In [6]:
import sys
import os
sys.path.append(os.getcwd()+"/../..")
from src import paths

import torch

from pydantic import BaseModel, Field
from enum import Enum, auto
from src.utils import (load_model_and_tokenizer, 
                        get_sampler, 
                        get_format_fun, 
                        format_prompt, 
                        get_outlines_generator, 
                        get_pydantic_schema, 
                        outlines_medication_prompting,
                        get_default_pydantic_model)

from outlines import samplers
from outlines.generate import SequenceGenerator
from outlines.samplers import Sampler

import pandas as pd

import json

from datasets import Dataset

from typing import Callable, Union

from pydantic import BaseModel

In [22]:
import importlib
import src.utils
importlib.reload(src.utils)

<module 'src.utils' from '/cluster/home/eglimar/inf-extr/notebooks/medication/../../src/utils.py'>

In [3]:
model, tokenizer = load_model_and_tokenizer("Llama2-MedTuned-7b",
                                            task_type = "outlines",
                                            quantization = "4bit",
                                            attn_implementation = "flash_attention_2",
                                           )

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [29]:
sampler = get_sampler("greedy")
schema = get_pydantic_schema("medication")
default_model = get_default_pydantic_model("medication")
generator = get_outlines_generator(model, sampler, task = "json", schema = schema)

In [37]:
schema(medications = [{"name": "Cipralex", "unit": "IE/mmol", "amount": 20, "morning": 0.5, "noon": 0, "evening": 0, "night": 0, "extra": "für 4 Tage"}]).json()

'{"medications":[{"name":"Cipralex","unit":"IE/mmol","amount":20.0,"morning":0.5,"noon":0.0,"evening":0.0,"night":0.0,"extra":"für 4 Tage"}]}'

In [30]:
df = Dataset.load_from_disk(paths.DATA_PATH_PREPROCESSED/"medication/kisim_medication_sample")

In [110]:
# Look at different medication formats:
for idx, text in enumerate(df["text"]):
    print(5*"---")
    print(idx)
    print(text)

---------------
0
Prednison 100 mg	1-0-0		24.09. - 07.10.2016
Prednison 80 mg		1-0-0		08.10. - 14.10.2016
Prednison 60 mg		1-0-0		15.10. - 21.10.2016
Prednison 40 mg		1-0-0		22.10.2016 bis auf weiteres

Pantozol 40 mg		1-0-0 		für die Dauer der Prednison-Behandlung
---------------
1
Auge rechts:
Floxal AT 4x/d für 5 Tage
Vitamine A AS zur Nacht
---------------
2
Volare Handgelenksschiene zur Nacht, bitte 1x für beide Hände

Dg.: CTS bds
---------------
3
Ebrufen 200 mg
---------------
4
Paracetamol 500 Hänseler neue Formel Tabl 20 (teilbar)
bei Bedarf


Ibuprofen Adico Filmtabl 400 mg 50 
bei Bedarf

Dauerrezept
---------------
5
Nexium MUPS-Tabl 40 mg 14 Stück Einnahme: 1-0-0
Zolpidem Winthrop Filmtabl 10mg 10 Stück (teilbar) Einnahme: 0-0-1
---------------
6
Einlagen
---------------
7
Amoxicillin Sandoz (Disp Tabl 1000 mg)
---------------
8
Fampyra 10mg 1-0-1 per os
---------------
9
Burgerstein Vitamin D3 (Kaps)
---------------
10
Novalgin (Filmtabl 500 mg)
---------------
11
Cetall

In [31]:
with open(paths.DATA_PATH_PREPROCESSED/"medication/task_instruction.txt", "r") as f:
    task_instruction = f.read()

with open(paths.DATA_PATH_PREPROCESSED/"medication/system_prompt.txt", "r") as f:
    system_prompt = f.read()

with open(paths.DATA_PATH_PREPROCESSED/"medication/examples.json", "r") as file:
    examples = json.load(file)                  
