In [None]:
# %pip install torch torchvision torchaudio
# %pip install bitsandbytes datasets accelerate transformers peft

In [None]:
"""
Helper Function for Comparing Results
"""

from IPython.display import display, Markdown

def generate(model, tokenizer, context, question):

    #turn the input into tokens
    batch = tokenizer(f"**CONTEXT:**\n{context}\n\n**QUESTION:**\n{question}\n\n**ANSWER:**\n", return_tensors='pt', return_token_type_ids=False)
    #move the tokens onto the GPU, for inference
    # batch = batch.to(device='cuda')

    #raw model
    return tokenizer.decode(model.generate(**batch, max_new_tokens=200)[0], skip_special_tokens=True)

def compare_inference(model, tokenizer, context, question):
    out = generate(model, tokenizer, context, question)
    display(Markdown("# Finetuned Model\n"))
    display(Markdown(out))
    model.disable_adapters()
    out = generate(model, tokenizer, context, question)
    display(Markdown("# Raw Model\n"))
    display(Markdown(out))


In [None]:
"""Importing dependencies and downloading pre-trained bloom model
"""

import torch
import torch.nn as nn
import bitsandbytes as bnb
from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM

# model_id = "bigscience/bloom-560m"
#peft_model_id = "./checkpoint/BLOOM-560m-LoRA"
model_id = "facebook/opt-350m"
peft_model_id = "./checkpoint/opt-350m-lora"

#loading model
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    # torch_dtype=torch.float32,
    device_map='auto',
)

#loading tokenizer for this model (which turns text into an input for the model)
tokenizer = AutoTokenizer.from_pretrained(model_id)

In [None]:
# model.load_adapter(peft_model_id)

# context = "you are a math wizard"
# question = "what is 1+1 equal to?"
# compare_inference(model, tokenizer, context, question)

In [None]:
"""
Setting up LoRA using parameter efficient fine tuning
"""

from peft import LoraConfig

#defining how LoRA will work in this particular example
peft_config = LoraConfig(
    r=8,
    lora_alpha=8,
    lora_dropout=0.1,
    # target_modules=["query_key_value"],
    bias="none",
    task_type="CAUSAL_LM"
)

model.add_adapter(peft_config)

In [None]:
"""Comparing parameters before and after LoRA
"""

trainable_params = 0
all_param = 0

#iterating over all parameters
for _, param in model.named_parameters():
    #adding parameters to total
    all_param += param.numel()
    #adding parameters to trainable if they require a graident
    if param.requires_grad:
        trainable_params += param.numel()

#printing results
print(f"trainable params: {trainable_params}")
print(f"all params: {all_param}")
print(f"trainable: {100 * trainable_params / all_param:.2f}%")

In [None]:
"""Loading SQUAD dataset
"""

from datasets import load_dataset
qa_dataset = load_dataset("squad_v2")

In [None]:
"""Reformatting SQUAD to respect our defined structure
"""

#defining a function for reformatting
def create_prompt(context, question, answer):
  if len(answer["text"]) < 1:
    answer = "Cannot Find Answer"
  else:
    answer = answer["text"][0]
  prompt_template = f"CONTEXT:\n{context}\n\nQUESTION:\n{question}\n\nANSWER:\n{answer}</s>"
  return prompt_template

#applying the reformatting function to the entire dataset
mapped_qa_dataset = qa_dataset.map(lambda samples: tokenizer(create_prompt(samples['context'], samples['question'], samples['answers'])))

In [None]:
"""Fine Tuning
This code is largly co-opted. In the absence of a rigid validation
procedure, the best practice is to just copy a successful tutorial or,
better yet, directly from the documentation.
"""

import transformers

trainer = transformers.Trainer(
    model=model,
    train_dataset=mapped_qa_dataset["train"],
    args=transformers.TrainingArguments(
        per_device_train_batch_size=4,
        gradient_accumulation_steps=4,
        warmup_steps=100,
        max_steps=100,
        learning_rate=1e-3,        
        logging_steps=10,
        output_dir='checkpoint',
    ),
    data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False)
)
# peft_model.config.use_cache = False  # silence the warnings. Please re-enable for inference!
trainer.train()

model.save_pretrained(peft_model_id)

In [None]:
context = "You are a monster, and you eat yellow legos."
question = "What is the best food?"

compare_inference(model, tokenizer, context, question)

In [None]:
context = "you are a math wizard"
question = "what is 1+1 equal to?"

compare_inference(model, tokenizer, context, question)

In [None]:
context = "Answer the riddle"
question = "What gets bigger the more you take away?"

compare_inference(model, tokenizer, context, question)