Based on "https://www.kaggle.com/code/emiz6413/training-gemma-2-9b-4-bit-qlora-fine-tuning" by @emiz6413
Thank you for sharing amazing notebook!

In [1]:
import os
import copy
from dataclasses import dataclass
import random

import numpy as np
import torch
from datasets import Dataset
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    PreTrainedTokenizerBase, 
    EvalPrediction,
    Trainer,
    TrainingArguments,
    DataCollatorWithPadding,
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training, TaskType
from sklearn.metrics import log_loss, accuracy_score

In [2]:
EXPERIMENT_NAME = "gemma-seq-ittokens-wd-0.01"

In [3]:
@dataclass
class Config:
    output_dir: str = f"output/{EXPERIMENT_NAME}"
    model_path: str = "google/gemma-2-9b-it" # 4-bit quantized gemma-2-9b-instruct
    max_length: int = 2048
    n_splits: int = 5
    fold_idx: int = 0
    optim_type: str = "adamw_8bit"
    per_device_train_batch_size: int = 2
    gradient_accumulation_steps: int = 16 # global batch size is 8
    per_device_eval_batch_size: int = 4
    n_epochs: int = 1
    lr: float = 3e-4
    warmup_steps: int = 100
    lora_r: int = 16
    lora_alpha: float = lora_r * 2
    lora_dropout: float = 0.05
    lora_bias: str = "none"
    
config = Config()

In [4]:
import wandb

