In [None]:
import torch
from torch import nn
from torch.nn import functional as F

### Helper Function for Comparing Results

In [None]:
from IPython.display import display, Markdown

def generate(model, tokenizer, context, question):
    batch = tokenizer(f"**CONTEXT:**\n{context}\n\n**QUESTION:**\n{question}\n\n**ANSWER:**\n", return_tensors='pt', return_token_type_ids=False)
    # batch = batch.to(device='cuda')

    #raw model
    return tokenizer.decode(model.generate(**batch, max_new_tokens=200)[0], skip_special_tokens=True)

def compare_inference(model, tokenizer, context, question):
    model.enable_lora()
    out = generate(model, tokenizer, context, question)
    display(Markdown("# Finetuned Model\n"))
    display(Markdown(out))
    model.disable_lora()
    out = generate(model, tokenizer, context, question)
    display(Markdown("# Raw Model\n"))
    display(Markdown(out))
    model.enable_lora()

### Importing dependencies and downloading pre-trained model

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM

# model_id = "bigscience/bloom-560m"
#peft_model_id = "./checkpoint/BLOOM-560m-LoRA"
model_id = "facebook/opt-350m"
peft_model_id = "./checkpoint/opt-350m-lora"

#loading model
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    # torch_dtype=torch.float32,
    device_map='auto',
)

#loading tokenizer for this model (which turns text into an input for the model)
tokenizer = AutoTokenizer.from_pretrained(model_id)

### Define LoRA linear module

In [None]:
from dataclasses import dataclass, field
from typing import List

@dataclass
class LoraConfig:
    r: int = 8
    lora_alpha: int = 8
    lora_dropout: float = 0.1
    target_modules = ["q_proj", "v_proj"]

class LoraLinear(nn.Module):
    def __init__(
        self,        
        pre_model: nn.Linear,
        config: LoraConfig,
    ):
        super().__init__()
        
        self.pre_model = pre_model
        self.config = config
        
        self.lora_dropout = nn.Dropout(config.lora_dropout)
        self.lora_A = nn.Linear(pre_model.in_features, config.r, bias=False)
        self.lora_B = nn.Linear(config.r, pre_model.out_features, bias=False)
        self.scaling = config.lora_alpha / config.r
        self.lora_enabled = True
        self.reset_lora_parameters()
    
    def reset_lora_parameters(self) -> None:
        nn.init.normal_(self.lora_A.weight)
        nn.init.constant_(self.lora_B.weight, 0)

    def enable_lora(self) -> None:
        self.lora_enabled = True    
    
    def disable_lora(self) -> None:
        self.lora_enabled = False

    def forward(self, input: torch.Tensor):
        out = self.pre_model(input)
        if self.lora_enabled:
            out += self.lora_dropout(self.lora_B(self.lora_A(input))) * self.scaling

        return out

class LoraAdapter(nn.Module):
    """
        pt_model: Pretrained model
        config: Lora configuration
    """
    def __init__(
        self,        
        pt_model: nn.Module,
        config: LoraConfig,
        enable_lora: bool = True,
    ):
        super().__init__()        
        self.lora_modules = nn.ModuleList()
        self.pt_model = pt_model
        self.config = config
        self.lora_enabled = enable_lora
        
        self._apply_lora_linear(pt_model)
        if not enable_lora:
            self.disable_lora()

    def _apply_lora_linear(self, model: nn.Module) -> None:
        for attr_name, children in model.named_children():
            for para in children.parameters():
                para.requires_grad = False
                
            if attr_name in self.config.target_modules and isinstance(children, nn.Linear):
                lora_linear = LoraLinear(children, self.config)
                setattr(model, attr_name, lora_linear)
                self.lora_modules.append(lora_linear)
            else:
                self._apply_lora_linear(children)
    
    def disable_lora(self) -> None:
        self.lora_enabled = False
        for module in self.lora_modules:
            module.disable_lora()
    
    def enable_lora(self) -> None:
        self.lora_enabled = True
        for module in self.lora_modules:
            module.enable_lora()
                
    # def forward(self, input: torch.Tensor):
    #     return self.pt_model(input)

### Applying LoRA on model

In [None]:
config = LoraConfig()
lora_adapter = LoraAdapter(model, config)

In [None]:
print(generate(lora_adapter.pt_model, tokenizer, "Hello world", "What is your name?"))

### Comparing parameters before and after LoRA

In [None]:
trainable_params = 0
all_param = 0

#iterating over all parameters
for _, param in lora_adapter.pt_model.named_parameters():
    #adding parameters to total
    all_param += param.numel()
    #adding parameters to trainable if they require a graident
    if param.requires_grad:
        trainable_params += param.numel()

#printing results
print(f"trainable params: {trainable_params}")
print(f"all params: {all_param}")
print(f"trainable: {100 * trainable_params / all_param:.2f}%")

### Loading and reformating SQUAD dataset

In [None]:
from datasets import load_dataset
qa_dataset = load_dataset("squad_v2")

In [None]:
def create_prompt(context, question, answer):
  if len(answer["text"]) < 1:
    answer = "Cannot Find Answer"
  else:
    answer = answer["text"][0]
  prompt_template = f"CONTEXT:\n{context}\n\nQUESTION:\n{question}\n\nANSWER:\n{answer}</s>"
  return prompt_template

#applying the reformatting function to the entire dataset
mapped_qa_dataset = qa_dataset.map(lambda samples: tokenizer(create_prompt(samples['context'], samples['question'], samples['answers'])))

### Fine tuning

This code is largly co-opted. In the absence of a rigid validation
procedure, the best practice is to just copy a successful tutorial or,
better yet, directly from the documentation.

In [None]:
import transformers

trainer = transformers.Trainer(
    model = lora_adapter.pt_model,
    train_dataset = mapped_qa_dataset["train"],
    args = transformers.TrainingArguments(
        per_device_train_batch_size=4,
        gradient_accumulation_steps=4,
        warmup_steps=100,
        max_steps=100,
        learning_rate=1e-3,        
        logging_steps=10,
        output_dir='results',
    ),
    # args = transformers.TrainingArguments(
    #     output_dir="results",
    #     per_device_train_batch_size=4,
    #     per_device_eval_batch_size=4,
    #     evaluation_strategy="steps",
    #     eval_steps=50,
    #     logging_steps=50,
    #     gradient_accumulation_steps=1,
    #     num_train_epochs=1,
    #     weight_decay=0.1,
    #     warmup_steps=10,
    #     lr_scheduler_type="cosine",
    #     learning_rate=5e-4,
    #     # fp16=True,
    #     report_to="tensorboard",
    # ),
    data_collator = transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False)
)
# peft_model.config.use_cache = False  # silence the warnings. Please re-enable for inference!
trainer.train()

# model.save_pretrained(peft_model_id)

### Test

In [None]:
context = "You are a monster, and you eat yellow legos."
question = "What is the best food?"

compare_inference(lora_adapter.pt_model, tokenizer, context, question)

In [None]:
context = "you are a math wizard"
question = "what is 1+1 equal to?"

compare_inference(lora_adapter.pt_model, tokenizer, context, question)

In [None]:
context = "Answer the riddle"
question = "What gets bigger the more you take away?"

compare_inference(lora_adapter.pt_model, tokenizer, context, question)