# Training a Llama model for n2c2 information extraction task

In [None]:
# Loading neccessary libraries

import torch

from trl import SFTTrainer
from datasets import load_dataset
from transformers import TrainingArguments, TextStreamer

from unsloth.chat_templates import get_chat_template
from unsloth import FastLanguageModel, is_bfloat16_supported

from sklearn.model_selection import train_test_split

from pydantic import BaseModel
from pydantic import Field
from typing import Optional, List, Union

import json
from collections import OrderedDict, defaultdict, Counter

import pickle
from datasets import Dataset
import pandas as pd
import dspy
import ast

In [None]:
# Choose random seed
seed = 123
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)

## Load model

In [None]:
# Load model
max_seq_length = 2048

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name="meta-llama/Meta-Llama-3.1-8B",
    max_seq_length=max_seq_length,
    load_in_4bit=False,
    dtype=None,
)

In [None]:
# Save finte-tuned model to folder
refined_model = "Meta-Llama-3_1-8b_refined_i2b2_pydantic/"

## Load data

In [None]:
# Loading training data
with open('./data/n2c2_train.pkl', 'rb') as f:
    dataset_train = pickle.load(f)
    
# Convert data to list of dictionaries
train_dict_data = [{'prompt': item.paragraph, 'response': item.answer} for item in dataset_train]

# Create a Dataset
train_dataset = Dataset.from_pandas(pd.DataFrame(train_dict_data))


### Convert data to CHATML

In [None]:
# Define pydantic object
class Medication(BaseModel):
    """Medication information extracted from the text."""
    medication: str = Field(description="A drug name or an active ingredient.")
    ade: Optional[Union[str, List[str]]] = Field(default="", description="Extract adverse drug events from the text. Example: rash, hypotension, thrombocytopenia, toxicity, diarrhea, altered mental status, Rash, Thrombocytopenia, GI bleed, somnolent, etc.")
    strength: Optional[Union[str, List[str]]] = Field(default="", description="Extract the strength of the medication from the text. Examples: 100 mg, 10 mg, 5 mg, 20 mg, 40 mg, 25 mg, 500 mg, 10mg, 50 mg, 5mg, etc.")
    frequency: Optional[Union[str, List[str]]] = Field(default="", description="Extract the frequency of the medication from the text. Examples: 1-0-0, 1-0-1, daily, 0-0-1, DAILY (Daily), once a day, DAILY, BID, BID (2 times a day), twice a day, etc.")
    duration: Optional[Union[str, List[str]]] = Field(default="", description="Extract the duration of the medication from the text. Examples: dauerhaft, pausiert, abgesetzt, für 12 Monate, B-DATE - B-DATE, Pause, dauerhafte, 14 day, for 7 days, for 10 days, etc.")
    route: Optional[Union[str, List[str]]] = Field(default="", description="Extract the route of the medication from the text. Examples: PO, Oral, IV, by mouth, po, Inhalation, oral, drip, gtt, i.v., etc.")
    form: Optional[Union[str, List[str]]] = Field(default="", description="Extract the form of the medication from the text. Examples: Tablet, Capsule, Solution, Tablet, Delayed Release (E.C.), Tablets, Tablet, Chewable, tablet, Appl, Capsule, Delayed Release(E.C.), Tablet(s), etc.")
    dosage: Optional[Union[str, List[str]]] = Field(default="", description="Extract the dosage of the medication from the text. Examples: One (1), Two (2), 1, 1-2, 2, sliding scale, Three (3), 0.5, taper, 3, etc.")
    reason: Optional[Union[str, List[str]]] = Field(default="", description="Extract the reason of the medication from the text. Examples: pain, Antikoagulation, constipation, Thrombozytenaggregationshemmung, Stentverschlussprophylaxe, anxiety, pneumonia, Antibiose, duale Thrombozytenaggregationshemmung, wheezing, etc.")

class MedicationInfo(BaseModel):
    """A list of medication information extracted from the text."""
    medications: List[Medication] = Field(default_factory=list, description="A list of medications and their related information.")

schema = json.dumps(MedicationInfo.model_json_schema())

# Define system message
system_message = f"""You are a physician. Your task is to extract ALL drug names (active ingredients or drug names) and their related information, such as ADE, strength, frequency, duration, route, form, dosage, and reason from a given text snippet of a doctoral letter. 
Please make sure to extract the medications **in the order they appear** in the text. Maintain this order in the JSON response.
If a medication occurs more than once in the text, append a unique count in parentheses to its name, starting from (1). 
If there is NO medication information in the text, create a this JSON: {{'medications': []}}
ONLY respond with an instance of JSON without any additional information. You have access to a JSON schema, which will determine how the JSON should be structured.
Make sure to return ONLY an instance of the JSON, NOT the schema itself. Do not add any additional information.
JSON schema:
{schema}
"""

# Define conversation format
def create_conversation(sample):
    conversation = {
        "messages": [
            {"role": "system", "content": system_message}, 
            {"role": "user", "content": sample["prompt"]}, 
            {"role": "assistant", "content": sample["response"]}  
        ]
    }
    
    return conversation

# Apply the conversation function to the dataset
train_chat_dataset = train_dataset.map(create_conversation, remove_columns=['prompt', 'response'])

tokenizer = get_chat_template(
    tokenizer,
    chat_template="chatml",
)

# Apply the CHATML format
def apply_template(examples):
    messages = examples["messages"]
    text = [tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=False) for message in messages]
    return {"text": text}

final_train = train_chat_dataset.map(apply_template, batched=True)


## Train the model and save to folder

In [None]:
model = FastLanguageModel.get_peft_model(
    model,
    r=16,
    lora_alpha=16,
    lora_dropout=0,
    target_modules=["q_proj", "k_proj", "v_proj", "up_proj", "down_proj", "o_proj", "gate_proj"], 
    use_rslora=True,
    use_gradient_checkpointing="unsloth"
)

trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=final_train,
    dataset_text_field="text",
    max_seq_length=max_seq_length,
    dataset_num_proc=2,
    packing=True,
    args=TrainingArguments(
        learning_rate=3e-4,
        lr_scheduler_type="linear",
        per_device_train_batch_size=4,
        gradient_accumulation_steps=2,
        num_train_epochs=1,
        fp16=not is_bfloat16_supported(),
        bf16=is_bfloat16_supported(),
        logging_steps=100,
        optim="adamw_8bit",
        weight_decay=0.01,
        warmup_steps=10,
        output_dir=refined_model,
        seed=0,
        save_steps=100,
        save_total_limit=2,         
        save_strategy="steps",      
    ),
)

trainer.train()

trainer.model.save_pretrained(refined_model)