In [None]:
import sys
path = '/gpfs/commons/groups/gursoy_lab/mstoll/'
sys.path.append(path)

import numpy as np
import pandas as pd
import torch.nn as nn
import torch
import pickle
import time
import os
import shutil
import shap
import seaborn as sns
from torch.utils.data import DataLoader
from torch.nn import functional as F
from tqdm.auto import tqdm
from torch.optim.lr_scheduler import LambdaLR, LinearLR, SequentialLR
from functools import partial
from sklearn.metrics import f1_score, accuracy_score
from sklearn.calibration import calibration_curve


from codes.models.metrics import calculate_roc_auc
from codes.models.data_form.DataForm import DataTransfo_1SNP, PatientList, Patient
from codes.models.Transformers.dic_model_versions import DIC_MODEL_VERSIONS
from codes.tests.TestsClass import TestSet, TrainTransformerModel, TrainModel
import matplotlib.pyplot as plt



In [None]:
train_model =  TrainModel.load_instance_test(242)
model = train_model.model
model.device = 'cpu'
patient_list = train_model.patient_list
data = train_model.dataT
nb_max_distinct_disease = len(patient_list[0].diseases_sentence)
nb_max_distinct_diseases_tot = patient_list.get_nb_distinct_diseases_tot()

frequencies = np.zeros(nb_max_distinct_diseases_tot)
for patient in patient_list:
    frequencies[patient.diseases_sentence] +=1 
frequencies /= len(patient_list)

In [None]:
indices_train, indices_test = train_model.dataT.indices_train, train_model.dataT.indices_test
train_model.patient_list_transformer_train, train_model.patient_list_transformer_test = train_model.patient_list.get_transformer_data(indices_train.astype(int), indices_test.astype(int))
#creation of torch Datasets:
dataloader_train = DataLoader(train_model.patient_list_transformer_train, batch_size=train_model.batch_size, shuffle=True)
dataloader_test = DataLoader(train_model.patient_list_transformer_test, batch_size=train_model.batch_size, shuffle=True)



### Calibration plots

In [None]:
################################## Calibration plots ##############################################################
f1, accuracy, auc_score, loss, proba_avg_zero, proba_avg_one, predicted_probas_list, true_labels_list = model.evaluate(dataloader_train)
predicted_probs_ones = np.array(predicted_probas_list)[:, 1]
true_labels = np.array(true_labels_list)
plt.hist(predicted_probs_ones, bins=100)

In [None]:
prob_true, prob_pred = calibration_curve(true_labels_list, predicted_probs_ones, n_bins=80)
auc = calculate_roc_auc(true_labels_list, predicted_probs_ones)
# Tracer le graphique de calibration
plt.plot(prob_pred, prob_true, marker='o', linestyle='--', label='Calibration Plot')
plt.plot([0, 1], [0, 1], linestyle='--', color='gray', label='Perfectly Calibrated')
plt.xlabel('Mean Predicted Probability')
plt.ylabel('Fraction of Positives')
plt.title('Calibration Plot')
plt.legend()

### Analyse of the number of diseases

In [None]:
list_diseases, list_counts, list_labels = zip(*train_model.patient_list_transformer_train)
list_diseases = np.array(list_diseases)
list_counts = np.array(list_counts)
list_labels = np.array(list_labels)

nb_diseases = np.sum(list_diseases !=0, axis=1)
predicted_labels = (predicted_probs_ones > 0.5).astype(int)

In [None]:
for nb_disease in np.unique(nb_diseases):
    indices = nb_diseases == nb_disease
    true = true_labels[indices]
    pred = predicted_labels[indices]
    acc.append(np.sum(true==pred)/len(true))

### Analyse of the well predicted patients

In [None]:
list_diseases, list_counts, list_labels = zip(*train_model.patient_list_transformer_train)
list_diseases = np.array(list_diseases)
list_counts = np.array(list_counts)
list_labels = np.array(list_labels)

In [None]:
indices_predicted_low = predicted_probs_ones<0.4

In [None]:
patients_selected = list_diseases[indices_predicted_low]

In [None]:
np.sum(indices_predicted_low)

