In [1]:
!pip install --upgrade transformers datasets accelerate peft trl



In [2]:
# This step is needed only on Apple Metal
!pip uninstall bitsandbytes -y

[0m

In [3]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import LoraConfig, get_peft_model

model_name = "meta-llama/Meta-Llama-3.1-8B-Instruct"  # Replace with the appropriate Llama 3.1 model name

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16)

# Set up LoRA configuration
lora_config = LoraConfig(
    r=16,  # Low-rank dimension
    lora_alpha=32,  # Scaling factor
    lora_dropout=0.1,  # Dropout probability
    bias="none",  # Don't add bias to the LoRA adapters
    target_modules=['down_proj', 'gate_proj', 'o_proj', 'v_proj', 'up_proj', 'q_proj', 'k_proj'],
    task_type="CAUSAL_LM",
)

# Wrap the model with LoRA
model = get_peft_model(model, lora_config)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [4]:
# Check if pad token exists, add it if missing
if tokenizer.pad_token is None:
   print(f'Added pad_token {tokenizer.eos_token}')
   tokenizer.pad_token = tokenizer.eos_token

tokenizer.padding_side = "right"

# Update the model's token embeddings to accommodate the new pad token
model.resize_token_embeddings(len(tokenizer))

Added pad_token <|eot_id|>


Embedding(128256, 4096)

In [5]:
import torch
print(f'MPS available: {torch.backends.mps.is_available()}')

device = "mps" if torch.backends.mps.is_available() else "cpu"
model = model.to(device)

MPS available: True


Prepare data

In [6]:
from datasets import load_dataset

dataset = load_dataset("ruslanmv/ai-medical-chatbot", split='all')
dataset

Dataset({
    features: ['Description', 'Patient', 'Doctor'],
    num_rows: 256916
})

In [7]:
# Inspired by https://github.com/meta-llama/llama-recipes/blob/main/src/llama_recipes/datasets/alpaca_dataset.py
# See also https://crfm.stanford.edu/2023/03/13/alpaca.html

from copy import deepcopy

PROMPT_DICT = {
    "prompt_input": (
        "Below is an instruction that describes a task, paired with an input that provides further context. "
        "Write a response that appropriately completes the request.\n\n"
        "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
    ),
    "prompt_no_input": (
        "Below is an instruction that describes a task. "
        "Write a response that appropriately completes the request.\n\n"
        "### Instruction:\n{instruction}\n\n### Response:"
    ),
}

# Tokenize the dataset
def tokenize_function(example):
    if example.get("Patient", "") == "":
        prompt = PROMPT_DICT["prompt_no_input"].format(instruction=example['Description'])
    else:
        prompt = PROMPT_DICT["prompt_input"].format(instruction=example['Description'], input=example['Patient'])

    formatted_example = prompt + example['Doctor']
    return tokenizer(formatted_example)


tokenized_dataset = dataset.map(tokenize_function, remove_columns=['Description', 'Patient', 'Doctor'], batched=False)
tokenized_dataset

Dataset({
    features: ['input_ids', 'attention_mask'],
    num_rows: 256916
})

In [8]:
train_test_split = tokenized_dataset.train_test_split(test_size=0.2)
train_dataset = train_test_split['train']
eval_dataset = train_test_split['test']
small_eval_dataset = eval_dataset.shuffle(seed=42).select(range(20))

Training

In [9]:
import math

# Define the perplexity metric function
def compute_metrics(eval_pred):
    logits, labels = eval_pred

    # Convert logits and labels from NumPy arrays to PyTorch tensors
    logits = torch.tensor(logits)
    labels = torch.tensor(labels)
    
    # Shift the labels so that they're aligned with the next token prediction
    labels = labels[:, 1:].reshape(-1)
    logits = logits[:, :-1].reshape(-1, logits.shape[-1])
    
    # Compute cross-entropy loss
    loss_fct = torch.nn.CrossEntropyLoss()
    loss = loss_fct(logits, labels)

    # Compute perplexity from loss
    perplexity = math.exp(loss.item()) if loss.item() < 100 else float("inf")
    
    return {"perplexity": perplexity}

In [10]:
from transformers import TrainingArguments
from torch.optim import AdamW
from trl import SFTTrainer

training_args = TrainingArguments(
    output_dir="./llama_lora_finetuned",
    eval_strategy="steps",
    eval_steps = 1,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=32,
    num_train_epochs=1,
    save_steps=1000,
    logging_steps=1,
    learning_rate=5e-4,
    weight_decay=0.001,
    warmup_steps=10,
    load_best_model_at_end=True,
)

optimizer = AdamW(model.parameters(), lr=5e-5)

trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=small_eval_dataset,
    #peft_config=lora_config,
    #dataset_text_field="input_ids",
    tokenizer=tokenizer, 
    optimizers=(optimizer, None),  # No need for a scheduler here
    max_seq_length=tokenizer.model_max_length,
    compute_metrics=compute_metrics,
    packing=False,
    #dataset_kwargs={
    #    "add_special_tokens": False,
    #    "append_concat_token": False,
    #}
)



Deprecated positional argument(s) used in SFTTrainer, please use the SFTConfig to set these arguments instead.


In [11]:
trainer.train()

  0%|          | 0/6422 [00:00<?, ?it/s]

{'loss': 2.9393, 'grad_norm': 7451.27783203125, 'learning_rate': 5e-06, 'epoch': 0.0}


  0%|          | 0/20 [00:00<?, ?it/s]

{'eval_loss': 2.9937121868133545, 'eval_perplexity': 18.989596695152287, 'eval_runtime': 121.7758, 'eval_samples_per_second': 0.164, 'eval_steps_per_second': 0.164, 'epoch': 0.0}
{'loss': 2.9506, 'grad_norm': 8231.7890625, 'learning_rate': 1e-05, 'epoch': 0.0}


  0%|          | 0/20 [00:00<?, ?it/s]

{'eval_loss': 2.993868827819824, 'eval_perplexity': 18.99148927302952, 'eval_runtime': 237.5224, 'eval_samples_per_second': 0.084, 'eval_steps_per_second': 0.084, 'epoch': 0.0}
{'loss': 3.1613, 'grad_norm': 5187.17578125, 'learning_rate': 1.5e-05, 'epoch': 0.0}


  0%|          | 0/20 [00:00<?, ?it/s]

{'eval_loss': 2.995087146759033, 'eval_perplexity': 19.01076989755655, 'eval_runtime': 227.1277, 'eval_samples_per_second': 0.088, 'eval_steps_per_second': 0.088, 'epoch': 0.0}
{'loss': 3.095, 'grad_norm': 119288.234375, 'learning_rate': 2e-05, 'epoch': 0.0}


  0%|          | 0/20 [00:00<?, ?it/s]

{'eval_loss': 2.993140459060669, 'eval_perplexity': 18.975480757070926, 'eval_runtime': 235.4051, 'eval_samples_per_second': 0.085, 'eval_steps_per_second': 0.085, 'epoch': 0.0}
{'loss': 3.0124, 'grad_norm': 158538.34375, 'learning_rate': 2.5e-05, 'epoch': 0.0}


  0%|          | 0/20 [00:00<?, ?it/s]

{'eval_loss': 2.9888112545013428, 'eval_perplexity': 18.901033701471555, 'eval_runtime': 272.9314, 'eval_samples_per_second': 0.073, 'eval_steps_per_second': 0.073, 'epoch': 0.0}


In [None]:
model.save_pretrained("./llama3_lora_model_finetuned")
tokenizer.save_pretrained("./llama3_lora_tokenizer_finetuned")