# Llama inference

In [None]:
# Loading libraries
import torch
from trl import SFTTrainer
from datasets import load_dataset
from transformers import TrainingArguments, TextStreamer
from unsloth.chat_templates import get_chat_template
from unsloth import FastLanguageModel, is_bfloat16_supported

from pydantic import BaseModel
from typing import Optional, List, Union
from pydantic import Field
import json
from collections import OrderedDict, defaultdict, Counter

import pickle
from datasets import Dataset
import pandas as pd
import dspy
import csv
import re

import ast
import tqdm
from collections import OrderedDict

from json_repair import repair_json

In [None]:
# Choose random seed
seed = 42
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)

## Load model

In [None]:
refined = "Meta-Llama-3_1-8b_refined_i2b2_pydantic/"

max_seq_length = 2048
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = refined,
    max_seq_length = max_seq_length,
)

model = FastLanguageModel.for_inference(model)

# Prepare data

In [None]:
# Loading the dataset from a file
with open('./data/n2c2_test.pkl', 'rb') as f:
    dataset_test = pickle.load(f)

# Convert to a list of dictionaries
test_dict_data = [{'prompt': item.paragraph, 'response': item.answer} for item in dataset_test]

# Create a Dataset
test_dataset = Dataset.from_pandas(pd.DataFrame(test_dict_data))

### Convert to CHATML

In [None]:
class Medication(BaseModel):
    """Medication information extracted from the text."""
    medication: str = Field(description="A drug name or an active ingredient.")
    ade: Optional[Union[str, List[str]]] = Field(default="", description="Extract adverse drug events from the text. Example: rash, hypotension, thrombocytopenia, toxicity, diarrhea, altered mental status, Rash, Thrombocytopenia, GI bleed, somnolent, etc.")
    strength: Optional[Union[str, List[str]]] = Field(default="", description="Extract the strength of the medication from the text. Examples: 100 mg, 10 mg, 5 mg, 20 mg, 40 mg, 25 mg, 500 mg, 10mg, 50 mg, 5mg, etc.")
    frequency: Optional[Union[str, List[str]]] = Field(default="", description="Extract the frequency of the medication from the text. Examples: 1-0-0, 1-0-1, daily, 0-0-1, DAILY (Daily), once a day, DAILY, BID, BID (2 times a day), twice a day, etc.")
    duration: Optional[Union[str, List[str]]] = Field(default="", description="Extract the duration of the medication from the text. Examples: dauerhaft, pausiert, abgesetzt, für 12 Monate, B-DATE - B-DATE, Pause, dauerhafte, 14 day, for 7 days, for 10 days, etc.")
    route: Optional[Union[str, List[str]]] = Field(default="", description="Extract the route of the medication from the text. Examples: PO, Oral, IV, by mouth, po, Inhalation, oral, drip, gtt, i.v., etc.")
    form: Optional[Union[str, List[str]]] = Field(default="", description="Extract the form of the medication from the text. Examples: Tablet, Capsule, Solution, Tablet, Delayed Release (E.C.), Tablets, Tablet, Chewable, tablet, Appl, Capsule, Delayed Release(E.C.), Tablet(s), etc.")
    dosage: Optional[Union[str, List[str]]] = Field(default="", description="Extract the dosage of the medication from the text. Examples: One (1), Two (2), 1, 1-2, 2, sliding scale, Three (3), 0.5, taper, 3, etc.")
    reason: Optional[Union[str, List[str]]] = Field(default="", description="Extract the reason of the medication from the text. Examples: pain, Antikoagulation, constipation, Thrombozytenaggregationshemmung, Stentverschlussprophylaxe, anxiety, pneumonia, Antibiose, duale Thrombozytenaggregationshemmung, wheezing, etc.")

class MedicationInfo(BaseModel):
    """A list of medication information extracted from the text."""
    medications: List[Medication] = Field(default_factory=list, description="A list of medications and their related information.")

schema = json.dumps(MedicationInfo.model_json_schema())

# Convert dataset to conversational format
system_message = f"""You are a physician. Your task is to extract ALL drug names (active ingredients or drug names) and their related information, such as ADE, strength, frequency, duration, route, form, dosage, and reason from a given text snippet of a doctoral letter. 
Please make sure to extract the medications **in the order they appear** in the text. Maintain this order in the JSON response.
If a medication occurs more than once in the text, append a unique count in parentheses to its name, starting from (1). 
If there is NO medication information in the text, create a this JSON: {{'medications': []}}
ONLY respond with an instance of JSON without any additional information. You have access to a JSON schema, which will determine how the JSON should be structured.
Make sure to return ONLY an instance of the JSON, NOT the schema itself. Do not add any additional information.
JSON schema:
{schema}
"""