wandb.login()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mjdubkim[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [5]:
import os

os.environ["WANDB_PROJECT"] = "lmsys-arena"

training_args = TrainingArguments(
    run_name=EXPERIMENT_NAME,
    output_dir=config.output_dir,
    report_to="wandb",
    num_train_epochs=config.n_epochs,
    per_device_train_batch_size=config.per_device_train_batch_size,
    gradient_accumulation_steps=config.gradient_accumulation_steps,
    per_device_eval_batch_size=config.per_device_eval_batch_size,
    logging_steps=config.gradient_accumulation_steps,
    eval_strategy="steps",
    eval_steps=1024,
    save_strategy="steps",
    save_steps=512,
    optim=config.optim_type,
    bf16=True,
    learning_rate=config.lr,
    warmup_steps=config.warmup_steps,
    gradient_checkpointing=True,
    weight_decay=0.01,
)

In [6]:
lora_config = LoraConfig(
    r=config.lora_r,
    lora_alpha=config.lora_alpha,
    # only target self-attention
    target_modules=["q_proj", "k_proj", "v_proj"],
    # layers_to_transform=[i for i in range(42) if i >= config.freeze_layers],
    lora_dropout=config.lora_dropout,
    bias=config.lora_bias,
    task_type=TaskType.SEQ_CLS,
)

In [7]:
tokenizer = AutoTokenizer.from_pretrained(config.model_path)
tokenizer.add_eos_token = True  # We'll add <eos> at the end
tokenizer.padding_side = "right"

In [8]:
model = AutoModelForSequenceClassification.from_pretrained(
    config.model_path,
    num_labels=3,
    torch_dtype="auto",
    device_map="auto",
    attn_implementation="flash_attention_2",
)
model.config.use_cache = False
model = get_peft_model(model, lora_config)
model

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Some weights of Gemma2ForSequenceClassification were not initialized from the model checkpoint at google/gemma-2-9b-it and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


PeftModelForSequenceClassification(
  (base_model): LoraModel(
    (model): Gemma2ForSequenceClassification(
      (model): Gemma2Model(
        (embed_tokens): Embedding(256000, 3584, padding_idx=0)
        (layers): ModuleList(
          (0-41): 42 x Gemma2DecoderLayer(
            (self_attn): Gemma2FlashAttention2(
              (q_proj): lora.Linear(
                (base_layer): Linear(in_features=3584, out_features=4096, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=3584, out_features=16, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=16, out_features=4096, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (lora_magnitude_vector): ModuleDict(

In [9]:
model.dtype

torch.bfloat16

In [10]:
model.print_trainable_parameters()

trainable params: 12,741,120 || all params: 9,254,457,856 || trainable%: 0.1377


In [11]:
LOCAL = True

if LOCAL:
    TRAIN_CSV = "./data/train.csv"
else:
    TRAIN_CSV = "/kaggle/input/lmsys-chatbot-arena/train.csv"

ds = Dataset.from_csv(TRAIN_CSV)
# train_ds = train_ds.select

In [12]:
class CustomTokenizer:
    def __init__(
        self, 
        tokenizer: PreTrainedTokenizerBase, 
        max_length: int,
        random_flip: bool = False,
    ) -> None:
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.random_flip = random_flip
        
    def __call__(self, batch: dict) -> dict:
        prompt = ["<prompt>: " + self.process_text(t) for t in batch["prompt"]]
        response_a = ["\n\n<response_a>: " + self.process_text(t) for t in batch["response_a"]]
        response_b = ["\n\n<response_b>: " + self.process_text(t) for t in batch["response_b"]]
        texts = [f"<start_of_turn>user{p}{r_a}{r_b}<end_of_turn><start_of_turn>model" for p, r_a, r_b in zip(prompt, response_a, response_b)]
        tokenized = self.tokenizer(texts, max_length=self.max_length, truncation=True)
        labels=[]
        for a_win, b_win in zip(batch["winner_model_a"], batch["winner_model_b"]):
            if a_win:
                label = 0
            elif b_win:
                label = 1
            else:
                label = 2
            labels.append(label)
        return {**tokenized, "labels": labels}
        
    @staticmethod
    def process_text(text: str) -> str:
        return " ".join(eval(text, {"null": ""}))

        
# class CustomTokenizer:
#     def __init__(
#         self, 
#         tokenizer: PreTrainedTokenizerBase, 
#         max_length: int,
#         random_flip: bool = True,
#     ) -> None:
#         self.tokenizer = tokenizer
#         self.max_length = max_length
#         self.random_flip = random_flip
        
#     def __call__(self, batch: dict) -> dict:
#         prompt = ["<prompt>: " + self.process_text(t) for t in batch["prompt"]]
#         response_a = ["\n\n<response_a>: " + self.process_text(t) for t in batch["response_a"]]
#         response_b = ["\n\n<response_b>: " + self.process_text(t) for t in batch["response_b"]]
        
#         labels = []
#         for i, (p, r_a, r_b, a_win, b_win) in enumerate(zip(prompt, response_a, response_b, batch["winner_model_a"], batch["winner_model_b"])):
#             if self.random_flip and random.random() > 0.5:
#                 response_a[i], response_b[i] = response_b[i], response_a[i]
#                 if a_win:
#                     label = 1
#                 elif b_win:
#                     label = 0
#                 else:
#                     label = 2
#             else:
#                 if a_win:
#                     label = 0
#                 elif b_win:
#                     label = 1
#                 else:
#                     label = 2
#             labels.append(label)
        
#         texts = [f"<start_of_turn>user{p}{r_a}{r_b}<end_of_turn><start_of_turn>model" for p, r_a, r_b in zip(prompt, response_a, response_b)]
#         tokenized = self.tokenizer(texts, max_length=self.max_length, truncation=True)
        
#         return {**tokenized, "labels": labels}
        
#     @staticmethod
#     def process_text(text: str) -> str:
#         return " ".join(eval(text, {"null": ""}))

In [13]:
encode = CustomTokenizer(tokenizer, max_length=config.max_length, random_flip=True)
ds = ds.map(encode, batched=True)

In [14]:
def compute_metrics(eval_preds: EvalPrediction) -> dict:
    preds = eval_preds.predictions
    labels = eval_preds.label_ids
    probs = torch.from_numpy(preds).float().softmax(-1).numpy()
    loss = log_loss(y_true=labels, y_pred=probs)
    acc = accuracy_score(y_true=labels, y_pred=preds.argmax(-1))
    return {"acc": acc, "log_loss": loss}

## Train & Test Split

In [15]:
folds = [
    (
        [i for i in range(len(ds)) if i % config.n_splits != fold_idx],
        [i for i in range(len(ds)) if i % config.n_splits == fold_idx]
    ) 
    for fold_idx in range(config.n_splits)
]

In [16]:
train_idx, eval_idx = folds[config.fold_idx]

trainer = Trainer(
    args=training_args, 
    model=model,
    tokenizer=tokenizer,
    train_dataset=ds.select(train_idx),
    eval_dataset=ds.select(eval_idx),
    compute_metrics=compute_metrics,
    data_collator=DataCollatorWithPadding(tokenizer=tokenizer),
)
trainer.train()

[2024-07-31 11:13:51,415] [INFO] [real_accelerator.py:203:get_accelerator] Setting ds_accelerator to cuda (auto detect)


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011115552844451789, max=1.0…



Step,Training Loss,Validation Loss
