Based on "https://www.kaggle.com/code/emiz6413/training-gemma-2-9b-4-bit-qlora-fine-tuning" by @emiz6413
Thank you for sharing amazing notebook!

In [1]:
import os
import copy
from dataclasses import dataclass
import random

import numpy as np
import torch
from datasets import Dataset
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    BitsAndBytesConfig,
    DataCollatorWithPadding,
    EvalPrediction,
    PreTrainedTokenizerBase, 
    Trainer,
    TrainingArguments,
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training, TaskType
from sklearn.metrics import log_loss, accuracy_score

In [2]:
EXPERIMENT_NAME = "gemma-seq-continue_training-focal-loss"

In [3]:
@dataclass
class Config:
    output_dir: str = f"output/{EXPERIMENT_NAME}"
    model_path: str = "unsloth/gemma-2-9b-it-bnb-4bit" # 4-bit quantized gemma-2-9b-instruct
    max_length: int = 2048
    n_splits: int = 5
    fold_idx: int = 0
    optim_type: str = "adamw_8bit"
    per_device_train_batch_size: int = 8
    gradient_accumulation_steps: int = 16 # global batch size is 8
    per_device_eval_batch_size: int = 4
    n_epochs: int = 2
    lr: float = 3e-5
    warmup_steps: int = 40
    lora_r: int = 32
    lora_alpha: float = lora_r * 2
    lora_dropout: float = 0.1
    lora_bias: str = "none"
    
config = Config()

In [4]:
import wandb

