# Notebook for testing mutlimodal capability of Gemma

In [27]:
import pickle
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

# Filepath to embeddings
fname = "/mnt/mimic/data/HAIM/mimic_extras/embeddings.csv"

# YES-TOKEN: 3276
# NO-TOKEN: 956

In [None]:
quantization_config = BitsAndBytesConfig(load_in_4bit=True, 
                                         bnb_4bit_use_double_quant=True,
                                         bnb_4bit_quant_type="nf4",
                                         bnb_4bit_compute_dtype=torch.bfloat16)

tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")
gemma = AutoModelForCausalLM.from_pretrained("google/gemma-2b", device_map="auto", quantization_config=quantization_config)

In [None]:
embedding_size = 1024
projection_size = 6

class ProjectionNN(nn.Module):
    def __init__(self):
        super(ProjectionNN, self).__init__()

        # Architecture
        self.fc1 = nn.Linear(embedding_size, 128)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(128, 2048 * projection_size)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = x.view(-1,6,2048)
        return x


### Load and pre-process data

In [None]:
df = pd.read_csv(fname)

condition_death_small48 = (df['img_length_of_stay'] < 48) & (df['death_status'] == 1)
condition_alive_big48 = (df['img_length_of_stay'] >= 48) & (df['death_status'] == 0)
condition_death_big48 = (df['img_length_of_stay'] >= 48) & (df['death_status'] == 1)

y = [0]*len(df)
for i, condition in enumerate(condition_death_small48):
    if condition:
        y[i] = 1

# Use .loc to avoid SettingWithCopyWarning
#df.loc[condition_death_small48, 'y'] = 1
#df.loc[condition_alive_big48, 'y'] = 0
#df.loc[condition_death_big48, 'y'] = 0

In [None]:
vd_cols = df.filter(regex='^vd_')
y_col = pd.Series(y, name='y')
haim_col = df[['haim_id']]
df = pd.concat([haim_col, vd_cols, y_col], axis=1)

pkl_list = df['haim_id'].unique().tolist()

print(df.head())


### Setup functions for training

*Dataset class*

In [None]:
class CustomDataset(Dataset):
    def __init__(self, vectors, labels):
        self.vectors = vectors
        self.labels = labels
    
    def __len__(self):
        return len(self.vectors)
    
    def __getitem__(self, index):
        vector = torch.tensor(self.vectors[index]).float()
        label = torch.from_numpy(np.array([self.labels[index]]))
        return vector, label.squeeze()

*Data splitter*

In [None]:
def data_split(df, pkl_list):
    train_id, test_id = train_test_split(pkl_list, test_size=0.3)
    
    train_idx = df[df['haim_id'].isin(train_id)]['haim_id'].tolist()
    test_idx = df[df['haim_id'].isin(test_id)]['haim_id'].tolist()

    x_train = df[df['haim_id'].isin(train_idx)].drop(['haim_id','y'],axis=1).values
    x_test = df[df['haim_id'].isin(test_idx)].drop(['haim_id','y'],axis=1).values

    y_train = df[df['haim_id'].isin(train_idx)]['y'].values
    y_test = df[df['haim_id'].isin(test_idx)]['y'].values

    return x_train, x_test, y_train, y_test

*Train/Val funcs (needs to be updated)*

In [None]:
def custom_output(emb, gemma):
    outputs = gemma(inputs_embeds=emb)
    noyes = [956, 3276]
    logits = outputs['logits']
    logits = logits[:,1,noyes]
    return logits

def output_to_label(logits):
    probs = torch.softmax(logits, dim=-1)
    predicted_token_id = torch.argmax(probs, dim=-1)
    return predicted_token_id

    
def train_epoch(model, gemma, optimizer, loss_fn, train_loader, device, word_embs):
    # Train:
    model.train()
    train_loss_batches, train_acc_batches = [], []
    for batch_index, (x, y) in enumerate(train_loader, 1):
        inputs, labels = x.to(device), y.to(device)

        optimizer.zero_grad()

        emb = model.forward(inputs)
        word_embs_extended = word_embs.repeat(len(inputs),1,1).detach()

        concatted = torch.cat((word_embs_extended, emb), dim=1).to(torch.float16)
        logits = custom_output(concatted, gemma).float()
        
        loss = loss_fn(logits, labels.long())
        loss.backward()
        optimizer.step()
        train_loss_batches.append(loss.item())

        hard_preds = output_to_label(logits)
        acc_batch_avg = (hard_preds == labels).float().mean().item()
        train_acc_batches.append(acc_batch_avg)

    return model, train_loss_batches, train_acc_batches

def validate(model, gemma, loss_fn, val_loader, device, word_embs):
    val_loss_cum = 0
    val_acc_cum = 0
    model.eval()
    with torch.no_grad():
        for batch_index, (x, y) in enumerate(val_loader, 1):
            inputs, labels = x.to(device), y.to(device)

            emb = model.forward(inputs)
            word_embs_extended = word_embs.repeat(len(inputs),1,1).detach()

            concatted = torch.cat((word_embs_extended, emb), dim=1).to(torch.float16)
            logits = custom_output(concatted, gemma)

            batch_loss = loss_fn(logits, labels.long())
            val_loss_cum += batch_loss.item()
            hard_preds = output_to_label(logits)
            acc_batch_avg = (hard_preds == labels).float().mean().item()
            val_acc_cum += acc_batch_avg
    return val_loss_cum/len(val_loader), val_acc_cum/len(val_loader)

