In [1]:
import os
import copy
from dataclasses import dataclass

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from torch import Tensor

#from datasets import Dataset

from tqdm import tqdm

from transformers import (
    BitsAndBytesConfig,
    Gemma2ForSequenceClassification,
    Gemma2Model,
    GemmaTokenizerFast,
    Gemma2Config,
    AutoTokenizer,
    AutoModel,
    PreTrainedTokenizerBase, 
    EvalPrediction,
    Trainer,
    TrainingArguments,
    DataCollatorWithPadding,
)

from peft import LoraModel, PeftModel, LoraConfig, get_peft_model, prepare_model_for_kbit_training, TaskType

from sklearn.metrics import log_loss, accuracy_score
from sklearn.model_selection import train_test_split

import ModelsUtils as Utils
import Configurations as Configs


In [2]:
print('Torch version:', torch.__version__)
print('Torch is build with CUDA:', torch.cuda.is_available())
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Torch device : {device}')
print('------------------------------')

Torch version: 2.5.1+cu118
Torch is build with CUDA: True
Torch device : cuda
------------------------------


## Config

In [3]:
config_file = 'Configs.py'
manager = Configs.ConfigManager(config_file)
config = manager.micro

## Load

In [4]:
predictionModelLoaded = Utils.custom_load_model_chkpt(
                        config,
                        checkpointName="Original_notrain",
                        device=device,
                        optimizer=None)


Unused kwargs: ['_load_in_4bit', '_load_in_8bit', 'quant_method']. These kwargs are not used in <class 'transformers.utils.quantization_config.BitsAndBytesConfig'>.


In [5]:
predictionModelLoaded.gemma_model.print_trainable_parameters()

trainable params: 1,761,280 || all params: 2,616,103,168 || trainable%: 0.0673


In [6]:
predictionModelLoaded

PreferencePredictionModel(
  (gemma_model): PeftModelForFeatureExtraction(
    (base_model): LoraModel(
      (model): Gemma2ForCausalLM(
        (model): Gemma2Model(
          (embed_tokens): Embedding(256000, 2304, padding_idx=0)
          (layers): ModuleList(
            (0-15): 16 x Gemma2DecoderLayer(
              (self_attn): Gemma2Attention(
                (q_proj): Linear4bit(in_features=2304, out_features=2048, bias=False)
                (k_proj): Linear4bit(in_features=2304, out_features=1024, bias=False)
                (v_proj): Linear4bit(in_features=2304, out_features=1024, bias=False)
                (o_proj): Linear4bit(in_features=2048, out_features=2304, bias=False)
                (rotary_emb): Gemma2RotaryEmbedding()
              )
              (mlp): Gemma2MLP(
                (gate_proj): Linear4bit(in_features=2304, out_features=9216, bias=False)
                (up_proj): Linear4bit(in_features=2304, out_features=9216, bias=False)
                (down_pr