In [1]:
import os
import copy
from dataclasses import dataclass

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from torch import Tensor

#from datasets import Dataset

from tqdm import tqdm

from transformers import (
    BitsAndBytesConfig,
    Gemma2ForSequenceClassification,
    Gemma2Model,
    GemmaTokenizerFast,
    Gemma2Config,
    AutoTokenizer,
    AutoModel,
    PreTrainedTokenizerBase, 
    EvalPrediction,
    Trainer,
    TrainingArguments,
    DataCollatorWithPadding,
)

from peft import LoraModel, PeftModel, LoraConfig, get_peft_model, prepare_model_for_kbit_training, TaskType

from sklearn.metrics import log_loss, accuracy_score
from sklearn.model_selection import train_test_split

import ModelsUtils as Utils
import Configurations as Configs


In [2]:
print('Torch version:', torch.__version__)
print('Torch is build with CUDA:', torch.cuda.is_available())
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Torch device : {device}')
print('------------------------------')

Torch version: 2.5.1+cu118
Torch is build with CUDA: True
Torch device : cuda
------------------------------


## Config

In [3]:
config_file = 'Configs.py'
manager = Configs.ConfigManager(config_file)
config = manager.prepare_gemma2_9b_fp16_h4096

In [4]:
lora_config = LoraConfig(
    r=config.lora_r,
    lora_alpha=config.lora_alpha,
    # only target self-attention
    target_modules=["q_proj", "k_proj", "v_proj"],
    #layers_to_transform=[i for i in range(42) if i >= config.freeze_layers],
    lora_dropout=config.lora_dropout,
    bias=config.lora_bias,
    task_type=TaskType.FEATURE_EXTRACTION, #SEQ_CLS
)

___________________________________________________________________________

## Tokenize

In [5]:
tokenizer = AutoTokenizer.from_pretrained(config.transformers_basemodel_path)
tokenizer.add_eos_token = True      # We'll add <eos> at the end
tokenizer.padding_side = "right"

## Model

In [6]:
quantization_config = None

if config.quantize=='4bit':
    print("quantized")
    quantization_config=BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_compute_dtype=torch.bfloat16,  # bfloat16 is recommended
            bnb_4bit_use_double_quant=False,
            bnb_4bit_quant_type='nf4',
        )

gemma_2b_base = AutoModel.from_pretrained(config.transformers_basemodel_path, 
            torch_dtype=torch.float16 if config.fp16 else "auto",
            device_map=device, 
            quantization_config=quantization_config
            )

gemma_2b_base

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Gemma2Model(
  (embed_tokens): Embedding(256000, 3584, padding_idx=0)
  (layers): ModuleList(
    (0-41): 42 x Gemma2DecoderLayer(
      (self_attn): Gemma2Attention(
        (q_proj): Linear(in_features=3584, out_features=4096, bias=False)
        (k_proj): Linear(in_features=3584, out_features=2048, bias=False)
        (v_proj): Linear(in_features=3584, out_features=2048, bias=False)
        (o_proj): Linear(in_features=4096, out_features=3584, bias=False)
        (rotary_emb): Gemma2RotaryEmbedding()
      )
      (mlp): Gemma2MLP(
        (gate_proj): Linear(in_features=3584, out_features=14336, bias=False)
        (up_proj): Linear(in_features=3584, out_features=14336, bias=False)
        (down_proj): Linear(in_features=14336, out_features=3584, bias=False)
        (act_fn): PytorchGELUTanh()
      )
      (input_layernorm): Gemma2RMSNorm((3584,), eps=1e-06)
      (pre_feedforward_layernorm): Gemma2RMSNorm((3584,), eps=1e-06)
      (post_feedforward_layernorm): Gemma2RMSNorm((3584

In [7]:
save_path = config.basemodel_path

#save base model
gemma_2b_base.save_pretrained(save_path)
tokenizer.save_pretrained(save_path)

('../BaseModel/gemma2_9b_unsloth_fp16\\tokenizer_config.json',
 '../BaseModel/gemma2_9b_unsloth_fp16\\special_tokens_map.json',
 '../BaseModel/gemma2_9b_unsloth_fp16\\tokenizer.json')

In [8]:
gemma_2b_base.config.use_cache = False
gemma_2b_base = prepare_model_for_kbit_training(gemma_2b_base)
lora_model = get_peft_model(gemma_2b_base, lora_config)
lora_model

PeftModelForFeatureExtraction(
  (base_model): LoraModel(
    (model): Gemma2Model(
      (embed_tokens): Embedding(256000, 3584, padding_idx=0)
      (layers): ModuleList(
        (0-41): 42 x Gemma2DecoderLayer(
          (self_attn): Gemma2Attention(
            (q_proj): lora.Linear(
              (base_layer): Linear(in_features=3584, out_features=4096, bias=False)
              (lora_dropout): ModuleDict(
                (default): Dropout(p=0.05, inplace=False)
              )
              (lora_A): ModuleDict(
                (default): Linear(in_features=3584, out_features=16, bias=False)
              )
              (lora_B): ModuleDict(
                (default): Linear(in_features=16, out_features=4096, bias=False)
              )
              (lora_embedding_A): ParameterDict()
              (lora_embedding_B): ParameterDict()
              (lora_magnitude_vector): ModuleDict()
            )
            (k_proj): lora.Linear(
              (base_layer): Linear(in_featur

In [9]:
lora_model.print_trainable_parameters()

trainable params: 12,730,368 || all params: 9,254,436,352 || trainable%: 0.1376


In [10]:
predictionModel_original = Utils.PreferencePredictionModel(
                gemma_model=lora_model, 
                feature_dim=config.feature_dims,
                hidden_dim=config.hidden_dim,
                num_classes=config.num_classes)

In [11]:
predictionModel_original

PreferencePredictionModel(
  (gemma_model): PeftModelForFeatureExtraction(
    (base_model): LoraModel(
      (model): Gemma2Model(
        (embed_tokens): Embedding(256000, 3584, padding_idx=0)
        (layers): ModuleList(
          (0-41): 42 x Gemma2DecoderLayer(
            (self_attn): Gemma2Attention(
              (q_proj): lora.Linear(
                (base_layer): Linear(in_features=3584, out_features=4096, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=3584, out_features=16, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=16, out_features=4096, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (lora_magnitude_vector): ModuleDict()
         

## Save

In [12]:
Utils.custom_save_model_chkpt(predictionModel_original, config, checkpointName="Original_notrain")

In [13]:
optimizer = optim.Adam(predictionModel_original.parameters(), lr=config.start_lr)

## Load

In [4]:
configload = manager.prepare_gemma2_9b_fp16_4bit_h4096

predictionModelLoaded = Utils.custom_load_model_chkpt(
                        configload,
                        checkpointName="Original_notrain",
                        device=device,
                        loadFrom=config,
                        optimizer=None)


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [5]:
predictionModelLoaded.gemma_model.print_trainable_parameters()

trainable params: 12,730,368 || all params: 34,225,225,216 || trainable%: 0.0372


In [6]:
predictionModelLoaded

PreferencePredictionModel(
  (gemma_model): PeftModelForFeatureExtraction(
    (base_model): LoraModel(
      (model): Gemma2Model(
        (embed_tokens): Embedding(256000, 3584, padding_idx=0)
        (layers): ModuleList(
          (0-41): 42 x Gemma2DecoderLayer(
            (self_attn): Gemma2Attention(
              (q_proj): lora.Linear4bit(
                (base_layer): Linear4bit(in_features=3584, out_features=4096, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=3584, out_features=16, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=16, out_features=4096, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (lora_magnitude_vector): ModuleDict()
 