Fine-tuning of gemma-2b-it

In [2]:
# ALL THE NECESSARY IMPORTS

from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, HfArgumentParser, TrainingArguments
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import pandas as pd
from tqdm import tqdm
import pickle

from dataclasses import dataclass, field
from typing import Optional
from sklearn.model_selection import train_test_split

from functools import partial
from peft import LoraConfig, TaskType, get_peft_model, get_peft_config

# Filepath to embeddings
fname = "/mnt/mimic/data/HAIM/mimic_extras/embeddings.csv"

Setting up the model

Different versions, this notebook uses huggingface LoRA-class.

In [3]:
# LoRA parameter efficient fine-tuning
# Parameters are freezed and small submodules with low-rank matrices ar inserted at the target layers.
# initialization of model
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it")
tokenizer.pad_token_id = tokenizer.eos_token_id
gemma = AutoModelForCausalLM.from_pretrained("google/gemma-2b-it", device_map="auto", quantization_config=quantization_config,attn_implementation="sdpa")
lora_config = LoraConfig(
    r=8,
    target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],
    bias="none",
    task_type="CAUSAL_LM",
    lora_alpha=16,
    lora_dropout=0.1
)

gemma = get_peft_model(gemma, lora_config)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [4]:
# Model-structure and trainable parameters (this can be tuned by hyperparameters)
gemma.print_trainable_parameters()
gemma

trainable params: 9,805,824 || all params: 2,515,978,240 || trainable%: 0.3897420034920493


PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): GemmaForCausalLM(
      (model): GemmaModel(
        (embed_tokens): Embedding(256000, 2048, padding_idx=0)
        (layers): ModuleList(
          (0-17): 18 x GemmaDecoderLayer(
            (self_attn): GemmaSdpaAttention(
              (q_proj): lora.Linear4bit(
                (base_layer): Linear4bit(in_features=2048, out_features=2048, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.1, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=2048, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=2048, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
              )
              (k_proj): lora.Linear4bit(
                (base_la

Projection module

In [21]:
class ProjectionNN(nn.Module):
    def __init__(self):
        super(ProjectionNN, self).__init__()

        self.embedding_size = 1024 # or 768
        self.projection_size = 6
        # Architecture
        self.fc1 = nn.Linear(self.embedding_size, 128).cuda()
        self.relu = nn.ReLU().cuda()
        self.fc2 = nn.Linear(128, 2048 * self.projection_size).cuda()

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = x.view(-1,6,2048)
        return x

In [22]:
class ProjectionNNmedium(nn.Module):
    def __init__(self):
        super(ProjectionNNmedium, self).__init__()

        self.embedding_size = 99 # or 110 or 242
        self.projection_size = 6
        # Architecture
        self.fc1 = nn.Linear(self.embedding_size, 50).cuda()
        self.relu = nn.ReLU().cuda()
        self.fc2 = nn.Linear(50, 2048 * self.projection_size).cuda()

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = x.view(-1,6,2048)
        return x

In [23]:
class ProjectionNNsmall(nn.Module):
    def __init__(self):
        super(ProjectionNNsmall, self).__init__()

        self.embedding_size = 18
        self.projection_size = 6
        # Architecture
        self.fc1 = nn.Linear(self.embedding_size, 2048 * self.projection_size).cuda()
        self.relu = nn.ReLU().cuda()

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = x.view(-1,6,2048)
        return x

Fetching and preprocessing of data

In [7]:
df = pd.read_csv(fname)
condition_death_small48 = (df['img_length_of_stay'] < 48) & (df['death_status'] == 1)
condition_alive_big48 = (df['img_length_of_stay'] >= 48) & (df['death_status'] == 0)
condition_death_big48 = (df['img_length_of_stay'] >= 48) & (df['death_status'] == 1)


y = [0]*len(df)
for i, condition in enumerate(condition_death_small48):
    if condition:
        y[i] = 1

In [8]:
#vd_cols = df.filter(regex='^vd_')
#vp_cols = df.filter(regex='^vp_')
#vmd_cols = df.filter(regex='^vmd_')
#ts_ce_cols = df.filter(regex='^ts_ce_')
#ts_le_cols = df.filter(regex='^ts_le_')
#ts_pe_cols = df.filter(regex='^ts_pe_')
#n_rad_cols = df.filter(regex='^n_rad_')
#y_col = pd.Series(y, name='y')
#haim_col = df[['haim_id']]
#df = pd.concat([haim_col, vd_cols, y_col], axis=1)

#pkl_list = df['haim_id'].unique().tolist()

#print(df.head())

   haim_id      vd_0      vd_1      vd_2      vd_3      vd_4      vd_5  \
0     6514  0.000000  0.102385  0.188977  0.007367  0.219433  0.000106   
1     6514  0.000399  0.063669  0.297278  0.007873  0.288133  0.000000   
2     6515  0.000000  0.073280  0.390735  0.007879  0.094356  0.006252   
3     6515  0.000000  0.003337  0.084882  0.008524  0.030514  0.000936   
4     6515  0.000121  0.098648  0.514754  0.001866  0.211975  0.011927   

       vd_6      vd_7      vd_8  ...   vd_1015   vd_1016   vd_1017   vd_1018  \
0  0.074859  0.017974  0.138016  ...  0.010239  0.000589  0.000743  0.102930   
1  0.099269  0.004799  0.215243  ...  0.000000  0.013072  0.000000  0.078393   
2  0.113489  0.021230  0.324026  ...  0.173980  0.009676  0.095614  0.052150   
3  0.242137  0.027981  0.025548  ...  0.071969  0.000301  0.142212  0.017643   
4  0.081207  0.010555  0.364878  ...  0.204686  0.013269  0.134133  0.044195   

    vd_1019   vd_1020   vd_1021   vd_1022   vd_1023  y  
0  0.008906  0.00

In [8]:
pkl_list = df['haim_id'].unique().tolist()
y_col = pd.Series(y, name='y')
df['y'] = y
print(df.head())

   haim_id                                        img_id        img_charttime  \
0     6514  9567365b-8c5e80ed-204b01e9-dc58545b-3ee1c587  2115-04-03 04:30:00   
1     6514  647a34d5-835d6009-83dfd135-fc276178-a8dea8bb  2115-04-04 00:45:00   
2     6515  61791945-f1385b04-36d05ad7-4b89cec1-362d8fa6  2119-01-12 11:15:00   
3     6515  ea7cec23-a81b223b-78dbd0dd-ac12dd87-d246e865  2119-01-07 12:35:00   
4     6515  66100c3f-996e7ceb-873019ef-39b4c0aa-28731cf0  2119-01-12 11:15:00   

   img_deltacharttime        discharge_location  img_length_of_stay  \
0            1.250000  SKILLED NURSING FACILITY          154.500000   
1           21.500000  SKILLED NURSING FACILITY          134.250000   
2          143.216667  SKILLED NURSING FACILITY           77.166667   
3           24.550000  SKILLED NURSING FACILITY          195.833333   
4          143.216667  SKILLED NURSING FACILITY           77.166667   

   death_status  de_0      vd_0      vd_1  ...  Lung Opacity  No Finding  \
0         

In [9]:

def data_split(df, pkl_list, test_size=0.3, validation_size=0.1, random_state=None):

    types = ['vd_', 'vp_', 'vmd_', 'ts_ce_', 'ts_le_', 'ts_pe_', 'n_rad_']

    x_train = {}
    x_validation = {}
    x_test = {}

    # Split into training and test sets
    train_set, test_set = train_test_split(pkl_list, test_size=test_size, random_state=random_state)

    remaining_data_size = 1.0 - test_size

    validation_size = validation_size*remaining_data_size

    # Further split the training set into training and validation sets
    train_set, validation_set = train_test_split(train_set, test_size=validation_size, random_state=random_state)

    train_idx = df[df['haim_id'].isin(train_set)]['haim_id'].tolist()
    validation_idx = df[df['haim_id'].isin(validation_set)]['haim_id'].tolist()
    test_idx = df[df['haim_id'].isin(test_set)]['haim_id'].tolist()

    for t in types:
        
        x_train[t] = df[df['haim_id'].isin(train_idx)].filter(regex='^'+t).values
        x_validation[t] = df[df['haim_id'].isin(validation_idx)].filter(regex='^'+t).values
        x_test[t] = df[df['haim_id'].isin(test_idx)].filter(regex='^'+t).values

    y_train = df[df['haim_id'].isin(train_idx)]['y'].values
    y_train = [' ###ANSWER: No' if value == 0 else ' ###ANSWER: Yes' for value in y_train]
    y_validation = df[df['haim_id'].isin(validation_idx)]['y'].values
    y_validation = [' ###ANSWER: No' if value == 0 else ' ###ANSWER: Yes' for value in y_validation]
    y_test = df[df['haim_id'].isin(test_idx)]['y'].values
    y_test = [' ###ANSWER: No' if value == 0 else ' ###ANSWER: Yes' for value in y_test]

    return x_train, x_validation, x_test, y_train, y_validation, y_test

In [10]:
x_train, x_val, x_test, y_train, y_val, y_test, = data_split(df, pkl_list)

In [43]:
# EMBEDDING SIZES
# vd = 1024
# vp = 18
# vmd = 1024
# ts_ce = 99
# ts_le = 242
# ts_pe = 110
# n_rad = 768

110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110
110


Creating Dataset and DataLoader

Training functions

Experiments

In [11]:
# dataset for generative

class CustomDataset(Dataset):
    def __init__(self, embedding, labels):
        self.labels = labels
        self.embedding = embedding
        self.instruction = "###INSTRUCTION: Did this patient stay longer than 48 h? ###MODALITY: "

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        label = self.labels[idx]
        emb = self.embedding[idx]
        inst = self.instruction
        sample = {"Emb": emb, "Class": label, 'Inst': inst}
        #print(sample)
        return sample

In [12]:
# collate_batch for generative

def collate_batch(batch):
     
    emb_list, classes, instructions = [], [], []
    for thing in batch:
        emb_list.append(thing['Emb'])
        classes.append(tokenizer(thing['Class'], return_tensors="pt"))
        instructions.append(tokenizer(thing['Inst'], return_tensors="pt"))
    text = torch.tensor(emb_list)
    classes = torch.cat([item['input_ids'] for item in classes], dim=0)
    instructions = torch.cat([item['input_ids'] for item in instructions], dim=0)
    return text, instructions, classes

In [31]:
bsz = 8

Trainset = CustomDataset(x_train['ts_ce_'].tolist(), y_train)
Valset = CustomDataset(x_val['ts_ce_'].tolist(), y_val)
Testset = CustomDataset(x_test['ts_ce_'].tolist(), y_test)

Train_loader = DataLoader(Trainset, batch_size=bsz, collate_fn=collate_batch)
Val_loader = DataLoader(Valset, batch_size=bsz, collate_fn=collate_batch)
Test_loader = DataLoader(Testset, batch_size=bsz, collate_fn=collate_batch)

In [32]:
n = 0
for batch_index, (x, y, z) in enumerate(Train_loader, 1):
    print(f"Batch {batch_index} - Inputs: {x}, Labels: {y}, Instructions {z}")
    n += 1
    if n == 5:
        break

Batch 1 - Inputs: tensor([[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00

In [16]:
# SECOND APPROACH GEN TRYOUT

def output_to_label(logits):
    probs = torch.softmax(logits, dim=-1)
    predicted_token_id = torch.argmax(probs, dim=-1)
    return predicted_token_id

def left_padding(concatenated_emb, labels, device):
    padded_labels = []
    for concatenated_emb_item, label in zip(concatenated_emb, labels):
        prompt_length = concatenated_emb_item.size(0) - label.size(0)
        padded_label = torch.cat([torch.full((prompt_length,), -100, device=device), label.to(device)])
        padded_labels.append(padded_label)
    padded_labels = torch.stack(padded_labels, dim=0)  # Stack padded labels into a batch tensor
    return padded_labels

def masking(logits, padded_target):
    mask = (padded_target != -100)
    logits_masked = logits[mask]
    targets_masked = padded_target[mask]
    return logits_masked, targets_masked

def train_epoch(model, gemma, optimizer, loss_fn, train_loader, device):
    # Train:
    gemma.train()
    train_loss_batches, train_acc_batches = [], []
    num_batches = len(train_loader)
    embedding_matrix = gemma.get_input_embeddings().weight
    for batch_index, (mod, inst, label) in enumerate(train_loader, 1):
        mod_embeddings = model(mod.to(device))
        inst_list = [embedding_matrix[token_id] for token_id in inst.to(dtype=torch.long)]
        label_list = [embedding_matrix[token_id] for token_id in label.to(dtype=torch.long)]
        inst_embeddings = torch.stack(inst_list)
        label_embeddings = torch.stack(label_list)
        optimizer.zero_grad()

        conc_emb = torch.cat([inst_embeddings.to(dtype=torch.float16), mod_embeddings, label_embeddings.to(dtype=torch.float16)], dim=1).to(device)
        padded_target = left_padding(conc_emb, label, device)
        output = gemma(inputs_embeds=conc_emb.to(dtype=torch.float16), labels=padded_target)

        logits = output['logits'] #torch.Size([1, 31, 256000])
        logits_masked, targets_masked = masking(logits, padded_target)

        loss = loss_fn(logits_masked.squeeze(0), targets_masked)
        loss.backward()
        optimizer.step()
        train_loss_batches.append(loss.item())

        hard_preds = output_to_label(logits_masked)
        acc_batch_avg = (hard_preds == targets_masked).float().mean().item()
        train_acc_batches.append(acc_batch_avg)

    return model, train_loss_batches, train_acc_batches

def validate(model, gemma, loss_fn, val_loader, device):
    val_loss_cum = 0
    val_acc_cum = 0
    gemma.eval()
    embedding_matrix = gemma.get_input_embeddings().weight
    with torch.no_grad():
        for batch_index, (mod, inst, label) in enumerate(val_loader, 1):
            mod_embeddings = model(mod.to(device))
            inst_list = [embedding_matrix[token_id] for token_id in inst.to(dtype=torch.long)]
            label_list = [embedding_matrix[token_id] for token_id in label.to(dtype=torch.long)]
            inst_embeddings = torch.stack(inst_list)
            label_embeddings = torch.stack(label_list)

            conc_emb = torch.cat([inst_embeddings.to(dtype=torch.float16), mod_embeddings, label_embeddings.to(dtype=torch.float16)], dim=1).to(device)
            padded_target = left_padding(conc_emb, label, device)

            output = gemma(inputs_embeds=conc_emb.to(dtype=torch.float16), labels=padded_target)

            logits = output['logits'] #torch.Size([1, 31, 256000])
            logits_masked, targets_masked = masking(logits, padded_target)


            loss = loss_fn(logits_masked.squeeze(0), targets_masked)
            val_loss_cum += loss.item()
            hard_preds = output_to_label(logits_masked)
            acc_batch_avg = (hard_preds == targets_masked).float().mean().item()
            val_acc_cum += acc_batch_avg
    return val_loss_cum/len(val_loader), val_acc_cum/len(val_loader)

In [17]:
def training_loop(model, gemma, optimizer, loss_fn, train_loader, val_loader, num_epochs):
    print("Starting training")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    gemma.to(device)
    #for name, param in gemma.named_parameters():
    #    print(f"{name}: {param.device}")
    train_losses, train_accs, val_losses, val_accs = [], [], [], []

    for epoch in range(1, num_epochs+1):
        model, train_loss, train_acc = train_epoch(model, gemma,
                                                   optimizer,
                                                   loss_fn,
                                                   train_loader,
                                                   device)
        val_loss, val_acc = validate(model, gemma, loss_fn, val_loader, device)
        print(f"Epoch {epoch}/{num_epochs}: "
              f"Train loss: {sum(train_loss)/len(train_loss):.3f}, "
              f"Train acc.: {sum(train_acc)/len(train_acc):.3f}, "
              f"Val. loss: {val_loss:.3f}, "
              f"Val. acc.: {val_acc:.3f}")
        train_losses.extend(train_loss)
        train_accs.extend(train_acc)
        val_losses.append(val_loss)
        val_accs.append(val_acc)
    return gemma, train_losses, train_accs, val_losses, val_accs

In [18]:
def get_projector(size):
    if size == 'small':
        model = ProjectionNNsmall()
    elif size == 'medium':
        model = ProjectionNNmedium()
    elif size == 'large':
        model = ProjectionNN()
    return model

In [33]:
model = get_projector(size='medium')
print(model)
optimizer = torch.optim.Adam(gemma.parameters(), lr=0.001)
loss_fn = torch.nn.CrossEntropyLoss()
num_epochs = 5
fine_tuned, train_losses, train_accs, val_losses, val_accs = training_loop(model, gemma, optimizer, loss_fn, Train_loader, Val_loader, num_epochs)

#torch.save(fine_tuned, 'finetuned.pth')

#with open('train_losses.pkl', 'wb') as f1:
#    pickle.dump(train_losses, f1)

#with open('train_accs.pkl', 'wb') as f2:
#    pickle.dump(train_accs, f2)

#with open('val_losses.pkl', 'wb') as f3:
#    pickle.dump(val_losses, f3)

#with open('val_accs.pkl', 'wb') as f4:
#    pickle.dump(val_accs, f4)

ProjectionNNmedium(
  (fc1): Linear(in_features=99, out_features=50, bias=True)
  (relu): ReLU()
  (fc2): Linear(in_features=50, out_features=12288, bias=True)
)
Starting training


KeyboardInterrupt: 