# Notebook for testing mutlimodal capability of Gemma

In [3]:
import pickle
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from focal_loss.focal_loss import FocalLoss
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

# Filepath to embeddings
fname = "/mnt/mimic/data/HAIM/mimic_extras/embeddings.csv"

# YES-TOKEN: 3276
# NO-TOKEN: 956

In [4]:
quantization_config = BitsAndBytesConfig(load_in_4bit=True, 
                                         bnb_4bit_use_double_quant=True,
                                         bnb_4bit_quant_type="nf4",
                                         bnb_4bit_compute_dtype=torch.bfloat16)

tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")
gemma = AutoModelForCausalLM.from_pretrained("google/gemma-2b", device_map="auto", quantization_config=quantization_config)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [5]:
embedding_size = 1024
projection_size = 1

class ProjectionNN(nn.Module):
    def __init__(self):
        super(ProjectionNN, self).__init__()

        # Architecture
        self.fc1 = nn.Linear(embedding_size, 128)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(128, 2048 * projection_size)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = x.view(-1,projection_size,2048)
        return x


### Load and pre-process data

In [6]:
df = pd.read_csv(fname)

condition_death_small48 = (df['img_length_of_stay'] < 48) & (df['death_status'] == 1)
condition_alive_big48 = (df['img_length_of_stay'] >= 48) & (df['death_status'] == 0)
condition_death_big48 = (df['img_length_of_stay'] >= 48) & (df['death_status'] == 1)

y = [0]*len(df)
for i, condition in enumerate(condition_death_small48):
    if condition:
        y[i] = 1

# Use .loc to avoid SettingWithCopyWarning
#df.loc[condition_death_small48, 'y'] = 1
#df.loc[condition_alive_big48, 'y'] = 0
#df.loc[condition_death_big48, 'y'] = 0

In [7]:
vd_cols = df.filter(regex='^vd_')
y_col = pd.Series(y, name='y')
haim_col = df[['haim_id']]
df = pd.concat([haim_col, vd_cols, y_col], axis=1)

pkl_list = df['haim_id'].unique().tolist()

print(df.head())


   haim_id      vd_0      vd_1      vd_2      vd_3      vd_4      vd_5  \
0     6514  0.000000  0.102385  0.188977  0.007367  0.219433  0.000106   
1     6514  0.000399  0.063669  0.297278  0.007873  0.288133  0.000000   
2     6515  0.000000  0.073280  0.390735  0.007879  0.094356  0.006252   
3     6515  0.000000  0.003337  0.084882  0.008524  0.030514  0.000936   
4     6515  0.000121  0.098648  0.514754  0.001866  0.211975  0.011927   

       vd_6      vd_7      vd_8  ...   vd_1015   vd_1016   vd_1017   vd_1018  \
0  0.074859  0.017974  0.138016  ...  0.010239  0.000589  0.000743  0.102930   
1  0.099269  0.004799  0.215243  ...  0.000000  0.013072  0.000000  0.078393   
2  0.113489  0.021230  0.324026  ...  0.173980  0.009676  0.095614  0.052150   
3  0.242137  0.027981  0.025548  ...  0.071969  0.000301  0.142212  0.017643   
4  0.081207  0.010555  0.364878  ...  0.204686  0.013269  0.134133  0.044195   

    vd_1019   vd_1020   vd_1021   vd_1022   vd_1023  y  
0  0.008906  0.00

### Setup functions for training

*Dataset class*

In [8]:
class CustomDataset(Dataset):
    def __init__(self, vectors, labels):
        self.vectors = vectors
        self.labels = labels
    
    def __len__(self):
        return len(self.vectors)
    
    def __getitem__(self, index):
        vector = torch.tensor(self.vectors[index]).float()
        label = torch.from_numpy(np.array([self.labels[index]]))
        return vector, label.squeeze()

*Data splitter*

