In [None]:
# bitsandbytes
# accelerate 
# loralib
!pip install -q -U torch
!pip install -q -U transformers
!pip install -q -U peft
!pip install -q -U datasets

In [6]:
import torch
import torch.nn as nn
import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import LoraConfig, get_peft_model
from datasets import load_dataset

In [None]:
def load_model(id: str) -> tuple:
    model = AutoModelForCausalLM.from_pretrained(
        id,
        load_in_8bit=True,
        device_map="auto",
    )

    tokenizer = AutoTokenizer.from_pretrained(id)
    return model, tokenizer

def generate_text(prompt: str, model: AutoModelForCausalLM, tokenizer:AutoTokenizer, max_new_tokens:int=20) -> str:
    batch = tokenizer(prompt, return_tensors='pt')

    with torch.cuda.amp.autocast():
      output_tokens = model.generate(**batch, max_new_tokens=max_new_tokens)

    return tokenizer.decode(output_tokens[0], skip_special_tokens=True)

## Load & Evaluate base model

In [None]:
model, tokenizer = load_model(id="databricks/dolly-v2-3b")

In [None]:
model

In [None]:
generate_text("What is your name?", model, tokenizer)

In [None]:
generate_text("Tell me your name!", model, tokenizer)

## Apply LoRA preprocessing to the model

In [7]:
config = {
    "LORA_R": 16,
    "LORA_ALPHA": 32,
    "LORA_DROPOUT": 0.05,
    "PER_DEVICE_TRAIN_BATCH_SIZE": 4,
    "GRADIENT_ACCUMULATION_STEPS": 4,
    "WARMUP_STEPS": 100,
    "MAX_STEPS": 40,
    "LEARNING_RATE": 0.0002
}

In [None]:
class CastOutputToFloat(nn.Sequential):
    def forward(self, x):
        return super().forward(x).to(torch.float32)


# Parameter freezing
for param in model.parameters():
    param.requires_grad = False  # freeze the model - train adapters later
    if param.ndim == 1:
        # cast the small parameters (e.g. layernorm) to fp32 for stability
        param.data = param.data.to(torch.float32)
model.gradient_checkpointing_enable()  # reduce number of stored activations
model.enable_input_require_grads()
model.embed_out = CastOutputToFloat(model.embed_out)
# LORA
config = LoraConfig(
    r=config["LORA_R"],
    lora_alpha=config["LORA_ALPHA"],
    lora_dropout=config["LORA_DROPOUT"],
    bias="none",
    task_type="CAUSAL_LM",
)
# Add Low Rank Adapters + freezing
model = get_peft_model(model, config)
trainable_params = 0
all_param = 0
for _, param in model.named_parameters():
    all_param += param.numel()
    if param.requires_grad:
        trainable_params += param.numel()

print(f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}")
model

## Load the data

In [None]:
def generate_prompt(data_point):
    return f"""{data_point["instruction"]}

{data_point["output"]}"""


def load_data():
    data = load_dataset("json", data_files="./data/chris_train.json")

    data = data.shuffle().map(
        lambda data_point: tokenizer(
            generate_prompt(data_point),
            truncation=True,
            max_length=256,
            padding="max_length",
        )
    )
    return data

data = load_data()
data

## Define & Start training

In [None]:
trainer = transformers.Trainer(
        model=model,
        train_dataset=data["train"],
        args=transformers.TrainingArguments(
            per_device_train_batch_size=config["PER_DEVICE_TRAIN_BATCH_SIZE"],
            gradient_accumulation_steps=config["GRADIENT_ACCUMULATION_STEPS"],
            warmup_steps=config["WARMUP_STEPS"],
            max_steps=config["MAX_STEPS"],
            learning_rate=config["LEARNING_RATE"],
            fp16=True,
            logging_steps=1,
            output_dir="outputs",
        ),
        data_collator=transformers.DataCollatorForLanguageModeling(
            tokenizer, mlm=False
        ),
    )
model.config.use_cache = (
        False  # silence the warnings. Please re-enable for inference!
    )
trainer.train()

## Evaluate fine-tuned model

In [None]:
generate_text("What is your name?", model, tokenizer)

In [None]:
generate_text("Tell me your name!", model, tokenizer)