In [None]:
frequencies_selected = np.zeros(nb_max_distinct_diseases_tot)
for disease_sentence in patients_selected:
    frequencies_selected[disease_sentence] += 1
frequencies_selected /= np.sum(indices_predicted_low)

In [None]:
frequencies_ratio = np.abs(frequencies - frequencies_selected)/frequencies
#diff_frequencies =np.max(np.concatenate([frequencies_ratio,  frequencies_ratio**-1], axis=0).reshape(2, len(frequencies_ratio)), axis=0)


In [None]:
#diff_frequencies = diff_frequencies[diff_frequencies!=np.inf]


In [None]:
np.argmax(frequencies_ratio)

In [None]:
frequencies_ratio[15], frequencies[15], frequencies_selected[15]*2330

In [None]:
plt.plot(frequencies_ratio)

In [None]:
patients_selected

In [None]:
np.array(patient_list.patients_list).shape

In [None]:
patients_selected = np.array(patient_list.patients_list)[indices_predicted_low]

### Shap values

In [None]:
batch_sentence, batch_counts, batch_labels = next(iter(dataloader_test))
batch_sentence = batch_sentence
batch_counts = batch_counts
input_data = [batch_sentence.to(torch.float), batch_counts.to(torch.float)]
shap_input = [batch_sentence.to(torch.float)[0].view(1, 122), batch_counts.to(torch.float)[0].view(1, 122)]


In [None]:
model.shap=True

In [None]:
batch_counts

In [None]:
explainer = shap.DeepExplainer(model, input_data)


In [None]:
shap_values = explainer.shap_values(shap_input)


In [None]:
shap_values

### Clustering patients


### Attention visualization

In [None]:
diseases_batch, counts_batch, labels_batch = next(iter(dataloader_test))

In [None]:
logits, probas, x_out = model.forward_decomposed(diseases_batch, counts_batch)


In [None]:
indice = 9
torch.sum(model.padding_mask[indice][0]==1), torch.sum(diseases_batch[indice]!=0)

In [None]:
attention_probas_raw = model.list_attention_layers
attention_probas_list = []
for attention_probas in attention_probas_raw:
    attention_probas_list.append(attention_probas.detach().numpy())


In [None]:
indice_sentence = 1
indice_layer = 0
indice_head = 1

attention_probas = attention_probas_list[indice_layer][indice_sentence][indice_head]
mask = model.padding_mask.detach().numpy()[indice_sentence].astype(bool)
n_real = np.sum(mask[0])

attention_probas_masked = attention_probas[mask].reshape(n_real, n_real)

sns.set(style="whitegrid")
plt.figure(figsize=(20, 16))
sns.heatmap(attention_probas_masked, cmap="YlGnBu", annot=False, fmt=".2f", cbar=True)

# Ajoutez des étiquettes pour les axes
plt.xlabel("Token")
plt.ylabel("Token")
plt.title("Self-Attention Matrix")

# Affichez le plot
plt.show()

In [None]:
## definition of the attention score:
nb_distinct_diseases_tot = patient_list.get_nb_distinct_diseases_tot()
frequencies = np.zeros(nb_distinct_diseases_tot)
attention_score_diseases = np.zeros(nb_distinct_diseases_tot)
for batch_sentence, batch_counts, batch_labels in dataloader_train:
    logits, probas, x_out = model.forward_decomposed(diseases_batch, counts_batch)
    attention_probas_raw = model.list_attention_layers
    attention_probas_list = []
    for attention_probas in attention_probas_raw:
        attention_probas_list.append(attention_probas.detach().numpy())

    for indice_layer in range(train_model.n_layer):
        for indice_head in range(train_model.n_head):
            for indice_sentence in range(len(batch_sentence)):
                sentence = diseases_batch[indice_sentence]
                attention_probas = attention_probas_list[indice_layer][indice_sentence][indice_head]
                mask = model.padding_mask.detach().numpy()[indice_sentence].astype(bool)
                n_real = np.sum(mask[0])
                sentence = sentence[:n_real]
                frequencies[sentence] +=1
                attention_probas_masked = attention_probas[mask].reshape(n_real, n_real)


                attention_score_diseases[sentence] += attention_probas_masked.sum(axis=0)


    
    

