In [None]:
import sys
import os
sys.path.append(os.getcwd()+"/../..")
from src import paths

import torch

from pydantic import BaseModel, Field
from enum import Enum, auto
from src.utils import (load_model_and_tokenizer, 
                        get_sampler, 
                        get_format_fun, 
                        format_prompt, 
                        get_outlines_generator, 
                        get_pydantic_schema, 
                        outlines_medication_prompting,
                        get_default_pydantic_model)

from outlines import samplers
from outlines.generate import SequenceGenerator
from outlines.samplers import Sampler

import pandas as pd

import json

from datasets import Dataset

from typing import Callable, Union

from pydantic import BaseModel

In [None]:
import importlib
import src.utils
importlib.reload(src.utils)

In [None]:
model, tokenizer = load_model_and_tokenizer("Llama2-MedTuned-7b",
                                            task_type = "outlines",
                                            quantization = "4bit",
                                            attn_implementation = "flash_attention_2",
                                           )

In [None]:
sampler = get_sampler("greedy")
schema = get_pydantic_schema("medication")
default_model = get_default_pydantic_model("medication")
generator = get_outlines_generator(model, sampler, task = "json", schema = schema)

In [None]:
schema(medications = [{"name": "Cipralex", "unit": "IE/mmol", "amount": 20, "morning": 0.5, "noon": 0, "evening": 0, "night": 0, "extra": "fÃ¼r 4 Tage"}]).json()

In [None]:
df = Dataset.load_from_disk(paths.DATA_PATH_PREPROCESSED/"medication/kisim_medication_sample")

In [None]:
# Look at different medication formats:
for idx, text in enumerate(df["text"]):
    print(5*"---")
    print(idx)
    print(text)

In [None]:
with open(paths.DATA_PATH_PREPROCESSED/"medication/task_instruction.txt", "r") as f:
    task_instruction = f.read()

with open(paths.DATA_PATH_PREPROCESSED/"medication/system_prompt.txt", "r") as f:
    system_prompt = f.read()

with open(paths.DATA_PATH_PREPROCESSED/"medication/examples.json", "r") as file:
    examples = json.load(file)                  
format_fun = get_format_fun("few_shot_instruction")

In [None]:
# text_input = format_prompt(df["text"], format_fun, system_prompt = system_prompt, task_instruction=task_instruction, examples = examples)

In [None]:
# results, successful = outlines_prompting_to(text = text_input, generator = generator, schema = schema, filename="medication-few-shot-instruction", batch_size = 1, wait_time = 250) 

In [None]:
test = df["text"][74]
test_input = format_prompt([test], format_fun, system_prompt = system_prompt, task_instruction=task_instruction, examples = examples)

In [None]:
print(test)
result, successful = outlines_medication_prompting(text= test_input, generator = generator, max_tokens = 1000, batch_size = 1)

In [None]:
result

In [None]:
for token in generator.stream(test_input, max_tokens = 200):
    print(token, end = "")

In [None]:
import torch
torch.save(result, paths.RESULTS_PATH/"medication/testing.pt")

In [None]:
test_res = torch.load(paths.RESULTS_PATH/"medication/testing.pt")

In [None]:
[res.json() for res in test_res]

