In [1]:
!pip install --upgrade transformers datasets accelerate peft trl wandb python-dotenv



In [2]:
# This step is needed only on Apple Metal
!pip uninstall bitsandbytes -y

[0m

In [3]:
import os
from pathlib import Path
from dotenv import load_dotenv

dotenv_path = Path('../.env') # path to your .env file
load_dotenv(dotenv_path=dotenv_path)

True

Set wandb

In [4]:
import wandb

WANDB_KEY=os.getenv('WANDB_KEY')
wandb.login(key=WANDB_KEY)

[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33milya-ivensky[0m ([33milya-ivensky-free-lancer[0m). Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /Users/ilyaivensky/.netrc


True

Set the model

In [5]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import LoraConfig, get_peft_model

#model_name = "meta-llama/Meta-Llama-3.1-8B-Instruct"  # Replace with the appropriate Llama 3.1 model name
model_name = "meta-llama/Llama-3.2-1B-Instruct"  # Replace with the appropriate Llama 3.1 model name

os.environ["WANDB_PROJECT"] = f"fine-tune-{model_name.replace('/', '-')}"  # name your W&B project
os.environ["WANDB_LOG_MODEL"] = "checkpoint"  # log all model checkpoints

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16)

# Set up LoRA configuration
lora_config = LoraConfig(
    r=16,  # Low-rank dimension
    lora_alpha=16,  # Scaling factor
    lora_dropout=0,  # Dropout probability
    bias="none",  # Don't add bias to the LoRA adapters
    target_modules=['down_proj', 'gate_proj', 'o_proj', 'v_proj', 'up_proj', 'q_proj', 'k_proj'],
    task_type="CAUSAL_LM",
)

# Wrap the model with LoRA
model = get_peft_model(model, lora_config)

In [6]:
# Check if pad token exists, add it if missing
if tokenizer.pad_token is None:
   print(f'Added pad_token {tokenizer.eos_token}')
   tokenizer.pad_token = tokenizer.eos_token

tokenizer.padding_side = "right"
tokenizer.is_split_into_words = True

# Update the model's token embeddings to accommodate the new pad token
model.resize_token_embeddings(len(tokenizer))

Added pad_token <|eot_id|>


Embedding(128256, 2048)

In [7]:
import torch
print(f'MPS available: {torch.backends.mps.is_available()}')
print(f'CUDA available: {torch.cuda.is_available()}')

device = "mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)

torch.manual_seed(42)
print(f'device={device}')

MPS available: True
CUDA available: False
device=mps


Prepare data

In [8]:
from datasets import load_dataset

dataset = load_dataset("ruslanmv/ai-medical-chatbot", split='all')
dataset

Dataset({
    features: ['Description', 'Patient', 'Doctor'],
    num_rows: 256916
})

In [9]:
dataset[0]

{'Description': 'Q. What does abutment of the nerve root mean?',
 'Patient': 'Hi doctor,I am just wondering what is abutting and abutment of the nerve root means in a back issue. Please explain. What treatment is required for\xa0annular bulging and tear?',
 'Doctor': 'Hi. I have gone through your query with diligence and would like you to know that I am here to help you. For further information consult a neurologist online -->'}

In [10]:
alpaca_propmt = """Below is an instruction that describes a task, paired with an input that provides further context. 

### Instruction: 
{}

### Input:
{}

### Response:
{}"""

def formatting_prompt_fn(example):

    instruction = example['Description'][3:] # Remove 'Q. '
    input = example['Patient']
    response = example['Doctor']

    text = alpaca_propmt.format(instruction, input, response) + tokenizer.eos_token

    return tokenizer(text)

formatted_dataset = dataset.map(formatting_prompt_fn, remove_columns=['Description', 'Patient', 'Doctor'], batched=False)
formatted_dataset

Dataset({
    features: ['input_ids', 'attention_mask'],
    num_rows: 256916
})

In [11]:
train_test_split = formatted_dataset.train_test_split(test_size=0.2)
train_dataset = train_test_split['train']
eval_dataset = train_test_split['test']
small_eval_dataset = eval_dataset.shuffle(seed=42).select(range(20))

Training

In [12]:
import math

# Define the perplexity metric function
def compute_metrics(eval_pred):
    logits, labels = eval_pred

    # Convert logits and labels from NumPy arrays to PyTorch tensors
    logits = torch.tensor(logits)
    labels = torch.tensor(labels)
    
    # Shift the labels so that they're aligned with the next token prediction
    labels = labels[:, 1:].reshape(-1)
    logits = logits[:, :-1].reshape(-1, logits.shape[-1])
    
    # Compute cross-entropy loss
    loss_fct = torch.nn.CrossEntropyLoss()
    loss = loss_fct(logits, labels)

    # Compute perplexity from loss
    perplexity = math.exp(loss.item()) if loss.item() < 100 else float("inf")
    
    return {"perplexity": perplexity}

In [13]:
from torch.optim import AdamW
from trl import SFTTrainer, SFTConfig

training_args = SFTConfig(
    output_dir="./llama_lora_finetuned",
    eval_strategy="steps",
    eval_steps = 1,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=128,
    num_train_epochs=1,
    save_steps=1000,
    logging_steps=1,
    learning_rate=1e-5,
    lr_scheduler_type="linear",
    weight_decay=0.01,
    warmup_steps=10,
    load_best_model_at_end=True,
    report_to="wandb",
    seed=3407
)

optimizer = AdamW(
    model.parameters(), 
    lr=training_args.learning_rate, 
    weight_decay=training_args.weight_decay)

trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=small_eval_dataset,
    #peft_config=lora_config,
    #dataset_text_field="input_ids",
    tokenizer=tokenizer, 
    optimizers=(optimizer, None),  # No need for a scheduler here
    max_seq_length=tokenizer.model_max_length,
    compute_metrics=compute_metrics,
    packing=False,
    #dataset_kwargs={
    #    "add_special_tokens": False,
    #    "append_concat_token": False,
    #}
)



Deprecated positional argument(s) used in SFTTrainer, please use the SFTConfig to set these arguments instead.


In [None]:
trainer.train()



  0%|          | 0/1605 [00:00<?, ?it/s]

{'loss': 3.4839, 'grad_norm': 0.6205085515975952, 'learning_rate': 1.0000000000000002e-06, 'epoch': 0.0}


  0%|          | 0/20 [00:00<?, ?it/s]

{'eval_loss': 3.4462876319885254, 'eval_perplexity': 29.968461061195264, 'eval_runtime': 19.0327, 'eval_samples_per_second': 1.051, 'eval_steps_per_second': 1.051, 'epoch': 0.0}
{'loss': 3.4586, 'grad_norm': 0.6195924282073975, 'learning_rate': 2.0000000000000003e-06, 'epoch': 0.0}


  0%|          | 0/20 [00:00<?, ?it/s]

{'eval_loss': 3.445777416229248, 'eval_perplexity': 29.95563846214156, 'eval_runtime': 28.8273, 'eval_samples_per_second': 0.694, 'eval_steps_per_second': 0.694, 'epoch': 0.0}
{'loss': 3.427, 'grad_norm': 0.5974190831184387, 'learning_rate': 3e-06, 'epoch': 0.0}


  0%|          | 0/20 [00:00<?, ?it/s]

{'eval_loss': 3.4458515644073486, 'eval_perplexity': 29.957752563043176, 'eval_runtime': 85.6423, 'eval_samples_per_second': 0.234, 'eval_steps_per_second': 0.234, 'epoch': 0.0}
{'loss': 3.4744, 'grad_norm': 0.599490225315094, 'learning_rate': 4.000000000000001e-06, 'epoch': 0.0}


  0%|          | 0/20 [00:00<?, ?it/s]

{'eval_loss': 3.4453091621398926, 'eval_perplexity': 29.94190044233763, 'eval_runtime': 63.2806, 'eval_samples_per_second': 0.316, 'eval_steps_per_second': 0.316, 'epoch': 0.0}
{'loss': 3.4491, 'grad_norm': 0.6212579607963562, 'learning_rate': 5e-06, 'epoch': 0.0}


  0%|          | 0/20 [00:00<?, ?it/s]

{'eval_loss': 3.4443740844726562, 'eval_perplexity': 29.914785646702878, 'eval_runtime': 70.2403, 'eval_samples_per_second': 0.285, 'eval_steps_per_second': 0.285, 'epoch': 0.0}
{'loss': 3.4681, 'grad_norm': 0.6232648491859436, 'learning_rate': 6e-06, 'epoch': 0.0}


  0%|          | 0/20 [00:00<?, ?it/s]

{'eval_loss': 3.442688465118408, 'eval_perplexity': 29.86817001825144, 'eval_runtime': 102.8323, 'eval_samples_per_second': 0.194, 'eval_steps_per_second': 0.194, 'epoch': 0.0}
{'loss': 3.5008, 'grad_norm': 0.623978853225708, 'learning_rate': 7e-06, 'epoch': 0.0}


  0%|          | 0/20 [00:00<?, ?it/s]

{'eval_loss': 3.4399242401123047, 'eval_perplexity': 29.790174635340183, 'eval_runtime': 99.8256, 'eval_samples_per_second': 0.2, 'eval_steps_per_second': 0.2, 'epoch': 0.0}
{'loss': 3.4974, 'grad_norm': 0.6112125515937805, 'learning_rate': 8.000000000000001e-06, 'epoch': 0.0}


  0%|          | 0/20 [00:00<?, ?it/s]

{'eval_loss': 3.4363837242126465, 'eval_perplexity': 29.691584544793336, 'eval_runtime': 99.6403, 'eval_samples_per_second': 0.201, 'eval_steps_per_second': 0.201, 'epoch': 0.0}
{'loss': 3.497, 'grad_norm': 0.6269823312759399, 'learning_rate': 9e-06, 'epoch': 0.01}


  0%|          | 0/20 [00:00<?, ?it/s]

{'eval_loss': 3.4314169883728027, 'eval_perplexity': 29.55338480325731, 'eval_runtime': 65.433, 'eval_samples_per_second': 0.306, 'eval_steps_per_second': 0.306, 'epoch': 0.01}
{'loss': 3.5251, 'grad_norm': 0.6304206252098083, 'learning_rate': 1e-05, 'epoch': 0.01}


  0%|          | 0/20 [00:00<?, ?it/s]

{'eval_loss': 3.425340175628662, 'eval_perplexity': 29.385735705623972, 'eval_runtime': 71.2562, 'eval_samples_per_second': 0.281, 'eval_steps_per_second': 0.281, 'epoch': 0.01}
{'loss': 3.3942, 'grad_norm': 0.6119317412376404, 'learning_rate': 9.993730407523512e-06, 'epoch': 0.01}


  0%|          | 0/20 [00:00<?, ?it/s]

{'eval_loss': 3.418275833129883, 'eval_perplexity': 29.192633135334475, 'eval_runtime': 69.6928, 'eval_samples_per_second': 0.287, 'eval_steps_per_second': 0.287, 'epoch': 0.01}
{'loss': 3.4418, 'grad_norm': 0.6032811999320984, 'learning_rate': 9.987460815047023e-06, 'epoch': 0.01}


  0%|          | 0/20 [00:00<?, ?it/s]

{'eval_loss': 3.4107518196105957, 'eval_perplexity': 28.98687738388793, 'eval_runtime': 65.4954, 'eval_samples_per_second': 0.305, 'eval_steps_per_second': 0.305, 'epoch': 0.01}
{'loss': 3.4271, 'grad_norm': 938.4490356445312, 'learning_rate': 9.981191222570533e-06, 'epoch': 0.01}


  0%|          | 0/20 [00:00<?, ?it/s]

{'eval_loss': 3.404735565185547, 'eval_perplexity': 28.822538149226002, 'eval_runtime': 58.2904, 'eval_samples_per_second': 0.343, 'eval_steps_per_second': 0.343, 'epoch': 0.01}
{'loss': 3.4512, 'grad_norm': 0.6245319247245789, 'learning_rate': 9.974921630094044e-06, 'epoch': 0.01}


  0%|          | 0/20 [00:00<?, ?it/s]

{'eval_loss': 3.398347854614258, 'eval_perplexity': 28.652496677200787, 'eval_runtime': 57.0653, 'eval_samples_per_second': 0.35, 'eval_steps_per_second': 0.35, 'epoch': 0.01}
{'loss': 3.4491, 'grad_norm': 0.6270241141319275, 'learning_rate': 9.968652037617555e-06, 'epoch': 0.01}


  0%|          | 0/20 [00:00<?, ?it/s]

{'eval_loss': 3.391522169113159, 'eval_perplexity': 28.46996110142165, 'eval_runtime': 65.5853, 'eval_samples_per_second': 0.305, 'eval_steps_per_second': 0.305, 'epoch': 0.01}
{'loss': 3.4429, 'grad_norm': 0.6250674724578857, 'learning_rate': 9.962382445141066e-06, 'epoch': 0.01}


  0%|          | 0/20 [00:00<?, ?it/s]

{'eval_loss': 3.384915590286255, 'eval_perplexity': 28.2934853483842, 'eval_runtime': 62.0576, 'eval_samples_per_second': 0.322, 'eval_steps_per_second': 0.322, 'epoch': 0.01}
{'loss': 3.4322, 'grad_norm': 0.6088733673095703, 'learning_rate': 9.956112852664579e-06, 'epoch': 0.01}


  0%|          | 0/20 [00:00<?, ?it/s]

{'eval_loss': 3.3780887126922607, 'eval_perplexity': 28.114443430119795, 'eval_runtime': 63.1326, 'eval_samples_per_second': 0.317, 'eval_steps_per_second': 0.317, 'epoch': 0.01}
{'loss': 3.3926, 'grad_norm': 0.6078761219978333, 'learning_rate': 9.94984326018809e-06, 'epoch': 0.01}


  0%|          | 0/20 [00:00<?, ?it/s]

{'eval_loss': 3.3712234497070312, 'eval_perplexity': 27.934349906089906, 'eval_runtime': 65.3962, 'eval_samples_per_second': 0.306, 'eval_steps_per_second': 0.306, 'epoch': 0.01}
{'loss': 3.4511, 'grad_norm': 0.5961207151412964, 'learning_rate': 9.943573667711599e-06, 'epoch': 0.01}


  0%|          | 0/20 [00:00<?, ?it/s]

{'eval_loss': 3.364393949508667, 'eval_perplexity': 27.756839408863375, 'eval_runtime': 72.0865, 'eval_samples_per_second': 0.277, 'eval_steps_per_second': 0.277, 'epoch': 0.01}
{'loss': 3.3679, 'grad_norm': 0.6088525056838989, 'learning_rate': 9.93730407523511e-06, 'epoch': 0.01}


  0%|          | 0/20 [00:00<?, ?it/s]

{'eval_loss': 3.3573365211486816, 'eval_perplexity': 27.573697928274598, 'eval_runtime': 57.0204, 'eval_samples_per_second': 0.351, 'eval_steps_per_second': 0.351, 'epoch': 0.01}
{'loss': 3.3867, 'grad_norm': 0.6238860487937927, 'learning_rate': 9.931034482758622e-06, 'epoch': 0.01}


  0%|          | 0/20 [00:00<?, ?it/s]

{'eval_loss': 3.3502273559570312, 'eval_perplexity': 27.391764827408682, 'eval_runtime': 62.4587, 'eval_samples_per_second': 0.32, 'eval_steps_per_second': 0.32, 'epoch': 0.01}
{'loss': 3.416, 'grad_norm': 0.6212234497070312, 'learning_rate': 9.924764890282133e-06, 'epoch': 0.01}


  0%|          | 0/20 [00:00<?, ?it/s]

{'eval_loss': 3.3429551124572754, 'eval_perplexity': 27.204791762740744, 'eval_runtime': 53.8429, 'eval_samples_per_second': 0.371, 'eval_steps_per_second': 0.371, 'epoch': 0.01}
{'loss': 3.4185, 'grad_norm': 0.6258429884910583, 'learning_rate': 9.918495297805644e-06, 'epoch': 0.01}


  0%|          | 0/20 [00:00<?, ?it/s]

{'eval_loss': 3.335667371749878, 'eval_perplexity': 27.021581625055266, 'eval_runtime': 70.7195, 'eval_samples_per_second': 0.283, 'eval_steps_per_second': 0.283, 'epoch': 0.01}
{'loss': 3.4631, 'grad_norm': 0.6170189380645752, 'learning_rate': 9.912225705329155e-06, 'epoch': 0.01}


  0%|          | 0/20 [00:00<?, ?it/s]

{'eval_loss': 3.32844877243042, 'eval_perplexity': 26.839611711517644, 'eval_runtime': 68.1691, 'eval_samples_per_second': 0.293, 'eval_steps_per_second': 0.293, 'epoch': 0.01}
{'loss': 3.3584, 'grad_norm': 0.5973426699638367, 'learning_rate': 9.905956112852665e-06, 'epoch': 0.02}


  0%|          | 0/20 [00:00<?, ?it/s]

{'eval_loss': 3.321087598800659, 'eval_perplexity': 26.657138460129627, 'eval_runtime': 57.7455, 'eval_samples_per_second': 0.346, 'eval_steps_per_second': 0.346, 'epoch': 0.02}
{'loss': 3.3623, 'grad_norm': 0.6013031005859375, 'learning_rate': 9.899686520376176e-06, 'epoch': 0.02}


  0%|          | 0/20 [00:00<?, ?it/s]

{'eval_loss': 3.3138535022735596, 'eval_perplexity': 26.476631711316532, 'eval_runtime': 48.6342, 'eval_samples_per_second': 0.411, 'eval_steps_per_second': 0.411, 'epoch': 0.02}
{'loss': 3.3587, 'grad_norm': 0.6248693466186523, 'learning_rate': 9.893416927899687e-06, 'epoch': 0.02}


  0%|          | 0/20 [00:00<?, ?it/s]

{'eval_loss': 3.306919574737549, 'eval_perplexity': 26.305386331539708, 'eval_runtime': 52.74, 'eval_samples_per_second': 0.379, 'eval_steps_per_second': 0.379, 'epoch': 0.02}
{'loss': 3.3286, 'grad_norm': 0.6103606820106506, 'learning_rate': 9.887147335423198e-06, 'epoch': 0.02}


  0%|          | 0/20 [00:00<?, ?it/s]

{'eval_loss': 3.2994861602783203, 'eval_perplexity': 26.12418438930847, 'eval_runtime': 54.233, 'eval_samples_per_second': 0.369, 'eval_steps_per_second': 0.369, 'epoch': 0.02}
{'loss': 3.4237, 'grad_norm': 0.6364954113960266, 'learning_rate': 9.880877742946709e-06, 'epoch': 0.02}


  0%|          | 0/20 [00:00<?, ?it/s]

{'eval_loss': 3.2922966480255127, 'eval_perplexity': 25.94853616347324, 'eval_runtime': 71.2458, 'eval_samples_per_second': 0.281, 'eval_steps_per_second': 0.281, 'epoch': 0.02}
{'loss': 3.3424, 'grad_norm': 0.6177639961242676, 'learning_rate': 9.874608150470221e-06, 'epoch': 0.02}


  0%|          | 0/20 [00:00<?, ?it/s]

{'eval_loss': 3.2848517894744873, 'eval_perplexity': 25.76944214137618, 'eval_runtime': 50.1222, 'eval_samples_per_second': 0.399, 'eval_steps_per_second': 0.399, 'epoch': 0.02}
{'loss': 3.4172, 'grad_norm': 0.638789713382721, 'learning_rate': 9.86833855799373e-06, 'epoch': 0.02}


  0%|          | 0/20 [00:00<?, ?it/s]

{'eval_loss': 3.277390718460083, 'eval_perplexity': 25.59080932714049, 'eval_runtime': 48.8885, 'eval_samples_per_second': 0.409, 'eval_steps_per_second': 0.409, 'epoch': 0.02}
{'loss': 3.3012, 'grad_norm': 0.5967254638671875, 'learning_rate': 9.862068965517241e-06, 'epoch': 0.02}


  0%|          | 0/20 [00:00<?, ?it/s]

{'eval_loss': 3.2700819969177246, 'eval_perplexity': 25.41783826593779, 'eval_runtime': 61.6283, 'eval_samples_per_second': 0.325, 'eval_steps_per_second': 0.325, 'epoch': 0.02}


In [None]:
model.save_pretrained("./llama3_lora_model_finetuned")
tokenizer.save_pretrained("./llama3_lora_tokenizer_finetuned")

In [None]:
wandb.finish()