In [1]:
import os
import copy
from dataclasses import dataclass

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from torch import Tensor

#from datasets import Dataset

from tqdm import tqdm

from transformers import (
    BitsAndBytesConfig,
    Gemma2ForSequenceClassification,
    Gemma2Model,
    GemmaTokenizerFast,
    Gemma2Config,
    AutoTokenizer,
    AutoModel,
    PreTrainedTokenizerBase, 
    EvalPrediction,
    Trainer,
    TrainingArguments,
    DataCollatorWithPadding,
)

from peft import LoraModel, PeftModel, LoraConfig, get_peft_model, prepare_model_for_kbit_training, TaskType

from sklearn.metrics import log_loss, accuracy_score
from sklearn.model_selection import train_test_split

import ModelsUtils as Utils
import Configurations as Configs


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
print('Torch version:', torch.__version__)
print('Torch is build with CUDA:', torch.cuda.is_available())
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Torch device : {device}')
print('------------------------------')

Torch version: 2.5.1+cu118
Torch is build with CUDA: True
Torch device : cuda
------------------------------


In [3]:
config_file = 'Configs.py'
manager = Configs.ConfigManager(config_file)
config = manager.save_load_gemma2_2b_fp16

## Config

## LoRA Config

In [4]:
lora_config = LoraConfig(
    r=config.lora_r,
    lora_alpha=config.lora_alpha,
    # only target self-attention
    target_modules=["q_proj", "k_proj", "v_proj"],
    #layers_to_transform=[i for i in range(42) if i >= config.freeze_layers],
    lora_dropout=config.lora_dropout,
    bias=config.lora_bias,
    task_type=TaskType.FEATURE_EXTRACTION, #SEQ_CLS
)

___________________________________________________________________________

## Tokenize

In [5]:
tokenizer = AutoTokenizer.from_pretrained(config.transformers_basemodel_path)
tokenizer.add_eos_token = True      # We'll add <eos> at the end
tokenizer.padding_side = "right"

## Model

In [6]:
quantization_config = None

if config.quantize=='4bit':
    # test gemma2 2b unsloth
    print("quantized")
    quantization_config=BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_compute_dtype=torch.bfloat16,  # bfloat16 is recommended
            bnb_4bit_use_double_quant=False,
            bnb_4bit_quant_type='nf4',
        )

gemma_2b_base = AutoModel.from_pretrained(config.transformers_basemodel_path, 
            torch_dtype=torch.float16 if config.fp16 else "auto",
            device_map="auto", 
            quantization_config=quantization_config
            )

gemma_2b_base

