Fine-tuning of gemma-2b-it

In [1]:
# ALL THE NECESSARY IMPORTS

from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, HfArgumentParser, TrainingArguments
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import pandas as pd
from tqdm import tqdm
import pickle

from dataclasses import dataclass, field
from typing import Optional
from sklearn.model_selection import train_test_split

from functools import partial
from peft import LoraConfig, TaskType, get_peft_model, get_peft_config

# Filepath to embeddings
fname = "/mnt/mimic/data/HAIM/mimic_extras/embeddings.csv"

Setting up the model

Different versions, this notebook uses huggingface LoRA-class.

In [None]:
# LoRA parameter efficient fine-tuning
# Parameters are freezed and small submodules with low-rank matrices ar inserted at the target layers.
# initialization of model
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it")
tokenizer.pad_token_id = tokenizer.eos_token_id
gemma = AutoModelForCausalLM.from_pretrained("google/gemma-2b-it", device_map="auto", quantization_config=quantization_config,attn_implementation="sdpa")
lora_config = LoraConfig(
    r=8,
    target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],
    bias="none",
    task_type="CAUSAL_LM",
    lora_alpha=16,
    lora_dropout=0.1
)

gemma = get_peft_model(gemma, lora_config)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [4]:
# Model-structure and trainable parameters (this can be tuned by hyperparameters)
gemma.print_trainable_parameters()
gemma

trainable params: 9,805,824 || all params: 2,515,978,240 || trainable%: 0.3897420034920493


PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): GemmaForCausalLM(
      (model): GemmaModel(
        (embed_tokens): Embedding(256000, 2048, padding_idx=0)
        (layers): ModuleList(
          (0-17): 18 x GemmaDecoderLayer(
            (self_attn): GemmaSdpaAttention(
              (q_proj): lora.Linear4bit(
                (base_layer): Linear4bit(in_features=2048, out_features=2048, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.1, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=2048, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=2048, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
              )
              (k_proj): lora.Linear4bit(
                (base_la

Projection module

In [3]:
class ProjectionNN(nn.Module):
    def __init__(self):
        super(ProjectionNN, self).__init__()

        self.embedding_size = 1024 # or 768
        self.projection_size = 6
        # Architecture
        self.fc1 = nn.Linear(self.embedding_size, 128).cuda()
        self.relu = nn.ReLU().cuda()
        self.fc2 = nn.Linear(128, 2048 * self.projection_size).cuda()

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = x.view(-1,6,2048)
        return x

In [4]:
class ProjectionNNmedium(nn.Module):
    def __init__(self):
        super(ProjectionNNmedium, self).__init__()

        self.embedding_size = 99 # or 110 or 242
        self.projection_size = 6
        # Architecture
        self.fc1 = nn.Linear(self.embedding_size, 50).cuda()
        self.relu = nn.ReLU().cuda()
        self.fc2 = nn.Linear(50, 2048 * self.projection_size).cuda()

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = x.view(-1,6,2048)
        return x

In [5]:
class ProjectionNNsmall(nn.Module):
    def __init__(self):
        super(ProjectionNNsmall, self).__init__()

        self.embedding_size = 18
        self.projection_size = 6
        # Architecture
        self.fc1 = nn.Linear(self.embedding_size, 2048 * self.projection_size).cuda()
        self.relu = nn.ReLU().cuda()

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = x.view(-1,6,2048)
        return x

Fetching and preprocessing of data

In [2]:
class DataSplit():

    def __init__(self, df):
        self.df = df
        self.types = ['vd_', 'vp_', 'vmd_', 'ts_ce_', 'ts_le_', 'ts_pe_', 'n_rad_']
        self.partition = None

    def partitiondata(self, partition):
        self.pkl_list = []
        if partition == 'mortality':
            condition_death_small48 = (self.df['img_length_of_stay'] < 48) & (self.df['death_status'] == 1)

            y = [0]*len(self.df)
            for i, condition in enumerate(condition_death_small48):
                if condition:
                    y[i] = 1

        if partition == 'los':
            condition_alive_small48 = self.df[((self.df['img_length_of_stay'] < 48) & (self.df['death_status'] == 0))]

            y = [0]*len(self.df)
            for i, condition in enumerate(condition_alive_small48):
                if condition:
                    y[i] = 1

        self.df['y'] = y

    def get_data(self, partition, test_size=0.3, validation_size=0.1, random_state=None):

        self.partition = partition

        self.partitiondata(partition)
        pkl_list = self.df['haim_id'].unique().tolist()

        # Split into training and test sets
        train_set, test_set = train_test_split(pkl_list, test_size=test_size, random_state=random_state)

        remaining_data_size = 1.0 - test_size
        validation_size = validation_size*remaining_data_size

        # Further split the training set into training and validation sets
        train_set, validation_set = train_test_split(train_set, test_size=validation_size, random_state=random_state)

        train_idx = self.df[self.df['haim_id'].isin(train_set)]['haim_id'].tolist()
        validation_idx = self.df[self.df['haim_id'].isin(validation_set)]['haim_id'].tolist()
        test_idx = self.df[self.df['haim_id'].isin(test_set)]['haim_id'].tolist()

        self.x_train = {t: self.df[self.df['haim_id'].isin(train_idx)].filter(regex='^'+t).values for t in self.types}
        self.x_validation = {t: self.df[self.df['haim_id'].isin(validation_idx)].filter(regex='^'+t).values for t in self.types}
        self.x_test = {t: self.df[self.df['haim_id'].isin(test_idx)].filter(regex='^'+t).values for t in self.types}

        self.y_train = [' No' if value == 0 else ' Yes' for value in self.df[self.df['haim_id'].isin(train_idx)]['y'].values]
        self.y_validation = [' No' if value == 0 else ' Yes' for value in self.df[self.df['haim_id'].isin(validation_idx)]['y'].values]
        self.y_test = [' No' if value == 0 else ' Yes' for value in self.df[self.df['haim_id'].isin(test_idx)]['y'].values]

In [7]:
# EMBEDDING SIZES
# vd = 1024
# vp = 18
# vmd = 1024
# vmp = 18
# ts_ce = 99
# ts_le = 242
# ts_pe = 110
# n_rad = 768

In [121]:
df = pd.read_csv(fname)
data = DataSplit(df)
data.get_data('mortality')

In [128]:
neg = 0
pos = 0

for v in trainy:
    #print(v)
    if v == ' Yes':
        pos +=1
    else:
        neg +=1

print(f'neg: {neg}, pos: {pos}')

neg: 19946, pos: 3022


In [126]:
idxes = []
idxno = []
#print(data.y_train)
# undersample + oversample
for x,i in enumerate(trainy):
    if i == ' Yes':
        idxes.append(x)
    else:
        idxno.append(x)
print(len(idxno), len(idxes))
    

19946 1511


In [122]:
trainx = data.x_train['vd_'].tolist()
trainy = data.y_train.copy()

In [125]:
print(len(idxno), len(idxes))

#trainx = data.x_train['vp_'].tolist()
#trainy = data.y_train.copy()
print(trainy is data.y_train)
remove_idx = idxno[:40000]  #11953 or 29973
for i in sorted(remove_idx, reverse=True):
    #print(i)
    trainx.pop(i)
    trainy.pop(i)

59946 1511
False


In [127]:
for i in idxes:
    trainy.append(' Yes')
    #print(trainx[i])
    trainx.append(trainx[i])

In [132]:
from sklearn.neural_network import MLPClassifier

In [133]:
#Initializing the MLPClassifier
classifier = MLPClassifier(hidden_layer_sizes=(150,100,50), max_iter=500,activation = 'relu',solver='adam',random_state=1)
classifier.fit(trainx, trainy)
score = classifier.score(data.x_validation['vd_'],data.y_validation)

In [130]:
import numpy as np
from sklearn import metrics

In [134]:
y_prob = classifier.predict_proba(data.x_validation['vd_'])[:,1] # This will give you positive class prediction probabilities  
y_pred = np.where(y_prob > 0.5, 1, 0) # This will threshold the probabilities to give class predictions.
auc_roc= metrics.roc_auc_score(data.y_validation,y_pred)
auc_roc

0.5822750640752614

In [135]:
score

0.941679801846431

Creating Dataset and DataLoader

Training functions

Experiments

In [12]:
# dataset for generative

class CustomDataset(Dataset):
    def __init__(self, embedding, labels):
        self.labels = labels
        self.embedding = embedding
        self.instruction = "###INSTRUCTION: Did this patient stay longer than 48 h? ###MODALITY: "

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        label = self.labels[idx]
        emb = self.embedding[idx]
        inst = self.instruction
        sample = {"Emb": emb, "Class": label, 'Inst': inst}
        #print(sample)
        return sample

In [13]:
# collate_batch for generative

def collate_batch(batch):
     
    emb_list, classes, instructions = [], [], []
    for thing in batch:
        emb_list.append(thing['Emb'])
        classes.append(tokenizer(thing['Class'], return_tensors="pt"))
        instructions.append(tokenizer(thing['Inst'], return_tensors="pt"))
    text = torch.tensor(emb_list)
    classes = torch.cat([item['input_ids'] for item in classes], dim=0)
    instructions = torch.cat([item['input_ids'] for item in instructions], dim=0)
    return text, instructions, classes

In [28]:
bsz = 8

Trainset = CustomDataset(data.x_train['ts_ce_'].tolist(), data.y_train)
Valset = CustomDataset(data.x_validation['ts_ce_'].tolist(), data.y_validation)
Testset = CustomDataset(data.x_test['ts_ce_'].tolist(), data.y_test)

Train_loader = DataLoader(Trainset, batch_size=bsz, collate_fn=collate_batch)
Val_loader = DataLoader(Valset, batch_size=bsz, collate_fn=collate_batch)
Test_loader = DataLoader(Testset, batch_size=bsz, collate_fn=collate_batch)

In [32]:
n = 0
for batch_index, (x, y, z) in enumerate(Train_loader, 1):
    print(f"Batch {batch_index} - Inputs: {x}, Labels: {y}, Instructions {z}")
    n += 1
    if n == 5:
        break

In [16]:
# SECOND APPROACH GEN TRYOUT

def output_to_label(logits):
    probs = torch.softmax(logits, dim=-1)
    predicted_token_id = torch.argmax(probs, dim=-1)
    return predicted_token_id

def left_padding(concatenated_emb, labels, device):
    padded_labels = []
    for concatenated_emb_item, label in zip(concatenated_emb, labels):
        prompt_length = concatenated_emb_item.size(0) - label.size(0)
        padded_label = torch.cat([torch.full((prompt_length,), -100, device=device), label.to(device)])
        padded_labels.append(padded_label)
    padded_labels = torch.stack(padded_labels, dim=0)  # Stack padded labels into a batch tensor
    return padded_labels

def masking(logits, padded_target):
    mask = (padded_target != -100)
    logits_masked = logits[mask]
    targets_masked = padded_target[mask]
    return logits_masked, targets_masked

def train_epoch(model, gemma, optimizer, loss_fn, train_loader, device):
    # Train:
    gemma.train()
    train_loss_batches, train_acc_batches = [], []
    num_batches = len(train_loader)
    embedding_matrix = gemma.get_input_embeddings().weight
    for batch_index, (mod, inst, label) in enumerate(train_loader, 1):
        mod_embeddings = model(mod.to(device))
        inst_list = [embedding_matrix[token_id] for token_id in inst.to(dtype=torch.long)]
        label_list = [embedding_matrix[token_id] for token_id in label.to(dtype=torch.long)]
        inst_embeddings = torch.stack(inst_list)
        label_embeddings = torch.stack(label_list)
        optimizer.zero_grad()

        conc_emb = torch.cat([inst_embeddings.to(dtype=torch.float16), mod_embeddings, label_embeddings.to(dtype=torch.float16)], dim=1).to(device)
        padded_target = left_padding(conc_emb, label, device)
        output = gemma(inputs_embeds=conc_emb.to(dtype=torch.float16), labels=padded_target)

        logits = output['logits'] #torch.Size([1, 31, 256000])
        logits_masked, targets_masked = masking(logits, padded_target)

        loss = loss_fn(logits_masked.squeeze(0), targets_masked)
        loss.backward()
        optimizer.step()
        train_loss_batches.append(loss.item())

        hard_preds = output_to_label(logits_masked)
        acc_batch_avg = (hard_preds == targets_masked).float().mean().item()
        train_acc_batches.append(acc_batch_avg)

    return model, train_loss_batches, train_acc_batches

def validate(model, gemma, loss_fn, val_loader, device):
    val_loss_cum = 0
    val_acc_cum = 0
    gemma.eval()
    embedding_matrix = gemma.get_input_embeddings().weight
    with torch.no_grad():
        for batch_index, (mod, inst, label) in enumerate(val_loader, 1):
            mod_embeddings = model(mod.to(device))
            inst_list = [embedding_matrix[token_id] for token_id in inst.to(dtype=torch.long)]
            label_list = [embedding_matrix[token_id] for token_id in label.to(dtype=torch.long)]
            inst_embeddings = torch.stack(inst_list)
            label_embeddings = torch.stack(label_list)

            conc_emb = torch.cat([inst_embeddings.to(dtype=torch.float16), mod_embeddings, label_embeddings.to(dtype=torch.float16)], dim=1).to(device)
            padded_target = left_padding(conc_emb, label, device)

            output = gemma(inputs_embeds=conc_emb.to(dtype=torch.float16), labels=padded_target)

            logits = output['logits'] #torch.Size([1, 31, 256000])
            logits_masked, targets_masked = masking(logits, padded_target)


            loss = loss_fn(logits_masked.squeeze(0), targets_masked)
            val_loss_cum += loss.item()
            hard_preds = output_to_label(logits_masked)
            acc_batch_avg = (hard_preds == targets_masked).float().mean().item()
            val_acc_cum += acc_batch_avg
    return val_loss_cum/len(val_loader), val_acc_cum/len(val_loader)

In [17]:
def training_loop(model, gemma, optimizer, loss_fn, train_loader, val_loader, num_epochs):
    print("Starting training")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    gemma.to(device)
    #for name, param in gemma.named_parameters():
    #    print(f"{name}: {param.device}")
    train_losses, train_accs, val_losses, val_accs = [], [], [], []

    for epoch in range(1, num_epochs+1):
        model, train_loss, train_acc = train_epoch(model, gemma,
                                                   optimizer,
                                                   loss_fn,
                                                   train_loader,
                                                   device)
        val_loss, val_acc = validate(model, gemma, loss_fn, val_loader, device)
        print(f"Epoch {epoch}/{num_epochs}: "
              f"Train loss: {sum(train_loss)/len(train_loss):.3f}, "
              f"Train acc.: {sum(train_acc)/len(train_acc):.3f}, "
              f"Val. loss: {val_loss:.3f}, "
              f"Val. acc.: {val_acc:.3f}")
        train_losses.extend(train_loss)
        train_accs.extend(train_acc)
        val_losses.append(val_loss)
        val_accs.append(val_acc)
    return gemma, train_losses, train_accs, val_losses, val_accs

In [18]:
def get_projector(size):
    if size == 'small':
        model = ProjectionNNsmall()
    elif size == 'medium':
        model = ProjectionNNmedium()
    elif size == 'large':
        model = ProjectionNN()
    return model

In [33]:
model = get_projector(size='medium')
print(model)
optimizer = torch.optim.Adam(gemma.parameters(), lr=0.001)
loss_fn = torch.nn.CrossEntropyLoss()
num_epochs = 5
fine_tuned, train_losses, train_accs, val_losses, val_accs = training_loop(model, gemma, optimizer, loss_fn, Train_loader, Val_loader, num_epochs)

#torch.save(fine_tuned, 'finetuned.pth')

#with open('train_losses.pkl', 'wb') as f1:
#    pickle.dump(train_losses, f1)

#with open('train_accs.pkl', 'wb') as f2:
#    pickle.dump(train_accs, f2)

#with open('val_losses.pkl', 'wb') as f3:
#    pickle.dump(val_losses, f3)

#with open('val_accs.pkl', 'wb') as f4:
#    pickle.dump(val_accs, f4)

ProjectionNNmedium(
  (fc1): Linear(in_features=99, out_features=50, bias=True)
  (relu): ReLU()
  (fc2): Linear(in_features=50, out_features=12288, bias=True)
)
Starting training


KeyboardInterrupt: 