# Meta Llama-3 Fine-tuning

### Set up

In [2]:
import gc
import os
import random
import re
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from datasets import load_dataset
import torch
from torch.utils.data import Subset
# from torch.utils.tensorboard import SummaryWriter
import wandb
from transformers.integrations import WandbCallback
from transformers import (
        AutoTokenizer, pipeline,
        AutoModelForCausalLM,
        DataCollatorWithPadding,
        DataCollatorForSeq2Seq,
        AutoModelForSpeechSeq2Seq,
        BartForConditionalGeneration,
        TrainingArguments,
        Seq2SeqTrainingArguments,
        Trainer,
        BitsAndBytesConfig,
        EarlyStoppingCallback,
        ProgressCallback
)
from huggingface_hub import login
from peft import prepare_model_for_kbit_training
from peft import LoraConfig, get_peft_model
from trl import SFTTrainer
import evaluate
from tqdm.auto import tqdm
import warnings

warnings.filterwarnings("ignore")

In [3]:
wandb.init(project="CareConnect",name="Fine-tuning with context(Accurate 50% of times)")

In [4]:
seed = 42
np.random.seed(seed)

In [6]:
output_dir = './models/llama3_8B/'
model_id = "meta-llama/Meta-Llama-3-8B"
checkpoints = [os.path.join(output_dir, d) for d in os.listdir(output_dir) if d.startswith("checkpoint-")]
latest_checkpoint = max(checkpoints, key=os.path.getctime) if checkpoints else None

# Resume training from the latest checkpoint
model_id = latest_checkpoint if latest_checkpoint else model_id

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_id)
if tokenizer.pad_token is None:
    print('Pad token is None')
    tokenizer.pad_token = tokenizer.eos_token
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config, device_map={"":0})

In [11]:
model.gradient_checkpointing_enable()
model = prepare_model_for_kbit_training(model)

In [12]:
config = LoraConfig(
    r=8,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)

model = get_peft_model(model, config)

### Data Preperation

In [17]:
HMCDataset = load_dataset('hakeematyab/HealthCareMagicWithSummary-100k')
HMCDataset = HMCDataset['train'].train_test_split(train_size=0.8,seed=seed)
HMCDatasetTemp = HMCDataset['test'].train_test_split(train_size=0.5,seed=seed)
HMCDataset['validation'] = HMCDatasetTemp.pop('train')
HMCDataset['test']  = HMCDatasetTemp.pop('test')
del HMCDatasetTemp
HMCDataset

DatasetDict({
    train: Dataset({
        features: ['input', 'preprocessed_output', 'summarized_input', 'summarized_output'],
        num_rows: 89732
    })
    test: Dataset({
        features: ['input', 'preprocessed_output', 'summarized_input', 'summarized_output'],
        num_rows: 11217
    })
    validation: Dataset({
        features: ['input', 'preprocessed_output', 'summarized_input', 'summarized_output'],
        num_rows: 11216
    })
})

In [18]:
def preprocess_function(examples,teacher_forcing_ratio=1.0,accurate_context_ratio=0.5):    
    prompt_template = '''### system: 
You are CareConnect, an expert medical personal assistant.

### instruction: 
Answer the user's queries truthfully and accurately, based on the provided context if the context is applicable. Refuse to answer any questions unrelated to medicine.

### context: 
{context}

### user: 
{user_query}

### system: 
{response}'''+tokenizer.eos_token
    inputs = []
    for user_query, output, summ_input, summ_output  in zip(examples['input'],examples['preprocessed_output'],examples['summarized_input'], examples['summarized_output']):
        if random.random() < teacher_forcing_ratio:
            if random.random() < accurate_context_ratio:
                current_context= summ_input+'\n'+summ_output
            else:
                wrong_input_context = random.choice([c for c in examples['summarized_input'] if c != summ_input])
                wrong_output_context = random.choice([c for c in examples['summarized_output'] if c != summ_output])
                current_context= wrong_input_context+'\n'+wrong_output_context
        else:
            current_context= ""
        inputs.append(prompt_template.format(context=current_context, user_query=user_query,response=output))
    return {'text':inputs}

In [19]:
fullDataset=HMCDataset.map(preprocess_function, batched=True)
fullDataset

