# Evaluation of projection module

In [None]:
import pickle
import matplotlib.pyplot as plt
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from focal_loss.focal_loss import FocalLoss
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn import metrics
import os
from mlp_utils import *
os.environ["TOKENIZERS_PARALLELISM"] = "false"

# Filepath to embeddings
fname = "/mnt/mimic/data/HAIM/mimic_extras/embeddings.csv"

# YES-TOKEN: 3276
# NO-TOKEN: 956

### Load data and make predictions

In [None]:
quantization_config = BitsAndBytesConfig(load_in_4bit=True, 
                                         bnb_4bit_use_double_quant=True,
                                         bnb_4bit_quant_type="nf4",
                                         bnb_4bit_compute_dtype=torch.bfloat16)

tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")
gemma = AutoModelForCausalLM.from_pretrained("google/gemma-2b", device_map="auto", quantization_config=quantization_config)

In [None]:
model = torch.load('results/results_small_bce_vd/finetuned.pth').to('cuda')
model.eval()

In [None]:
input_text = "Based on the following image, output yes if the patient is likely to die and no otherwise."
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
word_embs = gemma.get_input_embeddings().weight[input_ids.input_ids].to("cuda")

In [None]:
df = pd.read_csv(fname)

condition_death_small48 = (df['img_length_of_stay'] < 48) & (df['death_status'] == 1)
condition_alive_big48 = (df['img_length_of_stay'] >= 48) & (df['death_status'] == 0)
condition_death_big48 = (df['img_length_of_stay'] >= 48) & (df['death_status'] == 1)

y = [0]*len(df)
for i, condition in enumerate(condition_death_small48):
    if condition:
        y[i] = 1

vd_cols = df.filter(regex='^vd_')
y_col = pd.Series(y, name='y')
haim_col = df[['haim_id']]
df = pd.concat([haim_col, vd_cols, y_col], axis=1)

pkl_list = df['haim_id'].unique().tolist()

_, _, x_test, _, _ , labels = data_split(df, pkl_list)

In [None]:
df = pd.read_csv(fname)

Data = DataSplit(df)
Data.split_data('mortality')

X,V,T = Data.get_type('vd_')

x_test = T.values.tolist()
labels = Data.y_test.tolist()

In [None]:
np.random.seed(42)
x_test = select_random_subset(x_test)
labels = select_random_subset(labels)

In [None]:
preds = []

for i, item in enumerate(x_test):
    inputs = torch.tensor(item, dtype=torch.float32).unsqueeze(0).to('cuda')
    emb = model.forward(inputs).to(torch.float16)
    #concatted = torch.cat((word_embs, emb), dim=1).to(torch.float16)
    #emb = emb.squeeze(0)
    outputs = custom_output(emb, gemma)
    preds.append(output_to_label(outputs))
    if i % 100 == 0:
        print('item num ', i)

### Training losses

In [None]:
folder = 'results/results_small_bce_vd'

with open(f"{folder}/train_losses.pkl", "rb") as input_file:
    train_losses = pickle.load(input_file)

with open(f"{folder}/train_accs.pkl", "rb") as input_file:
    train_accs = pickle.load(input_file)

with open(f"{folder}/val_losses.pkl", "rb") as input_file:
    val_losses = pickle.load(input_file)

with open(f"{folder}/val_accs.pkl", "rb") as input_file:
    val_accs = pickle.load(input_file)

In [None]:
x = list(range(len(train_losses)))

repetitions = len(train_losses) // len(val_losses)
extended_val_loss = [element for element in val_losses for _ in range(repetitions)]

remaining_elements = len(train_losses) % len(val_losses)
if remaining_elements > 0:
    extended_val_loss += val_losses[:remaining_elements]

plt.plot(x, train_losses, color='r', label='train_losses')
plt.plot(x, extended_val_loss, color='g', label='val_losses')

plt.title('Loss for training and validation')
plt.ylabel('Loss')
plt.xlabel('batch')

plt.legend()


In [None]:
x = list(range(len(train_accs)))

repetitions = len(train_accs) // len(val_accs)
extended_val_accs = [element for element in val_accs for _ in range(repetitions)]

remaining_elements = len(train_accs) % len(val_accs)
if remaining_elements > 0:
    extended_val_accs += val_accs[:remaining_elements]

plt.plot(x, train_accs, color='r', label='train_losses')
plt.plot(x, extended_val_accs, color='g', label='val_losses')

plt.title('Accuracy for training and validation')
plt.ylabel('Accuracy')
plt.xlabel('batch')

plt.legend()

### Confusion Matrix

In [None]:
numpy_arrays = [t.cpu().numpy() for t in preds]
preds = np.array(numpy_arrays)

conf_matrix = metrics.confusion_matrix(labels, preds)
disp = metrics.ConfusionMatrixDisplay(confusion_matrix=conf_matrix)
disp.plot()

### F1-score and AUC

In [None]:
f1 = metrics.f1_score(labels, preds)
auc = metrics.roc_auc_score(labels, preds)
print('f1: ', f1)
print('auc: ', auc)

### Precision, Recall and Accuracy

In [None]:
precision = metrics.precision_score(labels, preds)
recall = metrics.recall_score(labels, preds)
accuracy = metrics.accuracy_score(labels, preds)
print('precision: ', precision)
print('recall: ', recall)
print('accuracy: ', accuracy)

### ROC-curve, FPR and TPR thresholds

In [None]:
fpr, tpr, thresholds = metrics.roc_curve(labels, preds)
roc_auc = metrics.auc(fpr, tpr)
display = metrics.RocCurveDisplay(fpr=fpr, tpr=tpr, roc_auc=roc_auc)

display.plot()