In [19]:
import sys
import os
sys.path.append(os.getcwd()+"/../..")
from src import paths

from pydantic import BaseModel, Field
from enum import Enum, auto
from src.utils import load_model_and_tokenizer

from outlines import samplers
import outlines

import pandas as pd


In [2]:
medication_names = [
    "Avonex",
    "Betaferon",
    "Plegridy",
    "Copaxone",
    "Glatiramyl",
    "Aubagio",
    "Tecfidera",
    "Gilenya",
    "Tysabri",
    "Ocrevus",
    "Lemtrada",
    "Novantron",
    "Endoxan",
    "MabThera",
    "Imurek",
    "Mayzent",
    "Medrol",
    "Solu-Medrol",
    "Solumedrol",
    "Cortison",
    "Interferon beta-1a",
    "Interferon beta-1b",
    "Peginterferon beta-1a",
    "Glatirameracetat",
    "Teriflunomid",
    "Dimethylfumarat",
    "Fingolimod",
    "Natalizumab",
    "Ocrelizumab",
    "Alemtuzumab",
    "Mitoxantron",
    "Cyclophosphamid",
    "Rituximab",
    "Azathioprin",
    "Siponimod",
    "Glucocorticosteroid"
]

medication_keys = [medication.replace("-", "_").replace(" ", "_").lower() for medication in medication_names]

MedicationName = Enum("MedicationName", [(key, value) for key, value in zip(medication_keys, medication_names)])

class MedicationUnit(str, Enum):
    mg = "mg"
    ug = "ug"
    g = "g"

intake_amount = Field(pattern=r"\d+(\.\d{1,2})?")

class Medication(BaseModel):
    name: MedicationName
    unit: MedicationUnit
    amount: float
    morning: float
    noon: float
    evening: float
    night: float
    # intake: str = Field(pattern=r"(-99)|(\d+(\.\d{1})?-\d+(\.\d{1})?-\d+(\.\d{1})?)-\d+(\.\d{1})?")


class MedicationList(BaseModel):
    medications: list[Medication]

# Instantiate Medication
medication1 = Medication(
    name="Avonex",
    unit="mg",
    amount=100.0,
    morning=1,
    noon=0.5,
    evening=0,
    night=1,
)

In [3]:
model, tokenizer = load_model_and_tokenizer("Llama2-MedTuned-13b",
                                            task_type = "outlines",
                                            quantization = "4bit",
                                           )

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [16]:
sampler = samplers.greedy()
generator = outlines.generate.json(model, Medication, sampler = sampler)

In [33]:
task_instruction = ("Your task is to extract specific information from medication descriptions. "
                    "The information to extract includes the medication name, dose, dose unit, "
                    "and the amount of doses that have to be taken at morning-noon-evening-night. "
                    "The Medication name can consist of multiple words with whitespaces and should be returned as a single string. "
                    "The dose is one of the following: mg, ug, g. "
                    "The dose unit is a float. "
                    "The amount of intake doses can be given in two ways: "
                    "If The amount of doses is given in the format of float-float-float, it corresponds to "
                    "morning-noon-evening and night should be set to 0. If it is given in the form float-float-float-float it corresponds to "
                    "morning-noon-evening-night. "
                    "Your output should follow this specific json format: {name: str:[MedicationName], "
                    "unit: str:[MedicationUnit], amount: float: [MedicationAmount], morning: float[MorningIntake], "
                    "noon: float[NoonIntake], evening: float[EveningIntake], night: float[NightIntake]"
                   )

def zero_shot_base(report:str, task_instruction:str)->str:
    """Zero-shot base for Llama prompting

    Args:
        report (str): medical report
        task_instruction (str): instruction for the task

    Returns:
        str: reformatted medical report with base

    """
    base_prompt = "<s>[INST]<<SYS>>{system_prompt}<</SYS>>\n\n{instruction}{input}[/INST]\nThe type of multiple sclerosis stated in the german medical report is: "
    system_prompt =  ("\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. "
                      "Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. "
                       "Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make "
                        "any sense, or is not factually coherent, explain why instead of answering something not correct. If you don’t "
                        "know the answer to a question, please don’t share false information.\n"
                        )
    
    input = base_prompt.format(system_prompt = system_prompt, instruction = task_instruction, input =  report)

    return input
    
def zero_shot_instruction(report:str, task_instruction)->str:
    """Zero-shot instruction for the MS extraction task
    
    Args:
        report (str): medical report
        task_instruction (str): instruction for the task
        
        Returns:
            str: reformatted medical report with instruction
            
            """
    instruction_base_prompt = "<s>[INST]\n### Instruction:\n{instruction}\n\n### Input:\n{input}[/INST]\n\n### Output:\n"
    input = instruction_base_prompt.format(instruction = task_instruction, input =  report)

    return input



def few_shot_base(report:str, task_instruction:str, examples:list[dict])->str:
    """Few Shot base for the MS extraction task

    Args:
        report (str): medical report
        task_instruction (str): instruction for the task
        examples (list[dict]): list of examples. Each example is a dict with keys text, labels.

    Returns:
        str: reformatted medical report with base

    """
    base_prompt = ("<s>[INST]<<SYS>>{system_prompt}<</SYS>>\n\n{instruction} "
                   "Here is an example to help you understand the task: {examples} \n"
                   "Please provide your answer for the following Report:\n{input}[/INST]\nDiagnosis:\n"
                   )
    
    system_prompt = (
    "\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. "
    "Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. "
    "Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make "
    "any sense, or is not factually coherent, explain why instead of answering something not correct. If you don’t "
    "know the answer to a question, please don’t share false information.\n"
    )
    insert_examples = ""
    for example in examples:
        text = example["text"]
        label = example["labels"]
        insert_examples += f"Report:\n{text}\nLabel:\n{label}\n"
    
    input = base_prompt.format(system_prompt = system_prompt, instruction = task_instruction, examples = insert_examples, input = report)
    return input


In [35]:
test = kisim_medications["medication_name"].iloc[0].split("\n")[0]
examples = [{"text": test, "labels" : '{name: "Medrol", unit: "mg", amount: 32, morning: 1, noon: 0, evening: 0, night: 0}'}]
input = few_shot_base(test, task_instruction, examples)

In [18]:
kisim_medications = pd.read_csv(os.path.join(paths.DATA_PATH_SEANTIS, "kisim_medication.csv"))

In [20]:
input = zero_shot_instruction(test, task_instruction)
input

'<s>[INST]\n### Instruction:\nYour task is to extract specific information from medication descriptions. The information to extract includes the medication name, dose, dose unit, and the amount of doses that have to be taken at morning-noon-evening-night. The Medication name can consist of multiple words with whitespaces and should be returned as a single string. The dose is one of the following: mg, ug, g. The dose unit is a float. The amount of intake doses can be given in two ways: If The amount of doses is given in the format of float-float-float, it corresponds to morning-noon-evening and night should be set to 0. If it is given in the form float-float-float-float it corresponds to morning-noon-evening-night. Your output should follow this specific json format: {name: str:[MedicationName], unit: str:[MedicationUnit], amount: float: [MedicationAmount], morning: float[MorningIntake], noon: float[NoonIntake], evening: float[EveningIntake], night: float[NightIntake]\n\n### Input:\n1 O

In [None]:
for word in generator.stream(input):
    print(word, end = "")

{
"name": "Medrol",
"unit": "mg",
"amount": 32,
"morning