DatasetDict({
    train: Dataset({
        features: ['input', 'preprocessed_output', 'summarized_input', 'summarized_output', 'text'],
        num_rows: 89732
    })
    test: Dataset({
        features: ['input', 'preprocessed_output', 'summarized_input', 'summarized_output', 'text'],
        num_rows: 11217
    })
    validation: Dataset({
        features: ['input', 'preprocessed_output', 'summarized_input', 'summarized_output', 'text'],
        num_rows: 11216
    })
})

In [21]:
print(fullDataset['train'][0]['text'])

### system: 
You are CareConnect, an expert medical personal assistant.

### instruction: 
Answer the user's queries truthfully and accurately, based on the provided context if the context is applicable. Refuse to answer any questions unrelated to medicine.

### context: 
I have had a bout of bronchitis. Now that I am over it, I can't do any work. Next week, I will have an echo cardiogram to check if I have a heart condition. I am 52 years old and 40 pounds overweight.
According to the information provided by the doctor, low back pain is due to stenosis in the spine.

### user: 
im a 39yr old female,i just had a spinal fusion surgery on my back in December. in December I missed my period,january comes I get my period but its really heavy and then I started vomiting for 8 days straight when my period was over s were my symptoms, then febuary comes same thing I get my period and vomit for 8 days again,now march just before my period started I vomited for 8 days then my period started and

### Training

In [18]:
data_collator= DataCollatorWithPadding(tokenizer=tokenizer,padding="max_length", max_length=1024)

In [22]:
per_device_train_batch_size=8
per_device_eval_batch_size=8
epochs = 1
max_seq_length=1024

In [23]:
training_args = TrainingArguments(
    output_dir=output_dir,
    num_train_epochs=epochs,
    per_device_train_batch_size=per_device_train_batch_size,
    per_device_eval_batch_size=per_device_eval_batch_size,
    gradient_accumulation_steps= 16,
    evaluation_strategy="steps",
    eval_steps=15,
    save_steps=15,
    save_total_limit=2,  # Keep only the last 2 checkpoints
    fp16=True,  # Enable mixed precision training
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",  # Metric to monitor for early stopping
    greater_is_better=False,  # Metric is minimized
    logging_dir='./logs',
    report_to="wandb",
    logging_steps=5,  # Log every 10 steps
    
)

In [19]:
bert_score= evaluate.load("bertscore")

def compute_metrics(eval_preds):
    predictions, references = eval_preds
    print(references.shape)
    print(predictions[0].shape)
    decoded_preds = tokenizer.batch_decode(predictions[0], skip_special_tokens=True)
    decoded_labels = tokenizer.batch_decode(references, skip_special_tokens=True)
    print(decoded_preds)
    print(decoded_labels)
    result= bert_score.compute(predictions=decoded_preds, references=decoded_labels, lang="en")
    
    return {'precision':np.mean(result['precision']),'recall':np.mean(result['recall']),'f1':np.mean(result['f1'])}

In [20]:
def preprocess_logits_for_metrics(logits, labels):
    """
    Original Trainer may have a memory leak. 
    This is a workaround to avoid storing too many tensors that are not needed.
    """
    print(logits.shape)
    print(labels.shape)
    pred_ids = torch.argmax(logits, dim=-1)
    return pred_ids, labels

In [24]:
trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    dataset_text_field = "text",
    train_dataset=fullDataset['train'],
    eval_dataset=fullDataset['validation'],
    args=training_args,
    peft_config=config,
    max_seq_length=max_seq_length,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=3),WandbCallback(),ProgressCallback()]
)

Detected kernel version 3.10.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.
You are adding a <class 'transformers.integrations.integration_utils.WandbCallback'> to the callbacks of this Trainer, but there is already one. The currentlist of callbacks is
:DefaultFlowCallback
WandbCallback
EarlyStoppingCallback


In [26]:
wandb.config.update(training_args.to_dict())

# Convert trainer parameters to dictionary and log to W&B
trainer_params = {
    "train_dataset_size": len(trainer.train_dataset),
    "eval_dataset_size": len(trainer.eval_dataset),
    "max_seq_length": max_seq_length,
    "compute_metrics":"Bertscore",
    "callbacks": ["Early Stopping: Patience=3","Logging","Wandb Reporting"],
}
wandb.config.update(trainer_params)

In [27]:
for name, module in trainer.model.named_modules():
    if "norm" in name:
        module = module.to(torch.float32)

In [None]:
trainer.train()