In [9]:
def data_split(df, pkl_list):
    train_id, val_id = train_test_split(pkl_list, test_size=0.3, random_state=42)
    val_id, test_id = train_test_split(val_id, test_size=0.25, random_state=42)

    train_idx = df[df['haim_id'].isin(train_id)]['haim_id'].tolist()
    val_idx = df[df['haim_id'].isin(val_id)]['haim_id'].tolist()
    test_idx = df[df['haim_id'].isin(test_id)]['haim_id'].tolist()

    x_train = df[df['haim_id'].isin(train_idx)].drop(['haim_id','y'],axis=1).values
    x_val = df[df['haim_id'].isin(val_idx)].drop(['haim_id','y'],axis=1).values
    x_test = df[df['haim_id'].isin(test_idx)].drop(['haim_id','y'],axis=1).values

    y_train = df[df['haim_id'].isin(train_idx)]['y'].values
    y_val = df[df['haim_id'].isin(val_idx)]['y'].values
    y_test = df[df['haim_id'].isin(test_idx)]['y'].values

    return x_train, x_val, x_test, y_train, y_val, y_test

*Train/Val funcs (needs to be updated)*

In [10]:
def custom_output(emb, gemma):
    outputs = gemma(inputs_embeds=emb)
    noyes = [956, 3276]
    logits = outputs['logits']
    logits = logits[:,-1:,noyes].mean(dim=1)
    probs = torch.softmax(logits, dim=-1)
    return probs

def output_to_label(logits):
    probs = torch.softmax(logits, dim=-1)
    predicted_token_id = torch.argmax(probs, dim=-1)
    return predicted_token_id

def select_random_subset(data, subset_fraction=0.1):
    num_samples = int(len(data) * subset_fraction)
    indices = np.random.choice(len(data), num_samples, replace=False)
    subset = data[indices]
    return subset
    
def train_epoch(model, gemma, optimizer, loss_fn, train_loader, device):
    # Train:
    model.train()
    train_loss_batches, train_acc_batches = [], []
    for batch_index, (x, y) in enumerate(train_loader, 1):
        inputs, labels = x.to(device), y.to(device)

        optimizer.zero_grad()

        emb = model.forward(inputs).to(torch.float16)
        #word_embs_extended = word_embs.repeat(len(inputs),1,1).detach()

        #concatted = torch.cat((word_embs_extended, emb), dim=1).to(torch.float16)
        logits = custom_output(emb, gemma).float()
        
        loss = loss_fn(logits, labels.long())
        loss.backward()
        optimizer.step()
        train_loss_batches.append(loss.item())

        hard_preds = output_to_label(logits)
        acc_batch_avg = (hard_preds == labels).float().mean().item()
        train_acc_batches.append(acc_batch_avg)

    return model, train_loss_batches, train_acc_batches

def validate(model, gemma, loss_fn, val_loader, device):
    val_loss_cum = 0
    val_acc_cum = 0
    model.eval()
    with torch.no_grad():
        for batch_index, (x, y) in enumerate(val_loader, 1):
            inputs, labels = x.to(device), y.to(device)

            emb = model.forward(inputs).to(torch.float16)
            #word_embs_extended = word_embs.repeat(len(inputs),1,1).detach()

            #concatted = torch.cat((word_embs_extended, emb), dim=1).to(torch.float16)
            logits = custom_output(emb, gemma)

            batch_loss = loss_fn(logits, labels.long())
            val_loss_cum += batch_loss.item()
            hard_preds = output_to_label(logits)
            acc_batch_avg = (hard_preds == labels).float().mean().item()
            val_acc_cum += acc_batch_avg
    return val_loss_cum/len(val_loader), val_acc_cum/len(val_loader)

*Training framework*