*Training framework*

In [None]:
def training_loop(model, gemma, optimizer, loss_fn, train_loader, val_loader, num_epochs, word_embs):
    print("Starting training")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    train_losses, train_accs, val_losses, val_accs = [], [], [], []

    for epoch in range(1, num_epochs+1):
        model, train_loss, train_acc = train_epoch(model,
                                                   gemma,
                                                   optimizer,
                                                   loss_fn,
                                                   train_loader,
                                                   device,
                                                   word_embs)
        val_loss, val_acc = validate(model, gemma, loss_fn, val_loader, device, word_embs)
        print(f"Epoch {epoch}/{num_epochs}: "
              f"Train loss: {sum(train_loss)/len(train_loss):.3f}, "
              f"Train acc.: {sum(train_acc)/len(train_acc):.3f}, "
              f"Val. loss: {val_loss:.3f}, "
              f"Val. acc.: {val_acc:.3f}")
        train_losses.extend(train_loss)
        train_accs.extend(train_acc)
        val_losses.append(val_loss)
        val_accs.append(val_acc)
    return model, train_losses, train_accs, val_losses, val_accs

*Main*

In [None]:
input_text = "Given this input, is it more likely than not that the patient will die?"
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
word_embs = gemma.get_input_embeddings().weight[input_ids.input_ids].to("cuda")

In [31]:
batch_size = 8
x_train, x_val, y_train, y_val = data_split(df, pkl_list)
x_train_small, x_val_small, y_train_small, y_val_small = data_split(df.iloc[:500], pkl_list)
train_set = CustomDataset(x_train_small, y_train_small)
val_set = CustomDataset(x_val_small, y_val_small)

train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=5)
val_loader = DataLoader(val_set, batch_size=batch_size, num_workers=5)

w0 = 1 - sum(y_train_small == 0)/len(y_train_small)
w1 = 1 - sum(y_train_small == 1)/len(y_train_small)
weights = torch.tensor([w0, w1], dtype = torch.float).to("cuda")

model = ProjectionNN()
optimizer = optim.Adam(model.parameters(), lr=0.001)
loss_fn = nn.CrossEntropyLoss(weight=weights)

num_epochs = 5

fine_tuned, train_losses, train_accs, val_losses, val_accs = training_loop(model, gemma, optimizer, loss_fn, train_loader, val_loader, num_epochs, word_embs)

torch.save(fine_tuned, 'finetuned.pth')

with open('train_losses.pkl', 'wb') as f1:
    pickle.dump(train_losses, f1)

with open('train_accs.pkl', 'wb') as f2:
    pickle.dump(train_accs, f2)

with open('val_losses.pkl', 'wb') as f3:
    pickle.dump(val_losses, f3)

with open('val_accs.pkl', 'wb') as f4:
    pickle.dump(val_accs, f4)



Starting training
Epoch 1/5: Train loss: 0.494, Train acc.: 0.990, Val. loss: 0.386, Val. acc.: 0.984
Epoch 2/5: Train loss: 0.494, Train acc.: 0.990, Val. loss: 0.386, Val. acc.: 0.984
Epoch 3/5: Train loss: 0.376, Train acc.: 0.990, Val. loss: 0.386, Val. acc.: 0.984
Epoch 4/5: Train loss: 0.494, Train acc.: 0.990, Val. loss: 0.386, Val. acc.: 0.984
Epoch 5/5: Train loss: 0.494, Train acc.: 0.990, Val. loss: 0.386, Val. acc.: 0.984


# Testing out gemma instruct on text generation

In [None]:
input_text = "no"
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
print(input_ids.input_ids)
#print(model)
tmp = gemma.get_input_embeddings().weight[input_ids.input_ids]
tmp.to(device='cuda')

conc = torch.cat((tmp,projected), dim=1).to(torch.float16)

outputs = model(inputs_embeds=conc)
noyes = [1294,3553]
logits = outputs['logits']
print(logits)
logits = logits[:,1,noyes]

probs = torch.softmax(logits, dim=-1)
print(probs)
predicted_token_id = torch.argmax(probs, dim=-1)

if predicted_token_id.item() == 0:
    predicted_token_id[0] = 1294
else:
    predicted_token_id[0] = 3553

decoded_token = tokenizer.decode(predicted_token_id[0])
print(decoded_token)



In [None]:
train, _, _, _ = data_split(df, pkl_list)

In [None]:
print(len(np.where(np.isnan(y_train))[0]))

print(len(np.where(np.isnan(df['y']))[0]))

In [None]:
proj = ProjectionNN()

tmp12 = torch.tensor(train[0]).float()

projected = proj(tmp12).to(device='cuda').to(torch.float16)
print(projected.size())

In [None]:
input_text = "Is smoking good for you"
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
test = gemma.get_input_embeddings().weight[input_ids.input_ids]
#print(input_ids)

outputs = gemma.generate(inputs_embeds=test, max_length = 150)

notoken = torch.tensor(956).to(device='cuda')
yestoken = torch.tensor(3276).to(device='cuda')

print(tokenizer.decode(yestoken))
print(tokenizer.decode(outputs[0]))