format_fun = get_format_fun("few_shot_instruction")

In [84]:
# text_input = format_prompt(df["text"], format_fun, system_prompt = system_prompt, task_instruction=task_instruction, examples = examples)

In [15]:
# results, successful = outlines_prompting_to(text = text_input, generator = generator, schema = schema, filename="medication-few-shot-instruction", batch_size = 1, wait_time = 250) 

In [34]:
test = df["text"][74]
test_input = format_prompt([test], format_fun, system_prompt = system_prompt, task_instruction=task_instruction, examples = examples)

In [35]:
print(test)
result, successful = outlines_medication_prompting(text= test_input, generator = generator, max_tokens = 1000, batch_size = 1)

Prednison-Ausschleichschema:

Prednison 15 mg für 5 Tage
folgend: Prednison 12,5 mg für 5 Tage
		Prednison 10 mg für 5 Tage
		Prednison 7,5 mg für 5 Tage
		Prednison 5 mg für 5 Tage
		Prednison 2,5 mg für 5 Tage

Esomeprazol 40 mg für die Dauer der Cortisoneinnahme


                                                         

Saved intermediate results to /cluster/dataset/midatams/inf-extr/results/intermediate/intermediate_results1710490674.3404074.pt




In [36]:
result

['{"medications":[{"name":"Prednison","dose":15.0,"dose_unit":"mg","morning":1.0,"noon":1.0,"evening":1.0,"night":1.0,"extra":"für 5 Tage"},{"name":"Prednison","dose":12.0,"dose_unit":"mg","morning":1.0,"noon":1.0,"evening":1.0,"night":1.0,"extra":"für 5 Tage"},{"name":"Prednison","dose":10.0,"dose_unit":"mg","morning":1.0,"noon":1.0,"evening":1.0,"night":1.0,"extra":"für 5 Tage"},{"name":"Prednison","dose":7.0,"dose_unit":"mg","morning":1.0,"noon":1.0,"evening":1.0,"night":1.0,"extra":"für 5 Tage"},{"name":"Prednison","dose":5.0,"dose_unit":"mg","morning":1.0,"noon":1.0,"evening":1.0,"night":1.0,"extra":"für 5 Tage"}]}']

In [21]:
for token in generator.stream(test_input, max_tokens = 200):
    print(token, end = "")