In [None]:
attention_score_freq = attention_score_diseases / frequencies
attention_score_freq[np.isnan(attention_score_freq)] = 0

In [None]:
frequencies

In [None]:
plt.plot(frequencies)

In [None]:
plt.plot(attention_score_freq, 'o')

In [None]:
np.argmax(attention_score_freq)

In [None]:
pheno_dicts = train_model.dataT.dicts['id']
pheno_dicts_reverse = {value:key for key, value in pheno_dicts.items()}
name_dicts = train_model.dataT.dicts['name']
name_dicts_reverse = {value:key for key, value in name_dicts.items()}

In [None]:
pheno_dicts_reverse[1373], name_dicts[pheno_dicts_reverse[1373]]

In [None]:
attention_score_freq = attention_score_diseases / frequencies
attention_score_freq[np.isnan(attention_score_freq)] = 0

In [None]:
attention_score_freq

In [None]:
np.argmax(attention_score_freq)

In [None]:
plt.plot(attention_score_freq, 'o')

In [None]:
attention_score_diseases[0]

In [None]:
probas_weights = probas[:, :n_real]

In [None]:
probas_weights[12].sum()

In [None]:
sns.set(style="whitegrid")
plt.figure(figsize=(20, 16))
sns.heatmap(probas_weights.detach().cpu(), cmap="YlGnBu", annot=False, fmt=".2f", cbar=True)

# Ajoutez des étiquettes pour les axes
plt.xlabel("Token")
plt.ylabel("Token")
plt.title("Self-Attention Matrix")

# Affichez le plot
plt.show()

In [None]:
logits, probas, x_out = model.forward_decomposed(diseases_batch, counts_batch)


In [None]:
attention_probas = model.list_attention_layers[1][1][0]

In [None]:
attention_probas.sum()

In [None]:

model.list_attention_layers

In [None]:
train_model.dataT.indices_test, train_model.dataT.indices_train

In [None]:
len(train_model.dataT.indices_test) + len( train_model.dataT.indices_train)

In [None]:
f1, accuracy, auc_score, loss, proba_avg_zero, proba_avg_one, predicted_probas_list, true_labels_list = model.evaluate(dataloader_train)

In [None]:
predicted_probs_ones = np.array(predicted_probas_list)[:, 1]
true_labels = np.array(true_labels_list)

In [None]:
len(predicted_probs_ones)/20

In [None]:
plt.hist(predicted_probs_ones, bins=100)

In [None]:
prob_true, prob_pred = calibration_curve(true_labels_list, predicted_probs_ones, n_bins=80)
auc = calculate_roc_auc(true_labels_list, predicted_probs_ones)
# Tracer le graphique de calibration
plt.plot(prob_pred, prob_true, marker='o', linestyle='--', label='Calibration Plot')
plt.plot([0, 1], [0, 1], linestyle='--', color='gray', label='Perfectly Calibrated')
plt.xlabel('Mean Predicted Probability')
plt.ylabel('Fraction of Positives')
plt.title('Calibration Plot')
plt.legend()

In [None]:
hist, edges = np.histogram(predicted_probs_ones, bins=5)


In [None]:
edges

In [None]:
list_auc = []
for i in range(len(edges[:-1])):
    indices_bin = np.intersect1d(np.where(predicted_probs_ones>=edges[i]) , np.where(predicted_probs_ones<edges[i+1]) )
    bin_probas = predicted_probs_ones[indices_bin]
    bin_labels = true_labels[indices_bin]
    bin_auc = calculate_roc_auc(bin_labels, bin_probas)
    list_auc.append(bin_auc)


In [None]:
accuracy

In [None]:
list_auc

In [None]:
i=0
indices_bin = np.intersect1d(np.where(predicted_probs_ones>=edges[i]) , np.where(predicted_probs_ones<edges[i+1]) )
bin_probas = predicted_probs_ones[indices_bin]
bin_labels = true_labels[indices_bin]
bin_auc = calculate_roc_auc(bin_labels, bin_probas)
list_auc.append(bin_auc)


In [None]:
edges

In [None]:
predicted_labels = (bin_probas>0.5).astype(int)

In [None]:
bin_probas