In [None]:
def medication_prompting_to(text: list[str], generator: SequenceGenerator, default_model:Type[BaseModel], batch_size: int = 1, wait_time:int = 120)-> list[Union[str, BaseModel]]:
    """
    Generates a list of sequences using the given outlines generator and sampler. Function has built in time-out function, as the generation is prone
    to hang with a complicated schema.

    Args:
        text (list[str]): list of strings to be used as prompts
        generator (outlines.SequenceGenerator): outlines generator
        default_model (Type[BaseModel]): default pydantic model to return if the generation times out
        batch_size (int, optional): batch size. Defaults to 1.
        wait_time (int, optional): wait time. Defaults to 120.

    Returns:
        list[Union[str, pydantic.BaseModel]]: list of generated sequences
    """

    dataloader = DataLoader(text, batch_size = batch_size, shuffle = False)

    results = []
    successful = []

    # Save intermediate results
    filename = "intermediate_results" + str(time.time()) + ".pt"

    def timeout_handler(signum, frame):
        raise TimeoutError("Timed out")
    
    signal.signal(signal.SIGALRM, timeout_handler)

    bar_prompt = tqdm(dataloader, desc="Prompting", leave=False)

    for batch in bar_prompt:
        try:
            signal.alarm(wait_time)  # Set the timeout
            result = generator(batch)
            if batch_size == 1:
                result = 
            results.extend(result)
            successful.extend([True] * len(batch))
        
        except TimeoutError:
            print("Timed out at observation number", len(results))
            successful.extend([False] * len(batch))
            results.extend([default_model for _ in range(len(batch))])

        # except TimeoutError:
        #     print("Timed out, trying stream_input")
        #     bar_stream = tqdm(batch, desc="Stream input", leave=False)
        #     for text in bar_stream:
        #         _res = stream_input(text, generator)
        #         try: 
        #             _res = schema.model_validate_json(_res)
        #             successful.append(True)
        #         except:
        #             successful.append(False)
        #         results.append(_res)
        finally:
            signal.alarm(0)
        print(results)
        print(type(results))
        print(len(results))
        # if batch_size == 1:
        #     # Using this because generator returns tuples with the first item being the keys of the highest pedantic schema, the second being the value
        #     results_unpacked = [res[1] for res in results]
        # results_json = [res.json() for res in results_unpacked]
        # results_json += [{"successful": s} for s in successful]
        # os.makedirs(paths.RESULTS_PATH/"intermediate", exist_ok=True)
        # torch.save(results_json, paths.RESULTS_PATH/"intermediate"/filename)
        # print(f"Saved intermediate results to {paths.RESULTS_PATH/'intermediate'}/{filename}")
    return results, successful

In [None]:
results2, successful2 = outlines_prompting_to(text = test_input, generator = generator, default_model = default_model, batch_size = 2, wait_time = 180) 

In [None]:
results

In [None]:
results2

In [None]:
type(results[0][1][0])

In [None]:
# medication

MedicationList(medications = results[0][1])

In [None]:
default_model.json()

In [None]:
default_schema

In [None]:
torch.load(paths.RESULTS_PATH/"intermediate/intermediate_results1710344797.0171728.pt")

In [None]:
filename = "medication_test"
results_dump = [res.json() for res in results]
torch.save(results_dump, paths.RESULTS_PATH/"medication"/f"{filename}.pt")

In [None]:
results_dump

In [None]:
test = df["text"][1]
test_few_shot = format_prompt([test], few_shot_instruction, system_prompt = system_prompt, task_instruction=task_instruction, examples = examples)
print(test_few_shot[0])

In [None]:
print(test)
for token in generator.stream(test_few_shot):
    print(token, end ="")

In [None]:
test_instruct= format_prompt([test], zero_shot_instruction, system_prompt = system_prompt, task_instruction=task_instruction, examples = examples)
print(test)
for token in generator.stream(test_instruct):
    print(token, end ="")

In [None]:
test1 = df["text"][38]
test1_few_shot = format_prompt([test1], few_shot_instruction, system_prompt = system_prompt, task_instruction=task_instruction, examples = examples)
print(test1)
for token in generator.stream(test1_few_shot):
    print(token, end ="")

In [None]:
test1_instruct = format_prompt([test1], zero_shot_instruction, system_prompt = system_prompt, task_instruction=task_instruction, examples = examples)
print(test1)
for token in generator.stream(test1_instruct):
    print(token, end ="")

In [None]:
test3 = df["text"][23]
test3_few_shot = format_prompt([test3], few_shot_instruction, system_prompt = system_prompt, task_instruction=task_instruction, examples = examples)
print(test3)
for token in generator.stream(test3_few_shot):
    print(token, end ="")

In [None]:
test3_instruct = format_prompt([test3], zero_shot_instruction, system_prompt = system_prompt, task_instruction=task_instruction, examples = examples)
print(test3)
for token in generator.stream(test3_instruct):
    print(token, end ="")

In [None]:
result3 = generator(test3)

