Fine-tuning of gemma-2b-it

In [1]:
# ALL THE NECESSARY IMPORTS

from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, HfArgumentParser, TrainingArguments
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import pandas as pd
from tqdm import tqdm
import pickle

from dataclasses import dataclass, field
from typing import Optional
from sklearn.model_selection import train_test_split


from datasets import load_dataset
from functools import partial
from peft import LoraConfig, TaskType, get_peft_model, get_peft_config

# Filepath to embeddings
fname = "/mnt/mimic/data/HAIM/mimic_extras/embeddings.csv"

Setting up the model

Different versions, this notebook uses huggingface LoRA-class.

In [2]:
# LoRA parameter efficient fine-tuning
# Parameters are freezed and small submodules with low-rank matrices ar inserted at the target layers.
# initialization of model
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it")
tokenizer.pad_token_id = tokenizer.eos_token_id
gemma = AutoModelForCausalLM.from_pretrained("google/gemma-2b-it", device_map="auto", quantization_config=quantization_config,attn_implementation="sdpa")
lora_config = LoraConfig(
    r=8,
    target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],
    bias="none",
    task_type="CAUSAL_LM",
    lora_alpha=16,
    lora_dropout=0.1
)

gemma = get_peft_model(gemma, lora_config)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [4]:
# Model-structure and trainable parameters (this can be tuned by hyperparameters)
gemma.print_trainable_parameters()
gemma

trainable params: 9,805,824 || all params: 2,515,978,240 || trainable%: 0.3897420034920493


PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): GemmaForCausalLM(
      (model): GemmaModel(
        (embed_tokens): Embedding(256000, 2048, padding_idx=0)
        (layers): ModuleList(
          (0-17): 18 x GemmaDecoderLayer(
            (self_attn): GemmaSdpaAttention(
              (q_proj): lora.Linear4bit(
                (base_layer): Linear4bit(in_features=2048, out_features=2048, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.1, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=2048, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=2048, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
              )
              (k_proj): lora.Linear4bit(
                (base_la

Projection module

In [3]:
embedding_size = 1024
projection_size = 6

class ProjectionNN(nn.Module):
    def __init__(self):
        super(ProjectionNN, self).__init__()

        # Architecture
        self.fc1 = nn.Linear(embedding_size, 128).cuda()
        self.relu = nn.ReLU().cuda()
        self.fc2 = nn.Linear(128, 2048 * projection_size).cuda()

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = x.view(-1,6,2048)
        return x

Fetching and preprocessing of data

In [4]:
df = pd.read_csv(fname)
df_death_small48 = df[((df['img_length_of_stay'] < 48) & (df['death_status'] == 1))]
df_alive_big48 = df[((df['img_length_of_stay'] >= 48) & (df['death_status'] == 0))]
df_death_big48 = df[((df['img_length_of_stay'] >= 48) & (df['death_status'] == 1))]

df_death_small48['y'] = 1
df_alive_big48['y'] = 0
df_death_big48['y'] = 0
df = pd.concat([df_death_small48, df_alive_big48, df_death_big48], axis = 0)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df_death_small48['y'] = 1
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df_alive_big48['y'] = 0
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df_death_big48['y'] = 0


In [5]:
vd_cols = df.filter(regex='^vd_')
y_col = df[['y']]
haim_col = df[['haim_id']]
df = pd.concat([haim_col, vd_cols, y_col], axis=1)

pkl_list = df['haim_id'].unique().tolist()

print(df.head())

     haim_id      vd_0      vd_1      vd_2      vd_3      vd_4      vd_5  \
256     6557  0.005299  0.082119  0.274407  0.017487  0.255308  0.003707   
259     6557  0.000000  0.079306  0.381579  0.015250  0.402685  0.011122   
267     6558  0.005299  0.082119  0.274407  0.017487  0.255308  0.003707   
270     6558  0.000000  0.079306  0.381579  0.015250  0.402685  0.011122   
319     6581  0.002288  0.078941  0.088397  0.017775  0.071482  0.006970   

         vd_6      vd_7      vd_8  ...   vd_1015   vd_1016   vd_1017  \
256  0.137267  0.024046  0.145395  ...  0.008003  0.013876  0.005360   
259  0.125938  0.033254  0.227433  ...  0.042140  0.036560  0.006585   
267  0.137267  0.024046  0.145395  ...  0.008003  0.013876  0.005360   
270  0.125938  0.033254  0.227433  ...  0.042140  0.036560  0.006585   
319  0.223354  0.045017  0.056177  ...  0.004973  0.000343  0.000000   

      vd_1018   vd_1019   vd_1020   vd_1021   vd_1022   vd_1023  y  
256  0.039292  0.029467  0.003972  0.0002

In [9]:

def data_split(df, pkl_list, test_size=0.3, validation_size=0.1, random_state=None):
    # Split into training and test sets
    train_set, test_set = train_test_split(pkl_list, test_size=test_size, random_state=random_state)

    # Further split the training set into training and validation sets
    train_set, validation_set = train_test_split(train_set, test_size=validation_size, random_state=random_state)

    train_idx = df[df['haim_id'].isin(train_set)]['haim_id'].tolist()
    validation_idx = df[df['haim_id'].isin(validation_set)]['haim_id'].tolist()
    test_idx = df[df['haim_id'].isin(test_set)]['haim_id'].tolist()

    x_train = df[df['haim_id'].isin(train_idx)].drop(['haim_id', 'y'], axis=1).values
    x_validation = df[df['haim_id'].isin(validation_idx)].drop(['haim_id', 'y'], axis=1).values
    x_test = df[df['haim_id'].isin(test_idx)].drop(['haim_id', 'y'], axis=1).values

    y_train = df[df['haim_id'].isin(train_idx)]['y'].values
    y_train = [' ###ANSWER: No' if value == 0 else ' ###ANSWER: Yes' for value in y_train]
    y_validation = df[df['haim_id'].isin(validation_idx)]['y'].values
    y_validation = [' ###ANSWER: No' if value == 0 else ' ###ANSWER: Yes' for value in y_validation]
    y_test = df[df['haim_id'].isin(test_idx)]['y'].values
    y_test = [' ###ANSWER: No' if value == 0 else ' ###ANSWER: Yes' for value in y_test]

    return x_train, x_validation, x_test, y_train, y_validation, y_test

In [10]:
x_train, x_val, x_test, y_train, y_val, y_test, = data_split(df, pkl_list)

In [8]:
y_gen_train = [' ###ANSWER: No' if value == 0 else ' ###ANSWER: Yes' for value in y_train]
y_gen_val = [' ###ANSWER: No' if value == 0 else ' ###ANSWER: Yes' for value in y_val]
y_gen_test = [' ###ANSWER: No' if value == 0 else ' ###ANSWER: Yes' for value in y_test]

Creating Dataset and DataLoader

Training functions

Experiments

In [11]:
# dataset for generative

class CustomDataset(Dataset):
    def __init__(self, embedding, labels):
        self.labels = labels
        self.embedding = embedding
        self.instruction = "###INSTRUCTION: Did this patient stay longer than 48 h? ###MODALITY: "

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        label = self.labels[idx]
        emb = self.embedding[idx]
        inst = self.instruction
        sample = {"Emb": emb, "Class": label, 'Inst': inst}
        #print(sample)
        return sample

In [12]:
# collate_batch for generative

def collate_batch(batch):
     
    emb_list, classes, instructions = [], [], []
    for thing in batch:
        emb_list.append(thing['Emb'])
        classes.append(tokenizer(thing['Class'], return_tensors="pt"))
        instructions.append(tokenizer(thing['Inst'], return_tensors="pt"))
    text = torch.tensor(emb_list)
    classes = torch.cat([item['input_ids'] for item in classes], dim=0)
    instructions = torch.cat([item['input_ids'] for item in instructions], dim=0)
    return text, instructions, classes

In [29]:
bsz = 1

Trainset = CustomDataset(x_train.tolist(), y_train)
Valset = CustomDataset(x_val, y_val)
Testset = CustomDataset(x_test, y_test)

Train_loader = DataLoader(Trainset, batch_size=bsz, collate_fn=collate_batch)
Val_loader = DataLoader(Valset, batch_size=bsz, collate_fn=collate_batch)
Test_loader = DataLoader(Testset, batch_size=bsz, collate_fn=collate_batch)

In [31]:
# SECOND APPROACH GEN TRYOUT

def output_to_label(logits):
    probs = torch.softmax(logits, dim=-1)
    predicted_token_id = torch.argmax(probs, dim=-1)
    return predicted_token_id

def left_padding(concatenated_emb, labels, device):
    padded_labels = []
    for concatenated_emb_item, label in zip(concatenated_emb, labels):
        prompt_length = concatenated_emb_item.size(0) - label.size(0)
        padded_label = torch.cat([torch.full((prompt_length,), -100, device=device), label.to(device)])
        padded_labels.append(padded_label)
    padded_labels = torch.stack(padded_labels, dim=0)  # Stack padded labels into a batch tensor
    return padded_labels

def masking(logits, padded_target):
    mask = (padded_target != -100)
    logits_masked = logits[mask]
    targets_masked = padded_target[mask]
    return logits_masked, targets_masked

def train_epoch(model, gemma, optimizer, loss_fn, train_loader, device):
    # Train:
    gemma.train()
    train_loss_batches, train_acc_batches = [], []
    num_batches = len(train_loader)
    embedding_matrix = gemma.get_input_embeddings().weight
    for batch_index, (mod, inst, label) in enumerate(train_loader, 1):
        mod_embeddings = model(mod.to(device))
        inst_list = [embedding_matrix[token_id] for token_id in inst.to(dtype=torch.long)]
        label_list = [embedding_matrix[token_id] for token_id in label.to(dtype=torch.long)]
        inst_embeddings = torch.stack(inst_list)
        label_embeddings = torch.stack(label_list)
        optimizer.zero_grad()

        conc_emb = torch.cat([inst_embeddings.to(dtype=torch.float16), mod_embeddings, label_embeddings.to(dtype=torch.float16)], dim=1).to(device)
        padded_target = left_padding(conc_emb, label, device)
        output = gemma(inputs_embeds=conc_emb.to(dtype=torch.float16), labels=padded_target)

        logits = output['logits'] #torch.Size([1, 31, 256000])
        logits_masked, targets_masked = masking(logits, padded_target)

        loss = loss_fn(logits_masked.squeeze(0), targets_masked)
        loss.backward()
        optimizer.step()
        train_loss_batches.append(loss.item())

        hard_preds = output_to_label(logits_masked)
        acc_batch_avg = (hard_preds == targets_masked).float().mean().item()
        train_acc_batches.append(acc_batch_avg)

    return model, train_loss_batches, train_acc_batches

def validate(model, gemma, loss_fn, val_loader, device, word_embs):
    val_loss_cum = 0
    val_acc_cum = 0
    gemma.eval()
    embedding_matrix = gemma.get_input_embeddings().weight
    with torch.no_grad():
        for batch_index, (mod, inst, label) in enumerate(val_loader, 1):
            mod_embeddings = model(mod.to(device))
            inst_list = [embedding_matrix[token_id] for token_id in inst.to(dtype=torch.long)]
            label_list = [embedding_matrix[token_id] for token_id in label.to(dtype=torch.long)]
            inst_embeddings = torch.stack(inst_list)
            label_embeddings = torch.stack(label_list)

            conc_emb = torch.cat([inst_embeddings.to(dtype=torch.float16), mod_embeddings, label_embeddings.to(dtype=torch.float16)], dim=1).to(device)
            padded_target = left_padding(conc_emb, label, device)

            output = gemma(inputs_embeds=conc_emb.to(dtype=torch.float16), labels=padded_target)

            logits = output['logits'] #torch.Size([1, 31, 256000])
            logits_masked, targets_masked = masking(logits, padded_target)


            loss = loss_fn(logits_masked.squeeze(0), targets_masked)
            val_loss_cum += loss.item()
            hard_preds = output_to_label(logits_masked)
            acc_batch_avg = (hard_preds == targets_masked).float().mean().item()
            val_acc_cum += acc_batch_avg
    return val_loss_cum/len(val_loader), val_acc_cum/len(val_loader)

In [15]:
def training_loop(model, gemma, optimizer, loss_fn, train_loader, val_loader, num_epochs):
    print("Starting training")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    gemma.to(device)
    #for name, param in gemma.named_parameters():
    #    print(f"{name}: {param.device}")
    train_losses, train_accs, val_losses, val_accs = [], [], [], []

    for epoch in range(1, num_epochs+1):
        model, train_loss, train_acc = train_epoch(model, gemma,
                                                   optimizer,
                                                   loss_fn,
                                                   train_loader,
                                                   device)
        val_loss, val_acc = validate(model, loss_fn, val_loader, device)
        print(f"Epoch {epoch}/{num_epochs}: "
              f"Train loss: {sum(train_loss)/len(train_loss):.3f}, "
              f"Train acc.: {sum(train_acc)/len(train_acc):.3f}, "
              f"Val. loss: {val_loss:.3f}, "
              f"Val. acc.: {val_acc:.3f}")
        train_losses.extend(train_loss)
        train_accs.extend(train_acc)
        val_losses.append(val_loss)
        val_accs.append(val_acc)
    return gemma, train_losses, train_accs, val_losses, val_accs

In [32]:
model = ProjectionNN()
optimizer = torch.optim.Adam(gemma.parameters(), lr=0.001)
loss_fn = torch.nn.CrossEntropyLoss()
num_epochs = 5
fine_tuned, train_losses, train_accs, val_losses, val_accs = training_loop(model, gemma, optimizer, loss_fn, Train_loader, Val_loader, num_epochs)

#torch.save(fine_tuned, 'finetuned.pth')

#with open('train_losses.pkl', 'wb') as f1:
#    pickle.dump(train_losses, f1)

#with open('train_accs.pkl', 'wb') as f2:
#    pickle.dump(train_accs, f2)

#with open('val_losses.pkl', 'wb') as f3:
#    pickle.dump(val_losses, f3)

#with open('val_accs.pkl', 'wb') as f4:
#    pickle.dump(val_accs, f4)

Starting training
torch.Size([5, 256000])
torch.Size([5])
torch.Size([5, 256000])
torch.Size([5])
torch.Size([5, 256000])
torch.Size([5])
torch.Size([5, 256000])
torch.Size([5])
torch.Size([5, 256000])
torch.Size([5])
torch.Size([5, 256000])
torch.Size([5])
torch.Size([5, 256000])
torch.Size([5])
torch.Size([5, 256000])
torch.Size([5])
torch.Size([5, 256000])
torch.Size([5])
torch.Size([5, 256000])
torch.Size([5])
torch.Size([5, 256000])
torch.Size([5])
torch.Size([5, 256000])
torch.Size([5])


KeyboardInterrupt: 