In [None]:
predicted_labels, bin_labels

In [None]:
list_auc

In [None]:

res = np.zeros(patient_list.get_nb_distinct_diseases_tot())
counts_ok = 0
counts = np.zeros(patient_list.get_nb_distinct_diseases_tot())
for k, patient in enumerate(patient_list):
    if k in train_model.dataT.indices_test:
        diseases_sentence = torch.tensor(patient.diseases_sentence).view(1, nb_max_distinct_disease)
        counts_sentence = torch.tensor(patient.counts_sentence).view(1, nb_max_distinct_disease)
        label_pred_patient = model.predict(diseases_sentence, counts_sentence)
        if label_pred_patient[0]==patient.SNP_label:
            counts_ok += 1
            res[patient.diseases_sentence] = res[patient.diseases_sentence] + 1
        counts[patient.diseases_sentence] = counts[patient.diseases_sentence] + 1

In [None]:
logits, probas, attention_probas, attention_weights = model.forward_decomposed(diseases_sentence, counts_sentence)

In [None]:
plt.plot(probas)

In [None]:
def get_risk_pheno(data, labels, pheno_nb):
    labels_ac = labels[data[:,pheno_nb]==1]
    labels_deac = labels[data[:,pheno_nb]==0]
    proba_mut_ac = np.sum(labels_ac==1)/len(labels_ac)
    proba_mut_deac = np.sum(labels_deac==1)/len(labels_deac)
    ratio  = proba_mut_ac / proba_mut_deac
    return ratio
def get_pred_naive(data, labels, pheno_nb):
    labels_ac = labels[data[:,pheno_nb]==1]
    nb_ones_ac = np.sum(labels_ac==1)
    nb_zeros_ac = np.sum(labels_ac==0)
    return (1 if nb_ones_ac > nb_zeros_ac else 0)
get_risk_pheno = partial(get_risk_pheno, data, labels)
get_pred_naive = partial(get_pred_naive, data, labels)

In [None]:
data

In [None]:
odds_ratios = list(map(get_risk_pheno, phenos))
labels_pred_naive = list(map(get_pred_naive, phenos))

In [None]:
preds = (np.array(probas) < 0.5).astype(int)

In [None]:
1 - np.sum((preds-labels_pred_naive)**2)/1717

In [None]:
data.shape

In [None]:
len(patient_list)

In [None]:
######### correlations with number zeros

In [None]:
labels_res = []
nb_zeros_res = []
for patient in patient_list:
    diseases_sentence = torch.tensor(patient.diseases_sentence).view(1, nb_max_distinct_disease)
    counts_sentence = torch.tensor(patient.counts_sentence).view(1, nb_max_distinct_disease)
    label_pred_patient = model.predict(diseases_sentence, counts_sentence)
    nb_zeros = torch.sum(diseases_sentence==0)
    labels_res.append(label_pred_patient[0].item())
    nb_zeros_res.append(nb_zeros.item())


In [None]:
nb_zeros_res = np.array(nb_zeros_res)
labels_res = np.array(labels_res)

In [None]:
nb_zeros_res, labels_res

In [None]:
zeros = np.unique(nb_zeros_res)
labels = [np.mean(labels_res[nb_zeros_res == nb_zero]) for nb_zero in zeros ]

In [None]:
plt.plot(zeros, labels, 'o')

In [None]:
################## Calibration plot ################
for patient in patient_list:
    

In [None]:
count=0
for patient in patient_list:
    if patient.diseases_sentence[0]==0:
        count +=1

In [None]:
file ='/gpfs/commons/groups/gursoy_lab/mstoll/codes/Data_Files/Pheno/Paul/ukbb_omop_rolled_up_depth_4_closest_ancestor.csv'
df_paul = pd.read_csv(file)

In [None]:
eid_list = df_paul.eid
eid = eid_list[0]

In [None]:
grouped = df_paul.groupby('eid')

In [None]:
df = grouped.get_group(eid)

In [None]:
unique_codes = list(df['concept_id'].values)
occurrences = list(df['condition_occurrence_count'].values)

disease_sentence = [code for code in unique_codes]
counts_sentence = [count for count in occurrences]

In [None]:
unique_codes