In [None]:
result3

Problems:
- Some inputs are structured with one line per medication e.g. df["text"][0], others are medical recipes like df["text"][38]. The model struggles a bit sometimes with inputs that are not as well structured (but still really good)
- The problem above seems to be solvable by providing appropriate examples but I don't know if I get all the different input formats.
- Sometimes medications are misspelled (like Propanolol) and model extracts it the way it was (which is the desired behaviour I think, because I don't have the medical expertise to correct it). Unsure what the best way to correct it is.
- A lot of times the schema for intake changes, so after 2 weeks maybe it is less or more. Additionally extracting this in detail could be very hard and might negatively affect the performance of the other outputs (which seem more important to me, but I am no doctor). This is also only the case for a few of the examples as far as I can tell.
- If the text just mentions "once daily" or similar I told the model to map it all in the morning (so once daily is 1-0-0) but not sure if that would be desired behaviour.
- How would I evaluate the performance (spelling mistakes, forget medicine etc.). I could evaluate a test set myself (100 examples) but I can't guarantee that the criteria I set would be reasonable from a medical point of view.

In [None]:
sampler = get_sampler("multinomial")
generator = get_outlines_generator(model = model, sampler = sampler, task = "text")

In [None]:
df = Dataset.load_from_disk(paths.DATA_PATH_PREPROCESSED/"medication/kisim_medication_sample")
results = torch.load(paths.RESULTS_PATH/"medication/medication_outlines_Llama2-MedTuned-13b_4bit_few_shot_instruction_examples_10.pt")

In [None]:
results["model_answers"][77]

In [None]:
df["text"][77]

In [None]:
input = """[INST]Your task is to extract specific information from medication descriptions. 
The input for this task is a list of medication descriptions, a report or doctors recipe, and the output should be how each these medications have to be taken during the day.
The intake follows for each medication can follow a specific schema, which has to be translated to a format like 0.5-1-2-0
The output should consist of the following:
- morning (float): The dose to be taken in the morning.
- noon (float): The dose to be taken at noon.
- evening (float): The dose to be taken in the evening.
- night (float): The dose to be taken at night.

The output format should be: 

float-float-float-float

This corresponds to morningDose-noonDose-eveningDose-nightDose

- The intake doses over the day can be given several ways:
    - If the amount of doses is given in the form of float-float-float, it corresponds to MorningDose-NoonDose-EveningDose with NightDose being 0.
    - If the amount of doses is given in the form float-float-float-float, it corresponds to MorningDose-NoonDose-EveningDose-NightDose.
    - If keywords like "Morgen", "Mittag", "Abend", "Nacht" are used, the corresponding doses should be extracted and the others set set 0.
    - If an intake schema like the ones above is not detected, MorningDose, NoonDose, EveningDose and NightDose should all be represented as -99.

Here is an example:

Input:
Fampyra 10 mg \t\t1-1-0\nSifrol 0.125 mg \t\t2h vor Schlafegehen

Output:
Let's think step by step.

Step 1: Split Input
First, we need to split the input into separate medication descriptions.

Fampyra 10 mg     1-1-0
Sifrol 0.125 mg     2h vor Schlafegehen

Step 2: Extract Intake Schema
Now, we'll go through each medication description to extract the intake schema.

For Fampyra 10 mg:
The intake schema is given as 1-1-0, which corresponds to Morning-Noon-Evening with Night being 0.
So the output for Fampyra would be: 1.0-1.0-0.0-0.0

For Sifrol 0.125 mg:
The description mentions "2h vor Schlafegehen", which means "2 hours before going to sleep". This corresponds to NightDose.
So the output for Sifrol would be: 0.0-0.0-0.0-1.0

Answer:
1.0-1.0-0.0-0.0
0.0-0.0-0.0-1.0
Each line corresponds to one medication's intake schema in the format: morning-noon-evening-night. If a dose is not mentioned, it's represented as 0.0.

###Input:
Vitamine A AS zur Nacht

[/INST]
Let's think step by step:"""

In [None]:
for token in generator.stream(input):
    print(token, end = "")