# Define conversation format
def create_conversation(sample):
    conversation = {
        "messages": [
            {"role": "system", "content": system_message},  
            {"role": "user", "content": sample["prompt"]},  
            {"role": "assistant", "content": sample["response"]}  
        ]
    }
    
    return conversation

# Apply the conversation function to the dataset
test_chat_dataset = test_dataset.map(create_conversation, remove_columns=['prompt', 'response'])

tokenizer = get_chat_template(
    tokenizer,
    chat_template="chatml",
)

# Apply the CHATML format
def apply_template(examples):
    messages = examples["messages"]
    text = [tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=False) for message in messages]
    return {"text": text}

final_test = test_chat_dataset.map(apply_template, batched=True)

# Inference

In [None]:
batch_size = 16

# Initialize the count of malformed JSON predictions
malformed_json_preds = 0

pattern = r"(?<=\w)'(?=\w)"

# Process all messages
all_messages = test_chat_dataset['messages']

In [None]:
# Write to '|' delimited output csv
with open('output.csv', 'w', newline='') as csvfile:
            csvwriter = csv.writer(csvfile, delimiter='|')
            csvwriter.writerow(['text', 'gold', 'pred'])
        
            with torch.no_grad():
                for batch_start in tqdm.tqdm(range(0, len(all_messages), batch_size)):
                    batch = all_messages[batch_start:batch_start + batch_size]
        
                    texts = []
                    golds = []
                    messages = []
        
                    for i in batch:
                        text = i[1]['content'].strip().replace("\n", "")
                        texts.append(text)
        
                        gold = i[2]['content']
                        try:
                            gold_cleaned = re.sub(pattern, "", gold)
                            gold_dict = ast.literal_eval(gold_cleaned)
                            sorted_gold_dict = OrderedDict(sorted(gold_dict.items()))
                        except Exception as e:
                            print(f"Error parsing gold data: {e}")
                            sorted_gold_dict = {}
                        golds.append(sorted_gold_dict)
        
                        messages.append(i[:2])
        
                    if not messages:
                        continue
        
                    # Convert the input texts using the CHATML template
                    input_texts = [tokenizer.apply_chat_template(
                        message,
                        tokenize=False,
                        add_generation_prompt=True
                    ) for message in messages]
        
                    # Tokenize the rendered texts
                    inputs = tokenizer(
                        input_texts,
                        padding=True,
                        truncation=True,
                        max_length=max_seq_length,
                        return_tensors='pt'
                    ).to(device)
        
                    # Generate deterministic outputs
                    outputs = model.generate(
                        input_ids=inputs['input_ids'],
                        attention_mask=inputs.get('attention_mask'),
                        max_new_tokens=384,
                        use_cache=True,
                        top_p=1.0,
                        temperature=1.0,
                        do_sample=False
                    )
        
                    # Decode output text
                    decoded_outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
        
                    # Process each sample in the batch and write to csv
                    for idx in range(len(decoded_outputs)):
                        try:
                            decoded_output = decoded_outputs[idx]
                            assistant_response = decoded_output.split("<|im_start|>assistant")[1].replace("<|im_end|>", "").strip()
        
                            # Convert the assistant's response to a sorted dictionary
                            pred_cleaned = re.sub(pattern, "", assistant_response)
                            pred_cleaned = repair_json(assistant_response)
                            pred_dict = ast.literal_eval(pred_cleaned)
                            sorted_pred_dict = OrderedDict(sorted(pred_dict.items()))
        
                        except Exception as e_sample:
                            # Document malformed JSON output
                            print(e_sample)
                            malformed_json_preds += 1
                            print(f"{malformed_json_preds} malformed JSON outputs")
                            print(f"Gold: {golds[idx]}")
                            print(f"Prediction: {decoded_outputs[idx]}")
                            sorted_pred_dict = OrderedDict([('medications', [])])
            
                        # Write the data to the CSV file
                        csvwriter.writerow([texts[idx], golds[idx], sorted_pred_dict])
        
                    # Flush the file buffer
                    csvfile.flush()