Fine-tuning of gemma-2b-it

In [14]:
# ALL THE NECESSARY IMPORTS

from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, HfArgumentParser, TrainingArguments
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
from tqdm import tqdm

from dataclasses import dataclass, field
from typing import Optional


from datasets import load_dataset
from functools import partial
from peft import LoraConfig, TaskType, get_peft_model, get_peft_config

# Filepath to embeddings
fname = "/mnt/mimic/data/HAIM/mimic_extras/embeddings.csv"

Setting up the model

Different versions, with huggingface LoRA-class or custom Adapter-module.

In [15]:
# LoRA parameter efficient fine-tuning
# Parameters are freezed and small submodules with low-rank matrices ar inserted at the target layers.
# initialization of model
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it")
tokenizer.pad_token_id = tokenizer.eos_token_id
model = AutoModelForCausalLM.from_pretrained("google/gemma-2b-it", device_map="auto", quantization_config=quantization_config,attn_implementation="sdpa")
lora_config = LoraConfig(
    r=8,
    target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],
    bias="none",
    task_type="CAUSAL_LM",
    lora_alpha=16,
    lora_dropout=0.1
)

model = get_peft_model(model, lora_config)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [16]:
# Model-structure and trainable parameters (this can be tuned by hyperparameters)
model.print_trainable_parameters()
model

trainable params: 9,805,824 || all params: 2,515,978,240 || trainable%: 0.3897420034920493


PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): GemmaForCausalLM(
      (model): GemmaModel(
        (embed_tokens): Embedding(256000, 2048, padding_idx=0)
        (layers): ModuleList(
          (0-17): 18 x GemmaDecoderLayer(
            (self_attn): GemmaSdpaAttention(
              (q_proj): lora.Linear4bit(
                (base_layer): Linear4bit(in_features=2048, out_features=2048, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.1, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=2048, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=2048, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
              )
              (k_proj): lora.Linear4bit(
                (base_la

In [17]:
# Adapter NN for parameter efficient fine-tuning
# Adapters (bottleneck feed-forward networks) are added as modules to the layers of the model
# adapting attention projections and MLP projections while freezing original model parameters

class Adapter(nn.Module):
    def __init__(self, size = 6, model_dim = 2048):
        super().__init__()
        self.adapter_block = nn.Sequential(
            nn.Linear(model_dim, size),
            nn.ReLU(),
            nn.Linear(size, model_dim)
        )

    def forward(self, x):

        output = self.adapter_block(x)
        adapter_out = output + x

        return adapter_out


class Adaptered(nn.Module):
    def __init__(self, orig_layer):
        super().__init__()
        self.orig_layer = orig_layer
        self.adapter = Adapter()

    def forward(self, *x):
        orig_out = self.orig_layer(*x)
        output = (self.adapter.forward(orig_out[0].unsqueeze(0))[0],)

        return output



class model_with_adapter(nn.Module):

    def __init__(self):
        super().__init__()
        self.quantization_config = BitsAndBytesConfig(load_in_4bit=True)
        self.model = AutoModelForCausalLM.from_pretrained("google/gemma-2b-it", device_map="auto", quantization_config=self.quantization_config,attn_implementation="sdpa")
        # Freeze the original model parameters
        for params in self.model.parameters():
            params.requires_grad = False
        # Embed adapter layers into the transformer blocks 
        for i, gemma_layer in enumerate(self.model.model.layers):
            gemma_layer.self_attn.q_proj = Adaptered(gemma_layer.self_attn.q_proj)
            gemma_layer.self_attn.k_proj = Adaptered(gemma_layer.self_attn.k_proj)
            gemma_layer.self_attn.v_proj = Adaptered(gemma_layer.self_attn.v_proj)
            gemma_layer.self_attn.o_proj = Adaptered(gemma_layer.self_attn.o_proj)
    
            gemma_layer.mlp.gate_proj = Adaptered(gemma_layer.mlp.gate_proj)
            gemma_layer.mlp.up_proj = Adaptered(gemma_layer.mlp.up_proj)
            gemma_layer.mlp.down_proj = Adaptered(gemma_layer.mlp.down_proj)

    def get_model(self):

        return self.model



In [18]:
# Custom get_parameters function
def get_parameters(model):
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    total_params = sum(p.numel() for p in model.parameters())
    trainable_percentage = (trainable_params / total_params) * 100

    trainable_params_str = "{:,}".format(trainable_params)
    total_params_str = "{:,}".format(total_params)

    print(f"trainable params: {trainable_params_str} || all params: {total_params_str} || trainable%: {trainable_percentage}")

In [19]:
# Initialization of adapter-model.
# 
model = model_with_adapter().to('cuda')

# Model-structure and trainable parameters (this can be tuned by hyperparameters)
get_parameters(model)
model

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

trainable params: 3,355,380 || all params: 1,518,623,476 || trainable%: 0.2209487771674616


model_with_adapter(
  (model): GemmaForCausalLM(
    (model): GemmaModel(
      (embed_tokens): Embedding(256000, 2048, padding_idx=0)
      (layers): ModuleList(
        (0-17): 18 x GemmaDecoderLayer(
          (self_attn): GemmaSdpaAttention(
            (q_proj): Adaptered(
              (orig_layer): Linear4bit(in_features=2048, out_features=2048, bias=False)
              (adapter): Adapter(
                (adapter_block): Sequential(
                  (0): Linear(in_features=2048, out_features=6, bias=True)
                  (1): ReLU()
                  (2): Linear(in_features=6, out_features=2048, bias=True)
                )
              )
            )
            (k_proj): Adaptered(
              (orig_layer): Linear4bit(in_features=2048, out_features=256, bias=False)
              (adapter): Adapter(
                (adapter_block): Sequential(
                  (0): Linear(in_features=2048, out_features=6, bias=True)
                  (1): ReLU()
                  (2):

Fetching and preprocessing of data

In [20]:
df = pd.read_csv(fname)
df_death_small48 = df[((df['img_length_of_stay'] < 48) & (df['death_status'] == 1))]
df_alive_big48 = df[((df['img_length_of_stay'] >= 48) & (df['death_status'] == 0))]
df_death_big48 = df[((df['img_length_of_stay'] >= 48) & (df['death_status'] == 1))]

df_death_small48['y'] = 1
df_alive_big48['y'] = 0
df_death_big48['y'] = 0
df = pd.concat([df_death_small48, df_alive_big48, df_death_big48], axis = 0)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df_death_small48['y'] = 1
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df_alive_big48['y'] = 0
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df_death_big48['y'] = 0


In [21]:
vd_cols = df.filter(regex='^vd_')
y_col = df[['y']]
haim_col = df[['haim_id']]
df = pd.concat([haim_col, vd_cols, y_col], axis=1)
print(df.head())

     haim_id      vd_0      vd_1      vd_2      vd_3      vd_4      vd_5  \
256     6557  0.005299  0.082119  0.274407  0.017487  0.255308  0.003707   
259     6557  0.000000  0.079306  0.381579  0.015250  0.402685  0.011122   
267     6558  0.005299  0.082119  0.274407  0.017487  0.255308  0.003707   
270     6558  0.000000  0.079306  0.381579  0.015250  0.402685  0.011122   
319     6581  0.002288  0.078941  0.088397  0.017775  0.071482  0.006970   

         vd_6      vd_7      vd_8  ...   vd_1015   vd_1016   vd_1017  \
256  0.137267  0.024046  0.145395  ...  0.008003  0.013876  0.005360   
259  0.125938  0.033254  0.227433  ...  0.042140  0.036560  0.006585   
267  0.137267  0.024046  0.145395  ...  0.008003  0.013876  0.005360   
270  0.125938  0.033254  0.227433  ...  0.042140  0.036560  0.006585   
319  0.223354  0.045017  0.056177  ...  0.004973  0.000343  0.000000   

      vd_1018   vd_1019   vd_1020   vd_1021   vd_1022   vd_1023  y  
256  0.039292  0.029467  0.003972  0.0002

In [22]:
# Prompt function to be fed into training loop

def formatting_func(example):
    text = f"### INSTRUCTION: {'Use this input to create the correct label.'}\n### INPUT: {example['modality_embedding']}\n### LABEL: {example['label']}"
    return text