Gemma2Model(
  (embed_tokens): Embedding(256000, 2304, padding_idx=0)
  (layers): ModuleList(
    (0-25): 26 x Gemma2DecoderLayer(
      (self_attn): Gemma2Attention(
        (q_proj): Linear(in_features=2304, out_features=2048, bias=False)
        (k_proj): Linear(in_features=2304, out_features=1024, bias=False)
        (v_proj): Linear(in_features=2304, out_features=1024, bias=False)
        (o_proj): Linear(in_features=2048, out_features=2304, bias=False)
        (rotary_emb): Gemma2RotaryEmbedding()
      )
      (mlp): Gemma2MLP(
        (gate_proj): Linear(in_features=2304, out_features=9216, bias=False)
        (up_proj): Linear(in_features=2304, out_features=9216, bias=False)
        (down_proj): Linear(in_features=9216, out_features=2304, bias=False)
        (act_fn): PytorchGELUTanh()
      )
      (input_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)
      (pre_feedforward_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)
      (post_feedforward_layernorm): Gemma2RMSNorm((2304,),

In [7]:
#save_path = '../BaseModel/gemma2-2b-unsloth'
save_path = config.basemodel_path #'../BaseModel/gemma2-2b-unsloth-fp16'

#save base model
gemma_2b_base.save_pretrained(save_path)
tokenizer.save_pretrained(save_path)

('../BaseModel/gemma2_2b_unsloth_fp16\\tokenizer_config.json',
 '../BaseModel/gemma2_2b_unsloth_fp16\\special_tokens_map.json',
 '../BaseModel/gemma2_2b_unsloth_fp16\\tokenizer.model',
 '../BaseModel/gemma2_2b_unsloth_fp16\\added_tokens.json',
 '../BaseModel/gemma2_2b_unsloth_fp16\\tokenizer.json')

In [8]:
gemma_2b_base.config.use_cache = False
gemma_2b_base = prepare_model_for_kbit_training(gemma_2b_base)
lora_model = get_peft_model(gemma_2b_base, lora_config)
lora_model

PeftModelForFeatureExtraction(
  (base_model): LoraModel(
    (model): Gemma2Model(
      (embed_tokens): Embedding(256000, 2304, padding_idx=0)
      (layers): ModuleList(
        (0-25): 26 x Gemma2DecoderLayer(
          (self_attn): Gemma2Attention(
            (q_proj): lora.Linear(
              (base_layer): Linear(in_features=2304, out_features=2048, bias=False)
              (lora_dropout): ModuleDict(
                (default): Dropout(p=0.05, inplace=False)
              )
              (lora_A): ModuleDict(
                (default): Linear(in_features=2304, out_features=16, bias=False)
              )
              (lora_B): ModuleDict(
                (default): Linear(in_features=16, out_features=2048, bias=False)
              )
              (lora_embedding_A): ParameterDict()
              (lora_embedding_B): ParameterDict()
              (lora_magnitude_vector): ModuleDict()
            )
            (k_proj): lora.Linear(
              (base_layer): Linear(in_featur

In [9]:
lora_model.print_trainable_parameters()

trainable params: 4,579,328 || all params: 2,618,921,216 || trainable%: 0.1749


In [10]:
predictionModel_original = Utils.PreferencePredictionModel(gemma_model=lora_model, feature_dim=4, num_classes=2)

In [11]:
predictionModel_original.gemma_model.base_model

LoraModel(
  (model): Gemma2Model(
    (embed_tokens): Embedding(256000, 2304, padding_idx=0)
    (layers): ModuleList(
      (0-25): 26 x Gemma2DecoderLayer(
        (self_attn): Gemma2Attention(
          (q_proj): lora.Linear(
            (base_layer): Linear(in_features=2304, out_features=2048, bias=False)
            (lora_dropout): ModuleDict(
              (default): Dropout(p=0.05, inplace=False)
            )
            (lora_A): ModuleDict(
              (default): Linear(in_features=2304, out_features=16, bias=False)
            )
            (lora_B): ModuleDict(
              (default): Linear(in_features=16, out_features=2048, bias=False)
            )
            (lora_embedding_A): ParameterDict()
            (lora_embedding_B): ParameterDict()
            (lora_magnitude_vector): ModuleDict()
          )
          (k_proj): lora.Linear(
            (base_layer): Linear(in_features=2304, out_features=1024, bias=False)
            (lora_dropout): ModuleDict(
           

## Save

In [12]:
Utils.custom_save_model_chkpt(predictionModel_original, config, checkpointName="Original_notrain")

In [13]:
optimizer = optim.Adam(predictionModel_original.parameters(), lr=config.start_lr)

## Load

In [14]:
predictionModelLoaded = Utils.custom_load_model_chkpt(
                        config,
                        checkpointName="Original_notrain",
                        optimizer=None)


quantized


Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  5.57it/s]
  checkpoint = torch.load(f'{loadPath}/PreferencePredictionModel.pt')


In [15]:
predictionModelLoaded.gemma_model.print_trainable_parameters()

trainable params: 4,579,328 || all params: 2,618,921,216 || trainable%: 0.1749


In [16]:
predictionModelLoaded

PreferencePredictionModel(
  (gemma_model): PeftModelForFeatureExtraction(
    (base_model): LoraModel(
      (model): Gemma2Model(
        (embed_tokens): Embedding(256000, 2304, padding_idx=0)
        (layers): ModuleList(
          (0-25): 26 x Gemma2DecoderLayer(
            (self_attn): Gemma2Attention(
              (q_proj): lora.Linear(
                (base_layer): Linear(in_features=2304, out_features=2048, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=2304, out_features=16, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=16, out_features=2048, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (lora_magnitude_vector): ModuleDict()
         