{
"medications": [
{
"name": "Novalgin",
"unit": "mg",
"amount": 500

KeyboardInterrupt: 

In [12]:
import torch
torch.save(result, paths.RESULTS_PATH/"medication/testing.pt")

In [14]:
test_res = torch.load(paths.RESULTS_PATH/"medication/testing.pt")

In [18]:
[res.json() for res in test_res]

['{"medications":[{"name":"Novalgin","unit":"mg","amount":500.0,"morning":0.0,"noon":0.0,"evening":0.0,"night":0.0}]}',
 '{"medications":[{"name":"Cetallerg","unit":"mg","amount":10.0,"morning":0.5,"noon":0.0,"evening":0.0,"night":0.0}]}',
 '{"medications":[{"name":"Magnesium Diasporal","unit":"mg","amount":300.0,"morning":0.0,"noon":0.0,"evening":0.0,"night":0.0}]}']

In [121]:
def medication_prompting_to(text: list[str], generator: SequenceGenerator, default_model:Type[BaseModel], batch_size: int = 1, wait_time:int = 120)-> list[Union[str, BaseModel]]:
    """
    Generates a list of sequences using the given outlines generator and sampler. Function has built in time-out function, as the generation is prone
    to hang with a complicated schema.

    Args:
        text (list[str]): list of strings to be used as prompts
        generator (outlines.SequenceGenerator): outlines generator
        default_model (Type[BaseModel]): default pydantic model to return if the generation times out
        batch_size (int, optional): batch size. Defaults to 1.
        wait_time (int, optional): wait time. Defaults to 120.

    Returns:
        list[Union[str, pydantic.BaseModel]]: list of generated sequences
    """

    dataloader = DataLoader(text, batch_size = batch_size, shuffle = False)

    results = []
    successful = []

    # Save intermediate results
    filename = "intermediate_results" + str(time.time()) + ".pt"

    def timeout_handler(signum, frame):
        raise TimeoutError("Timed out")
    
    signal.signal(signal.SIGALRM, timeout_handler)

    bar_prompt = tqdm(dataloader, desc="Prompting", leave=False)

    for batch in bar_prompt:
        try:
            signal.alarm(wait_time)  # Set the timeout
            result = generator(batch)
            if batch_size == 1:
                result = 
            results.extend(result)
            successful.extend([True] * len(batch))
        
        except TimeoutError:
            print("Timed out at observation number", len(results))
            successful.extend([False] * len(batch))
            results.extend([default_model for _ in range(len(batch))])

        # except TimeoutError:
        #     print("Timed out, trying stream_input")
        #     bar_stream = tqdm(batch, desc="Stream input", leave=False)
        #     for text in bar_stream:
        #         _res = stream_input(text, generator)
        #         try: 
        #             _res = schema.model_validate_json(_res)
        #             successful.append(True)
        #         except:
        #             successful.append(False)
        #         results.append(_res)
        finally:
            signal.alarm(0)
        print(results)
        print(type(results))
        print(len(results))
        # if batch_size == 1:
        #     # Using this because generator returns tuples with the first item being the keys of the highest pedantic schema, the second being the value
        #     results_unpacked = [res[1] for res in results]
        # results_json = [res.json() for res in results_unpacked]
        # results_json += [{"successful": s} for s in successful]
        # os.makedirs(paths.RESULTS_PATH/"intermediate", exist_ok=True)
        # torch.save(results_json, paths.RESULTS_PATH/"intermediate"/filename)
        # print(f"Saved intermediate results to {paths.RESULTS_PATH/'intermediate'}/{filename}")
    return results, successful

In [125]:
results2, successful2 = outlines_prompting_to(text = test_input, generator = generator, default_model = default_model, batch_size = 2, wait_time = 180) 

                                                        

[MedicationList(medications=[Medication(name='Novalgin', unit=<MedicationUnit.mg: 'mg'>, amount=500.0, morning=1.0, noon=0.0, evening=0.0, night=0.0)]), MedicationList(medications=[Medication(name='Cetallerg', unit=<MedicationUnit.mg: 'mg'>, amount=10.0, morning=1.0, noon=0.0, evening=0.0, night=0.0)])]
<class 'list'>
2




In [124]:
results

[('medications',
  [Medication(name='Novalgin', unit=<MedicationUnit.mg: 'mg'>, amount=500.0, morning=1.0, noon=0.0, evening=0.0, night=0.0)]),
 ('medications',
  [Medication(name='Cetallerg', unit=<MedicationUnit.mg: 'mg'>, amount=10.0, morning=1.0, noon=0.0, evening=0.0, night=0.0)])]

In [126]:
results2

[MedicationList(medications=[Medication(name='Novalgin', unit=<MedicationUnit.mg: 'mg'>, amount=500.0, morning=1.0, noon=0.0, evening=0.0, night=0.0)]),
 MedicationList(medications=[Medication(name='Cetallerg', unit=<MedicationUnit.mg: 'mg'>, amount=10.0, morning=1.0, noon=0.0, evening=0.0, night=0.0)])]

In [153]:
type(results[0][1][0])

src.utils.Medication

In [145]:
# medication

MedicationList(medications = results[0][1])

MedicationList(medications=[Medication(name='Novalgin', unit=<MedicationUnit.mg: 'mg'>, amount=500.0, morning=1.0, noon=0.0, evening=0.0, night=0.0)])

In [123]:
default_model.json()

'{"name":"unknown","unit":"unknown","amount":-99.0,"morning":-99.0,"noon":-99.0,"evening":-99.0,"night":-99.0}'

In [117]:
default_schema

Medication(name='unknown', unit=<MedicationUnit.unknown: 'unknown'>, amount=-99.0, morning=-99.0, noon=-99.0, evening=-99.0, night=-99.0)

In [72]:
torch.load(paths.RESULTS_PATH/"intermediate/intermediate_results1710344797.0171728.pt")

['{"medications":[{"name":"unknown","unit":"unknown","amount":-99.0,"morning":-99.0,"noon":-99.0,"evening":-99.0,"night":-99.0}]}',
 {'successful': False}]

In [46]:
filename = "medication_test"
results_dump = [res.json() for res in results]
torch.save(results_dump, paths.RESULTS_PATH/"medication"/f"{filename}.pt")

In [45]:
results_dump

['{"medications":[{"name":"Vi-De","unit":"tropfen","amount":4500.0,"morning":0.0,"noon":0.0,"evening":0.0,"night":0.0}]}']

In [63]:
test = df["text"][1]
test_few_shot = format_prompt([test], few_shot_instruction, system_prompt = system_prompt, task_instruction=task_instruction, examples = examples)
print(test_few_shot[0])

[INST]<<SYS>>You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe.
Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content.
Please ensure that your responses are socially unbiased and positive in nature.
If a question does not makeany sense, or is not factually coherent, explain why instead of answering something not correct. 
If you don’t know the answer to a question, please don’t share false information.
<</SYS>>

### Instruction:
Your task is to extract specific information from medication descriptions. 
The input for this task is a list of medication descriptions, a report or doctors recipe, and the output should be a complete list of dictionaries (one per medication) with the following keys:
- name (str): The name of the medication.
- dose (float): The dose of the medication.
- unit (str): The unit of the dose (mg, ug, g, stk).
- morning (float): The dose to be taken in th

In [64]:
print(test)
for token in generator.stream(test_few_shot):
    print(token, end ="")

1 OP Propanolol 20 mg 1-0-1
2 OP Medrol 32 mg 1-0-0
1 OP Calcimagon D3 forte 1-0-0
{
"medications": [
{
"name": "Propanolol",
"unit": "mg",
"amount": 20,
"morning": 1,
"noon": 0,
"evening": 1,
"night": 0
},
{
"name": "Medrol",
"unit": "mg",
"amount": 32,
"morning": 1,
"noon": 0,
"evening": 0,
"night": 0
},
{
"name": "Calcimagon D3 forte",
"unit": "mg",
"amount": 1,
"morning": 0,
"noon": 0,
"evening": 0,
"night": 0
}
]
}

In [65]:
test_instruct= format_prompt([test], zero_shot_instruction, system_prompt = system_prompt, task_instruction=task_instruction, examples = examples)
print(test)
for token in generator.stream(test_instruct):
    print(token, end ="")

1 OP Propanolol 20 mg 1-0-1
2 OP Medrol 32 mg 1-0-0
1 OP Calcimagon D3 forte 1-0-0
{
"medications": [
{
"name": "Propanolol",
"unit": "mg",
"amount": 20,
"morning": 1,
"noon": 1,
"evening": 1,
"night": 1
},
{
"name": "Medrol",
"unit": "mg",
"amount": 32,
"morning": 1,
"noon": 1,
"evening": 1,
"night": 1
},
{
"name": "Calcimagon D3 forte",
"unit": "ml",
"amount": 1,
"morning": 1,
"noon": 1,
"evening": 1,
"night": 1
}
]
}

In [66]:
test1 = df["text"][38]
test1_few_shot = format_prompt([test1], few_shot_instruction, system_prompt = system_prompt, task_instruction=task_instruction, examples = examples)
print(test1)
for token in generator.stream(test1_few_shot):
    print(token, end ="")

Gültig für 12 Monate
Gilenya
Fingolimod.

Eine Kapsel mit 0,5 mg einmal täglich oral

für Herr Holenstein Daniel
Geboren 14.01.1965
Weltiweg 4
CH-5330 Bad Zurzach

Krankenkasse: Die Eidgenössische Gesundheitskasse
Vers.-nummer: 2194915
{
"medications" : [
{
"name" : "Gilenya",
"unit" : "mg",
"amount" : 0.5,
"morning" : 1,
"noon" : 0,
"evening" : 0,
"night" : 0
},
{
"name" : "Fingolimod",
"unit" : "mg",
"amount" : 0.5,
"morning" : 1,
"noon" : 0,
"evening" : 0,
"night" : 0
}
]
}

In [67]:
test1_instruct = format_prompt([test1], zero_shot_instruction, system_prompt = system_prompt, task_instruction=task_instruction, examples = examples)
print(test1)
for token in generator.stream(test1_instruct):
    print(token, end ="")

Gültig für 12 Monate
Gilenya
Fingolimod.

Eine Kapsel mit 0,5 mg einmal täglich oral

für Herr Holenstein Daniel
Geboren 14.01.1965
Weltiweg 4
CH-5330 Bad Zurzach

Krankenkasse: Die Eidgenössische Gesundheitskasse
Vers.-nummer: 2194915
{
"medications" : [
{
"name" : "Gilenya",
"unit" : "tropfen",
"amount" : 0.5
,"morning" : 0,
"noon" : 0,
"evening" : 0,
"night" : 0
}, 
{
"name" : "Fingolimod.",
"unit" : "tropfen",
"amount" : 0.5
,"morning" : 0,
"noon" : 0,
"evening" : 0,
"night" : 0
}
]
}

In [68]:
test3 = df["text"][23]
test3_few_shot = format_prompt([test3], few_shot_instruction, system_prompt = system_prompt, task_instruction=task_instruction, examples = examples)
print(test3)
for token in generator.stream(test3_few_shot):
    print(token, end ="")

Fussorthese
{ "medications": [
{ "name": "Fussorthese", "unit": "unknown", "amount": 1, "morning": -99, "noon": -99, "evening": -99, "night": -99 }
] }

In [69]:
test3_instruct = format_prompt([test3], zero_shot_instruction, system_prompt = system_prompt, task_instruction=task_instruction, examples = examples)
print(test3)
for token in generator.stream(test3_instruct):
    print(token, end ="")

Fussorthese
{"medications": [
{"name": "Fussorthese", "unit": "unknown", "amount": 1, "morning": 1, "noon": 1, "evening": 1, "night": 1}
]}

In [70]:
result3 = generator(test3)

In [71]:
result3

MedicationList(medications=[Medication(name='Diclofenac', unit=<MedicationUnit.tropfen: 'tropfen'>, amount=50.0, morning=1.0, noon=1.0, evening=1.0, night=1.0)])

Problems:
- Some inputs are structured with one line per medication e.g. df["text"][0], others are medical recipes like df["text"][38]. The model struggles a bit sometimes with inputs that are not as well structured (but still really good)
- The problem above seems to be solvable by providing appropriate examples but I don't know if I get all the different input formats.
- Sometimes medications are misspelled (like Propanolol) and model extracts it the way it was (which is the desired behaviour I think, because I don't have the medical expertise to correct it). Unsure what the best way to correct it is.
- A lot of times the schema for intake changes, so after 2 weeks maybe it is less or more. Additionally extracting this in detail could be very hard and might negatively affect the performance of the other outputs (which seem more important to me, but I am no doctor). This is also only the case for a few of the examples as far as I can tell.
- If the text just mentions "once daily" or similar I told the model to map it all in the morning (so once daily is 1-0-0) but not sure if that would be desired behaviour.
- How would I evaluate the performance (spelling mistakes, forget medicine etc.). I could evaluate a test set myself (100 examples) but I can't guarantee that the criteria I set would be reasonable from a medical point of view.

In [4]:
sampler = get_sampler("multinomial")
generator = get_outlines_generator(model = model, sampler = sampler, task = "text")

In [9]:
df = Dataset.load_from_disk(paths.DATA_PATH_PREPROCESSED/"medication/kisim_medication_sample")
results = torch.load(paths.RESULTS_PATH/"medication/medication_outlines_Llama2-MedTuned-13b_4bit_few_shot_instruction_examples_10.pt")

In [11]:
results["model_answers"][77]

'{"medications":[{"name":"Fampyra","dose":10.0,"dose_unit":"mg","morning":1.0,"noon":1.0,"evening":0.0,"night":0.0,"extra":"1-1-0"},{"name":"Sifrol","dose":0.125,"dose_unit":"mg","morning":-99.0,"noon":-99.0,"evening":-99.0,"night":-99.0,"extra":"2h vor Schlafegehen"}]}'

In [12]:
df["text"][77]

'Fampyra 10 mg \t\t1-1-0\nSifrol 0.125 mg \t\t2h vor Schlafegehen'

In [27]:
input = """[INST]Your task is to extract specific information from medication descriptions. 
The input for this task is a list of medication descriptions, a report or doctors recipe, and the output should be how each these medications have to be taken during the day.
The intake follows for each medication can follow a specific schema, which has to be translated to a format like 0.5-1-2-0
The output should consist of the following:
- morning (float): The dose to be taken in the morning.
- noon (float): The dose to be taken at noon.
- evening (float): The dose to be taken in the evening.
- night (float): The dose to be taken at night.

The output format should be: 

float-float-float-float

This corresponds to morningDose-noonDose-eveningDose-nightDose

- The intake doses over the day can be given several ways:
    - If the amount of doses is given in the form of float-float-float, it corresponds to MorningDose-NoonDose-EveningDose with NightDose being 0.
    - If the amount of doses is given in the form float-float-float-float, it corresponds to MorningDose-NoonDose-EveningDose-NightDose.
    - If keywords like "Morgen", "Mittag", "Abend", "Nacht" are used, the corresponding doses should be extracted and the others set set 0.
    - If an intake schema like the ones above is not detected, MorningDose, NoonDose, EveningDose and NightDose should all be represented as -99.

Here is an example:

Input:
Fampyra 10 mg \t\t1-1-0\nSifrol 0.125 mg \t\t2h vor Schlafegehen

Output:
Let's think step by step.

Step 1: Split Input
First, we need to split the input into separate medication descriptions.

Fampyra 10 mg     1-1-0
Sifrol 0.125 mg     2h vor Schlafegehen

Step 2: Extract Intake Schema
Now, we'll go through each medication description to extract the intake schema.

For Fampyra 10 mg:
The intake schema is given as 1-1-0, which corresponds to Morning-Noon-Evening with Night being 0.
So the output for Fampyra would be: 1.0-1.0-0.0-0.0

For Sifrol 0.125 mg:
The description mentions "2h vor Schlafegehen", which means "2 hours before going to sleep". This corresponds to NightDose.
So the output for Sifrol would be: 0.0-0.0-0.0-1.0

Answer:
1.0-1.0-0.0-0.0
0.0-0.0-0.0-1.0
Each line corresponds to one medication's intake schema in the format: morning-noon-evening-night. If a dose is not mentioned, it's represented as 0.0.

###Input:
Vitamine A AS zur Nacht

[/INST]
Let's think step by step:"""

In [28]:
for token in generator.stream(input):
    print(token, end = "")



Step 1: Split Input
First, we need to split the input into separate medication descriptions.

Vitamine : B-Nacht
A : I-Nacht
AS : O
zur : O
Nacht : O 