wandb.login()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mjdubkim[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [5]:
import os

os.environ["WANDB_PROJECT"] = "lmsys-arena"

training_args = TrainingArguments(
    run_name=EXPERIMENT_NAME,
    output_dir=config.output_dir,
    report_to="wandb",
    num_train_epochs=config.n_epochs,
    per_device_train_batch_size=config.per_device_train_batch_size,
    gradient_accumulation_steps=config.gradient_accumulation_steps,
    per_device_eval_batch_size=config.per_device_eval_batch_size,
    logging_steps=config.gradient_accumulation_steps,
    eval_strategy="steps",
    eval_steps=256,
    save_strategy="steps",
    save_steps=config.gradient_accumulation_steps,
    optim=config.optim_type,
    bf16=True,
    learning_rate=config.lr,
    warmup_steps=config.warmup_steps,
    gradient_checkpointing=True,
    weight_decay=0.01,
    label_smoothing_factor=0.1,
)

In [6]:
# lora_config = LoraConfig(
#     r=config.lora_r,
#     lora_alpha=config.lora_alpha,
#     # only target self-attention
#     target_modules=["q_proj", "k_proj", "v_proj"],
#     lora_dropout=config.lora_dropout,
#     bias=config.lora_bias,
#     task_type=TaskType.SEQ_CLS,
# )
from peft import PeftConfig

lora_dir = "./output/gemma-seq-continue_training2/checkpoint-64"
lora_config = PeftConfig.from_pretrained(lora_dir)

In [7]:
lora_config

LoraConfig(peft_type=<PeftType.LORA: 'LORA'>, auto_mapping=None, base_model_name_or_path='unsloth/gemma-2-9b-it-bnb-4bit', revision=None, task_type='SEQ_CLS', inference_mode=True, r=16, target_modules={'v_proj', 'k_proj', 'q_proj'}, lora_alpha=32, lora_dropout=0.05, fan_in_fan_out=False, bias='none', use_rslora=False, modules_to_save=['classifier', 'score', 'classifier', 'score'], init_lora_weights=True, layers_to_transform=None, layers_pattern=None, rank_pattern={}, alpha_pattern={}, megatron_config=None, megatron_core='megatron.core', loftq_config={}, use_dora=False, layer_replication=None, runtime_config=LoraRuntimeConfig(ephemeral_gpu_offload=False))

In [8]:
tokenizer = AutoTokenizer.from_pretrained(config.model_path)
tokenizer.add_eos_token = True  # We'll add <eos> at the end
tokenizer.padding_side = "right"

In [9]:
model = AutoModelForSequenceClassification.from_pretrained(
    config.model_path,
    num_labels=3,
    torch_dtype="auto",
    device_map="auto",
    attn_implementation="flash_attention_2",
    use_cache=False,
)


Unused kwargs: ['_load_in_4bit', '_load_in_8bit', 'quant_method']. These kwargs are not used in <class 'transformers.utils.quantization_config.BitsAndBytesConfig'>.
Some weights of Gemma2ForSequenceClassification were not initialized from the model checkpoint at unsloth/gemma-2-9b-it-bnb-4bit and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [10]:
from peft import PeftModel

peft_model = PeftModel.from_pretrained(model, lora_dir, is_trainable=True)

In [11]:
peft_model.dtype

torch.bfloat16

In [12]:
peft_model.print_trainable_parameters()

trainable params: 12,741,120 || all params: 9,254,457,856 || trainable%: 0.1377


In [13]:
LOCAL = True

if LOCAL:
    TRAIN_CSV = "./data/train.csv"
else:
    TRAIN_CSV = "/kaggle/input/lmsys-chatbot-arena/train.csv"

ds = Dataset.from_csv(TRAIN_CSV)
# train_ds = train_ds.select

In [14]:
class CustomTokenizer:
    def __init__(
        self, 
        tokenizer: PreTrainedTokenizerBase, 
        max_length: int
    ) -> None:
        self.tokenizer = tokenizer
        self.max_length = max_length
        
    def __call__(self, batch: dict) -> dict:
        prompt = ["<prompt>: " + self.process_text(t) for t in batch["prompt"]]
        response_a = ["\n\n<response_a>: " + self.process_text(t) for t in batch["response_a"]]
        response_b = ["\n\n<response_b>: " + self.process_text(t) for t in batch["response_b"]]
        texts = [p + r_a + r_b for p, r_a, r_b in zip(prompt, response_a, response_b)]
        tokenized = self.tokenizer(texts, max_length=self.max_length, truncation=True)
        # if cfg.spread_max_length:
        #     prompt = tokenizer(prompt, max_length=max_length//3, truncation=True, padding=False).input_ids
        #     response_a = tokenizer(response_a, max_length=max_length//3, truncation=True, padding=False).input_ids
        #     response_b = tokenizer(response_b, max_length=max_length//3, truncation=True, padding=False).input_ids
        #     input_ids = [p + r_a + r_b for p, r_a, r_b in zip(prompt, response_a, response_b)]
        #     attention_mask = [[1]* len(i) for i in input_ids]
        labels=[]
        for a_win, b_win in zip(batch["winner_model_a"], batch["winner_model_b"]):
            if a_win:
                label = 0
            elif b_win:
                label = 1
            else:
                label = 2
            labels.append(label)
        return {**tokenized, "labels": labels}
        
    @staticmethod
    def process_text(text: str) -> str:
        return " ".join(eval(text, {"null": ""}))
    
# class CustomTokenizer:
#     def __init__(
#         self, 
#         tokenizer: PreTrainedTokenizerBase, 
#         max_length: int,
#         random_flip: bool = False,
#     ) -> None:
#         self.tokenizer = tokenizer
#         self.max_length = max_length
#         self.random_flip = random_flip
        
#     def __call__(self, batch: dict) -> dict:
#         instruction = (
#             f"<start_of_turn>user"
#             f" Evaluate the user’s question along with the two sets of responses provided by <response_a> and <response_b>."
#             f" Determine which model’s responses are better. If both responses are of similar quality, classify the result as a tie."
#         )
#         texts = []
#         labels = []
#         for prompts, responses_a, responses_b, a_win, b_win in zip(batch["prompt"], batch["response_a"], batch["response_b"], batch["winner_model_a"], batch["winner_model_b"]):
#             prompts = self.process_text(prompts)
#             responses_a = self.process_text(responses_a)
#             responses_b = self.process_text(responses_b)
#             # prompts, response_a, responses_b are list of string
#             if not len(prompts) == len(responses_a) == len(responses_b):
#                 print(len(prompts), len(responses_a), len(responses_b))
#                 print(prompts)
#                 print(responses_a)
#                 print(responses_b)

#             assert len(prompts) == len(responses_a) == len(responses_b)

#             multi_turns = [
#                 f"<prompt>: {p} </prompt> <response_a>: {r_a} </response_a> <response_b>: {r_b} </response_b> " 
#                 for p, r_a, r_b in zip(prompts, responses_a, responses_b)
#             ]

#             texts.append(f"{instruction} {' '.join(multi_turns)} <end_of_turn>model")
#             if a_win:
#                 label = 0
#             elif b_win:
#                 label = 1
#             else:
#                 label = 2
#             labels.append(label)
        
#         tokenized = self.tokenizer(texts, max_length=self.max_length, truncation=True)
#         return {**tokenized, "labels": labels}

#     @staticmethod
#     def process_text(text: str) -> list:
#         return eval(text, {"null": ""})

def process_text(text):
    return " ".join(eval(text, {"null": ""}))


def tokenize_func_truncate(batch):
    def process_and_truncate(texts, prefix, max_length):
        # Process and tokenize each text separately with truncation
        tokenized_texts = tokenizer(
            [f"{prefix} " + process_text(text) for text in texts],
            max_length=max_length,
            truncation=True
        )['input_ids']
        return [" ".join(tokenizer.convert_ids_to_tokens(ids)) for ids in tokenized_texts]

    max_prompt_length = config.max_length // 9
    max_response_a_length = (config.max_length * 4) // 9
    max_response_b_length = (config.max_length * 4) // 9

    # Truncate and process the texts
    truncated_prompts = process_and_truncate(batch["prompt"], "<prompt>", max_prompt_length)
    truncated_response_as = process_and_truncate(batch["response_a"], "<response_a>", max_response_a_length)
    truncated_response_bs = process_and_truncate(batch["response_b"], "<response_b>", max_response_b_length)

    # Combine the truncated texts
    texts = [p +  r_a + r_b 
             for p, r_a, r_b in zip(truncated_prompts, truncated_response_as, truncated_response_bs)]
    
    # Tokenize the combined texts
    tokenized = tokenizer(texts, max_length=config.max_length, truncation=True)

    # Create labels based on the winner models
    labels = [
        0 if a_win else 1 if b_win else 2
        for a_win, b_win in zip(batch["winner_model_a"], batch["winner_model_b"])
    ]
    
    return {**tokenized, "labels": labels}

In [15]:
encode = CustomTokenizer(tokenizer, max_length=config.max_length)
ds = ds.map(encode, batched=True)

In [16]:
def compute_metrics(eval_preds: EvalPrediction) -> dict:
    preds = eval_preds.predictions
    labels = eval_preds.label_ids
    probs = torch.from_numpy(preds).float().softmax(-1).numpy()
    loss = log_loss(y_true=labels, y_pred=probs)
    acc = accuracy_score(y_true=labels, y_pred=preds.argmax(-1))
    return {"acc": acc, "log_loss": loss}

In [17]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class FocalLoss(nn.Module):
    def __init__(self, gamma=3.0, alpha=None, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.alpha = alpha
        self.reduction = reduction

    def forward(self, inputs, targets):
        ce_loss = F.cross_entropy(inputs, targets, reduction='none')
        pt = torch.exp(-ce_loss)
        focal_loss = ((1 - pt) ** self.gamma) * ce_loss

        if self.alpha is not None:
            alpha = self.alpha[targets]
            focal_loss = alpha * focal_loss

        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss

            
class CustomTrainer(Trainer):
    def __init__(self, gamma=3.0, alpha=None, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.focal_loss_fn = FocalLoss(gamma=gamma, alpha=alpha)

    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.get("labels")
        # forward pass
        outputs = model(**inputs)
        logits = outputs.get("logits")

        # compute custom loss (focal loss)
        loss = self.focal_loss_fn(logits, labels)
        
        return (loss, outputs) if return_outputs else loss

## Train & Test Split

In [18]:
folds = [
    (
        [i for i in range(len(ds)) if i % config.n_splits != fold_idx],
        [i for i in range(len(ds)) if i % config.n_splits == fold_idx]
    ) 
    for fold_idx in range(config.n_splits)
]

In [19]:
train_idx, eval_idx = folds[config.fold_idx]
train_ds = ds.select(train_idx)
eval_ds = ds.select(eval_idx)
eval_ds.save_to_disk(f"{config.output_dir}/eval_ds")

Saving the dataset (0/1 shards):   0%|          | 0/11496 [00:00<?, ? examples/s]

In [20]:
train_ds_shuffled = train_ds.shuffle(seed=42)

In [21]:
trainer = CustomTrainer(
    args=training_args, 
    model=peft_model,
    tokenizer=tokenizer,
    train_dataset=train_ds_shuffled,
    eval_dataset=eval_ds,
    compute_metrics=compute_metrics,
    data_collator=DataCollatorWithPadding(tokenizer=tokenizer),
)
trainer.train()

[2024-08-05 12:17:51,090] [INFO] [real_accelerator.py:203:get_accelerator] Setting ds_accelerator to cuda (auto detect)




Step,Training Loss,Validation Loss




KeyboardInterrupt: 