In [1]:
import torch
import pickle
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments
from datasets import load_dataset

  from .autonotebook import tqdm as notebook_tqdm


Some links to be used for fine-tuning
1. https://medium.com/@alexandros_chariton/how-to-fine-tune-llama-3-2-instruct-on-your-own-data-a-detailed-guide-e5f522f397d7
2. https://drlee.io/step-by-step-guide-fine-tuning-metas-llama-3-2-1b-model-f1262eda36c8
3. https://huggingface.co/blog/ImranzamanML/fine-tuning-1b-llama-32-a-comprehensive-article
4. https://medium.com/@hakeemsyd/how-to-fine-tune-your-llama-3-2-model-49a6f8c7621a
5. https://www.analyticsvidhya.com/blog/2024/12/fine-tuning-llama-3-2-3b-for-rag/
6. https://blog.futuresmart.ai/fine-tune-llama-32-vision-language-model-on-custom-datasets
7. https://www.kdnuggets.com/fine-tuning-llama-using-unsloth
8. https://www.linkedin.com/pulse/step-guide-use-fine-tune-llama-32-dr-oualid-soula-xmnff/
9. https://github.com/meta-llama/llama-cookbook/blob/main/getting-started/finetuning/multigpu_finetuning.md



### Reading the Counsel Chat Dataset

In [2]:
with open('processed_data/counselchat_top_votes.pkl', 'rb') as file:
    dataset_df_top_votes = pickle.load(file)

dataset_df_top_votes

Unnamed: 0,topic,question,answerText
0,relationships,My partner seems to always get depressed over ...,"Hold on, Sanger! You know, I meet with a lot o..."
1,depression,My depression has been reoccurring for a long ...,I couldn't help but notice that you did not sp...
2,relationships,"My boyfriend is in Ireland for 11 days, and I ...",It sounds like you and your boyfriend are very...
3,family-conflict,Ever since my mother passed away my family has...,Understandably you'd like support from those w...
4,grief-and-loss,What can I do to stop grieving my mother's dea...,I am sorry that you lost your mother. That is ...
...,...,...,...
858,relationships,"Ever since I was little, I loved the idea of l...",Attention is often something that is both want...
859,relationship-dissolution,"We're not together, but I'm still doing things...",You didn't ask a direct question because I fee...
860,parenting,"When my son was a teenager, we sent him to liv...",Probably the best way to be supportive of your...
861,depression,I've had posttraumatic stress disorder for yea...,Post traumatic stress disorder (PTSD) is a ver...


### Writing the Fine-Tuning Code

In [None]:
with open('hf_token.key', 'r') as f:
    hf_token = f.read()

model_id = "meta-llama/Llama-3.2-3B-Instruct"

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_id)
# Adding a special token for pad token so that eos token can be recognized 
# (https://github.com/unslothai/unsloth/issues/416)
# https://github.com/huggingface/transformers/issues/22794
# https://github.com/huggingface/transformers/issues/23230
tokenizer.add_special_tokens({"pad_token": "<|reserved_special_token_0|>"})
tokenizer.padding_side = "right"
tokenizer.model_max_length = 2048

In [None]:
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="auto") # Must be float32 for MacBooks!

model.config.use_cache=False
model.config.pad_token_id = tokenizer.pad_token_id # Updating the model config to use the special pad token

In [None]:
# Define a function to apply the chat template
def format_chat_template(example):
    question_text_title_combined = None
    
    if example['questionTitle'] == None:
        question_text_title_combined = example['questionText']
    elif example['questionText'] == None:
        question_text_title_combined = example['questionTitle']
    else:
        question_text_title_combined = example['questionText'] + " " + example['questionTitle']
    
    messages = [
        {"role": "user", "content": question_text_title_combined},
        {"role": "assistant", "content": example['answerText']}
    ]
    prompt = tokenizer.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
    return {"prompt": prompt}

In [None]:
new_dataset = dataset.map(format_chat_template)
new_dataset = new_dataset.train_test_split(0.05)

In [None]:
# Tokenize the data
def tokenize_function(example):
    tokens = tokenizer(example['prompt'], padding="max_length", truncation=True)
    # Set padding token labels to -100 to ignore them in loss calculation
    tokens['labels'] = [
        -100 if token == tokenizer.pad_token_id else token for token in tokens['input_ids']
    ]
    return tokens

In [None]:
# Apply tokenize_function to each row
tokenized_dataset = new_dataset.map(tokenize_function)
tokenized_dataset = tokenized_dataset.remove_columns(['questionID', 'questionTitle', 'questionText', 'questionLink', 'topic', 'therapistInfo', 'therapistURL', 'answerText', 'upvotes', 'views', 'prompt'])

In [None]:
model.train()
training_args = TrainingArguments(
    output_dir="./llama32-sft-fine-tune-counselchat",
    eval_strategy="steps", # To evaluate during training
    eval_steps=50,
    logging_steps=50,
    save_steps=500,
    per_device_train_batch_size=2, # Adjust based on your hardware
    per_device_eval_batch_size=2,
    num_train_epochs=2, # How many times to loop through the dataset
    fp16=False, # Must be False for MacBooks
    report_to="none", # Here we can use something like tensorboard to see the training metrics
    log_level="info",
    learning_rate=1e-5, # Would avoid larger values here
    max_grad_norm=2 # Clipping the gradients is always a good idea
)

# Initialize Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["test"],
    tokenizer=tokenizer)

In [None]:
# Train the model
trainer.train()

# Save the model and tokenizer
trainer.save_model("./llama32-sft-fine-tune-counselchat")
tokenizer.save_pretrained("./llama32-sft-fine-tune-counselchat")

### Inference

In [None]:
model_id = "llama32-sft-fine-tune-counselchat"
pipe = pipeline(
    "text-generation",
    model=model_id,
    device_map="auto",
)

messages = [
    {"role": "user", "content": dataset[1]['questionText']}
]
outputs = pipe(
    messages,
    max_new_tokens=1024
)
print(outputs[0]["generated_text"][-1]['content'])