In [11]:
def training_loop(model, gemma, optimizer, loss_fn, train_loader, val_loader, num_epochs):
    print("Starting training")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    train_losses, train_accs, val_losses, val_accs = [], [], [], []

    for epoch in range(1, num_epochs+1):
        model, train_loss, train_acc = train_epoch(model,
                                                   gemma,
                                                   optimizer,
                                                   loss_fn,
                                                   train_loader,
                                                   device)
        val_loss, val_acc = validate(model, gemma, loss_fn, val_loader, device)
        print(f"Epoch {epoch}/{num_epochs}: "
              f"Train loss: {sum(train_loss)/len(train_loss):.3f}, "
              f"Train acc.: {sum(train_acc)/len(train_acc):.3f}, "
              f"Val. loss: {val_loss:.3f}, "
              f"Val. acc.: {val_acc:.3f}")
        train_losses.extend(train_loss)
        train_accs.extend(train_acc)
        val_losses.append(val_loss)
        val_accs.append(val_acc)
    return model, train_losses, train_accs, val_losses, val_accs

*Main*

In [12]:
input_text = "Based on the following image, output yes if the patient is likely to die and no otherwise."
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
word_embs = gemma.get_input_embeddings().weight[input_ids.input_ids].to("cuda")

In [None]:
x_train, x_val, _, y_train, y_val, _ = data_split(df, pkl_list)
w0 = sum(y_train == 0)
w1 = sum(y_train == 1)
print('sum 0 labels:', w0)
print('sum 1 labels:', w1)

In [None]:
w0w = len(y_train)/(2*sum(y_train == 0))
w1w = len(y_train)/(2*sum(y_train == 1))
print('weight class 0:', w0w)
print('weight class 1:', w1w)

In [13]:
batch_size = 8
x_train, x_val, _, y_train, y_val, _ = data_split(df, pkl_list)

np.random.seed(42)

x_train_small = select_random_subset(x_train)
y_train_small = select_random_subset(y_train)

x_val_small = select_random_subset(x_val)
y_val_small = select_random_subset(y_val)

train_set = CustomDataset(x_train_small, y_train_small)
val_set = CustomDataset(x_val_small, y_val_small)

train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=5)
val_loader = DataLoader(val_set, batch_size=batch_size, num_workers=5)

w0 = len(y_train_small)/(2*sum(y_train_small == 0))
w1 = len(y_train_small)/(2*sum(y_train_small == 1))
weights = torch.tensor([w0, w1], dtype = torch.float).to("cuda")

model = ProjectionNN()
optimizer = optim.Adam(model.parameters(), lr=0.001)
#loss_fn = nn.CrossEntropyLoss(weight=weights)
loss_fn = FocalLoss(gamma=2)

num_epochs = 10

fine_tuned, train_losses, train_accs, val_losses, val_accs = training_loop(model, gemma, optimizer, loss_fn, train_loader, val_loader, num_epochs)

torch.save(fine_tuned, 'results_small_focal_2_001/finetuned.pth')

with open('results_small_focal_2_001/train_losses.pkl', 'wb') as f1:
    pickle.dump(train_losses, f1)

with open('results_small_focal_2_001/train_accs.pkl', 'wb') as f2:
    pickle.dump(train_accs, f2)

with open('results_small_focal_2_001/val_losses.pkl', 'wb') as f3:
    pickle.dump(val_losses, f3)

with open('results_small_focal_2_001/val_accs.pkl', 'wb') as f4:
    pickle.dump(val_accs, f4)



Starting training
Epoch 1/10: Train loss: 0.078, Train acc.: 0.970, Val. loss: 0.038, Val. acc.: 0.975
Epoch 2/10: Train loss: 0.042, Train acc.: 0.975, Val. loss: 0.038, Val. acc.: 0.975
Epoch 3/10: Train loss: 0.041, Train acc.: 0.974, Val. loss: 0.035, Val. acc.: 0.975
Epoch 4/10: Train loss: 0.039, Train acc.: 0.975, Val. loss: 0.042, Val. acc.: 0.975


KeyboardInterrupt: 

# Testing out gemma instruct on text generation

In [None]:
train, _, _, _ = data_split(df, pkl_list)

In [None]:
model = ProjectionNN().to('cuda')
emb = torch.tensor(train[0]).float().to(device='cuda')
projected = model(emb).to(device='cuda').to(torch.float16)
concatted = torch.cat((word_embs, projected), dim=1).to(torch.float16)
test = custom_output(concatted,gemma)
print(torch.softmax(test, dim=-1))