In [None]:
from torch.utils.data import DataLoader
from torchviz import make_dot
import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
##### Functionnal version with optionnal mask padding and dropouts, see Transformer_V1.ipynb for example
import sys
path = '/gpfs/commons/groups/gursoy_lab/mstoll/'
sys.path.append(path)


### imports
import torch
import torch.nn as nn
import numpy as np
import time
from torch.nn import functional as F
from sklearn.metrics import f1_score, accuracy_score

from codes.models.Transformers.Embedding import EmbeddingPheno
from codes.models.metrics import calculate_roc_auc, calculate_classification_report, calculate_loss, get_proba

### Transformer's instance
# B, S, E, H, HN, MH = Batch_len, Sentence_len, Embedding_len, Head_size, Head number, MultiHead size.

def custom_softmax(input_tensor, dim=-1):
    # Calculer le softmax classique
    softmax_output = F.softmax(input_tensor, dim)

    # Trouver les colonnes avec que des -inf
    nan_columns = torch.all(input_tensor == float('-inf'), dim=dim)

    # Remplacer les valeurs de softmax par 0 pour les colonnes avec que des -inf
    softmax_output[nan_columns] = 0

    return softmax_output

class TransformerGeneModel_V2(nn.Module):
    def __init__(self, pheno_method, Embedding, Head_size, binary_classes, n_head, n_layer, mask_padding=False, padding_token=None, p_dropout=0, device='cpu', loss_version='cross_entropy', gamma=2, alpha=1, instance_size=None, proj_embed=True):
        super().__init__()
       
        self.Embedding_size = Embedding.Embedding_size
        self.binary_classes = binary_classes
        self.instance_size=instance_size
        self.proj_embed = proj_embed
        if not self.proj_embed:
            self.instance_size = self.Embedding_size
        if self.proj_embed:
            self.projection_embed = nn.Linear(self.Embedding_size, self.instance_size)
        self.classes_nb = 2 if self.binary_classes else 3
        self.blocks =PadMaskSequential(*[Block(self.instance_size, n_head=n_head, Head_size=Head_size, p_dropout=p_dropout) for _ in range(n_layer)]) #Block(self.instance_size, n_head=n_head, Head_size=Head_size) 
        self.ln_f = nn.LayerNorm(self.instance_size) # final layer norm
        self.lm_head_logits = nn.Linear(self.instance_size, self.classes_nb) 
        self.lm_head_proba = nn.Linear(self.instance_size,1) # plus one for the probabilities
        self.Embedding = Embedding
        self.mask_padding = mask_padding
        self.padding_token = padding_token
        self.padding_mask = None
        self.device = device
        self.pheno_method = pheno_method
        
        self.loss_version = loss_version
        self.gamma = gamma
        self.alpha = alpha
        self.shap = False
       
        self.diseases_embedding_table = Embedding.distinct_diseases_embeddings
        if self.pheno_method == 'Paul':
            self.counts_embedding_table = Embedding.counts_embeddings

        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
            

    def create_padding_mask(self, diseases_sentence):
        B, S = np.shape(diseases_sentence)[0], np.shape(diseases_sentence)[1]
        mask = torch.where(diseases_sentence==self.padding_token)
        padding_mask_mat = torch.ones((B, S, S), dtype=torch.int)
        padding_mask_mat[mask] = 0
        padding_mask_mat.transpose(-2,-1)[mask] = 0

        
        padding_mask_probas = torch.zeros((B, S))
        padding_mask_probas[mask] = -torch.inf
        padding_mask_probas = padding_mask_probas.view(B, S)
        return padding_mask_mat.to(self.device), padding_mask_probas.to(self.device) # 1 if masked, 0 else

    def set_padding_mask_transformer(self, padding_mask, padding_mask_probas):
        self.padding_mask = padding_mask
        self.padding_mask_probas = padding_mask_probas

    def forward(self, diseases_sentence, counts_diseases, targets=None):

        diseases_sentence = diseases_sentence.to(torch.long)
        counts_diseases = counts_diseases.to(torch.long)

        Batch_len, Sentence_len = diseases_sentence.shape

        diseases_sentence = diseases_sentence.to(self.device)
        counts_diseases = counts_diseases.to(self.device)
        
        if targets!=None:
            targets = targets.to(self.device)

        if self.mask_padding:
            padding_mask, padding_mask_probas = self.create_padding_mask(diseases_sentence)
            self.set_padding_mask_transformer(padding_mask, padding_mask_probas)
            self.blocks.set_padding_mask_sequential(self.padding_mask)

        diseases_sentences_embedded = self.diseases_embedding_table(diseases_sentence) # shape B, S, E

        x = diseases_sentences_embedded 
    
        if self.pheno_method == 'Paul':
            counts_diseases_embedded = self.counts_embedding_table(counts_diseases) # shape B, S, E
            x = x + counts_diseases_embedded # shape B, S, E 
        
        if self.proj_embed:
            x = self.projection_embed(x)
        x = self.blocks(x) # shape B, S, E
        x = self.ln_f(x) # shape B, S, E
        logits = self.lm_head_logits(x) #shape B, S, Classes_Numb, token logits
        weights_logits = self.lm_head_proba(x).view(Batch_len, Sentence_len)
        if self.mask_padding:
            weights_logits = weights_logits + self.padding_mask_probas
        probas = F.softmax(weights_logits) # shape B, S(represent the probas to be chosen)
        #if self.mask_padding:
           # probas = probas * self.padding_mask_probas
        
        logits = (logits.transpose(1, 2) @ probas.view(Batch_len, Sentence_len, 1)).view(Batch_len, self.classes_nb)# (B,Classes_Numb) Weighted Average logits
        loss = calculate_loss(logits, targets, self.loss_version, self.gamma, self.alpha)

        if self.shap:
            return logits[:,0].view(Batch_len, 1)
        
        return logits, loss
    
    def forward_decomposed(self, diseases_sentence, diseases_count):
        self.list_attention_layers = []
        Batch_len, Sentence_len = diseases_sentence.shape
        print('coucou')
        diseases_sentence = diseases_sentence.to(self.device)
        counts_diseases = diseases_count.to(self.device)

        if self.mask_padding:
            padding_mask, padding_mask_probas = self.create_padding_mask(diseases_sentence)
            self.set_padding_mask_transformer(padding_mask, padding_mask_probas)
            self.blocks.set_padding_mask_sequential(self.padding_mask)

        diseases_sentences_embedded = self.diseases_embedding_table(diseases_sentence) # shape B, S, E

        x = diseases_sentences_embedded 
        if self.pheno_method == 'Paul':
            counts_diseases_embedded = self.counts_embedding_table(counts_diseases) # shape B, S, E
            x = x + counts_diseases_embedded # shape B, S, E 
        if self.proj_embed:
            x = self.projection_embed(x)
        x= self.blocks.forward_decompose(x, self.list_attention_layers) # shape B, S, E
        x_out = self.ln_f(x) # shape B, S, E
        logits = self.lm_head_logits(x_out) #shape B, S, Classes_Numb, token logits
        weights_logits = self.lm_head_proba(x).view(Batch_len, Sentence_len)
        if self.mask_padding:
            weights_logits = weights_logits + self.padding_mask_probas
        probas = F.softmax(weights_logits) # shape B, S(represent the probas to be chosen)
        #if self.mask_padding:
           # probas = probas * self.padding_mask_probas
        
        logits = (logits.transpose(1, 2) @ probas.view(Batch_len, Sentence_len, 1))# (B,Classes_Numb) Weighted Average logits
        return logits, probas, x_out
    

    def predict(self, diseases_sentence, diseases_count):
        logits, _ = self(diseases_sentence, diseases_count) # shape B, Classes_nb
        return torch.argmax(logits, dim=1)  # (B,)
        
    def predict_proba(self, diseases_sentence, diseases_count):
        logits, _ = self(diseases_sentence, diseases_count)
        probabilities = F.softmax(logits, dim=1)
        return probabilities
    
    def evaluate(self, dataloader_test):
        print('beginning inference evaluation')
        start_time_inference = time.time()
        predicted_labels_list = []
        predicted_probas_list = []
        true_labels_list = []

        total_loss = 0.
        self.eval()
        with torch.no_grad():
            for batch_sentences, batch_counts, batch_labels in dataloader_test:


                logits, loss = self(batch_sentences, batch_counts,  batch_labels)
                total_loss += loss.item()
                predicted_labels = self.predict(batch_sentences, batch_counts)
                predicted_labels_list.extend(predicted_labels.cpu().numpy())
                predicted_probas = self.predict_proba(batch_sentences, batch_counts)
                predicted_probas_list.extend(predicted_probas.cpu().numpy())
                true_labels_list.extend(batch_labels.cpu().numpy())
        f1 = f1_score(true_labels_list, predicted_labels_list, average='macro')
        accuracy = accuracy_score(true_labels_list, predicted_labels_list)
        auc_score = calculate_roc_auc(true_labels_list, np.array(predicted_probas_list)[:, 1], return_nan=True)
        proba_avg_zero, proba_avg_one = get_proba(true_labels_list, predicted_probas_list)
    
        self.train()
        print(f'end inference evaluation in {time.time() - start_time_inference}s')
        return f1, accuracy, auc_score, total_loss/len(dataloader_test), proba_avg_zero, proba_avg_one, predicted_probas_list, true_labels_list


    def write_embedding(self, writer):
        if self.proj_embed:
            embedding_tensor = self.projection_embed(self.diseases_embedding_table.weight).detach().cpu().numpy()
        else:
            embedding_tensor = self.diseases_embedding_table.weight.detach().cpu().numpy()

        writer.add_embedding(embedding_tensor, metadata=self.Embedding.metadata, metadata_header=["Name","Label"])


class PadMaskSequential(nn.Sequential):
    def __init__(self, *args):
        super(PadMaskSequential, self).__init__(*args)
        self.padding_mask = None

    def set_padding_mask_sequential(self, padding_mask):
        self.padding_mask = padding_mask

    def forward(self, x):
        for module in self:
            module.set_padding_mask_block(self.padding_mask)
            x = module(x)
        return x
    
    def forward_decompose(self, x, list_attention_layers):
        for module in self:
            module.set_padding_mask_block(self.padding_mask)
            x = module.forward_decompose(x, list_attention_layers)
        return x
    
class Block(nn.Module):
    def __init__(self, instance_size, n_head, Head_size, p_dropout):
        super().__init__()
        self.sa = MultiHeadSelfAttention(n_head, Head_size, instance_size, p_dropout)
        self.ffwd = FeedForward(instance_size, p_dropout)
        self.ln1 = nn.LayerNorm(instance_size)
        self.ln2 = nn.LayerNorm(instance_size)
        self.padding_mask = None

    def set_padding_mask_block(self, padding_mask):
        self.padding_mask = padding_mask

    def forward(self, x):
        self.sa.set_padding_mask_sa(self.padding_mask)
        #x = self.ln1(x)
        x = x + self.sa(x)
        x = self.ln1(x)
        x = x + self.ffwd(x)
        x = self.ln2(x)
        return x
    
    def forward_decompose(self, x, list_attention_layers=None):
        self.sa.set_padding_mask_sa(self.padding_mask)
        out_sa, attention_probas = self.sa.forward_decompose(x)
        if list_attention_layers != None:
            list_attention_layers.append(attention_probas)
        x = out_sa + x
        x = x + self.ffwd(x)
        x = self.ln2(x)
        return x

class MultiHeadSelfAttention(nn.Module):
    def __init__(self, n_head, Head_size, instance_size, p_dropout):
        super().__init__()
        self.qkv_network = nn.Linear(instance_size, Head_size * 3, bias = False) #group the computing of the nn.Linear for q, k and v, shape 
        self.proj = nn.Linear(Head_size, instance_size)
        self.attention_dropout = nn.Dropout(p_dropout)
        self.resid_dropout = nn.Dropout(p_dropout)

        self.multi_head_size = Head_size // n_head
        self.flash = False
        self.n_head = n_head
        self.Head_size = Head_size
        self.padding_mask = None

    def set_padding_mask_sa(self, padding_mask):
        self.padding_mask = padding_mask

        #self.dropout = nn.Dropout(dropout)
    def forward(self, x):
        # x of size (B, S, E)
        Batch_len, Sentence_len, _ = x.shape
        q, k, v  = self.qkv_network(x).split(self.Head_size, dim=2) #q, k, v of shape (B, S, H)
        # add a dimension to compute the different attention heads separately
        q_multi_head = q.view(Batch_len, Sentence_len, self.n_head, self.multi_head_size).transpose(1, 2) #shape B, HN, S, MH
        k_multi_head = k.view(Batch_len, Sentence_len, self.n_head, self.multi_head_size).transpose(1, 2)
        v_multi_head = v.view(Batch_len, Sentence_len, self.n_head, self.multi_head_size).transpose(1, 2)

        if self.flash:
            # efficient attention using Flash Attention CUDA kernels
            out = torch.nn.functional.scaled_dot_product_attention(q_multi_head, k_multi_head, v_multi_head, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True)
        else:    
            attention_weights = (q_multi_head @ k_multi_head.transpose(-2, -1))/np.sqrt(self.multi_head_size) # shape B, S, S
            ### padding mask #####
            if self.padding_mask != None:
                padding_mask_weights = -(1-self.padding_mask)*(10**10)
                attention_weights = (attention_weights.transpose(0, 1)+padding_mask_weights).transpose(0, 1)
            #print(f'wei0={attention_weights}')
            attention_probas = F.softmax(attention_weights, dim=-1) # shape B, S, S
            if self.padding_mask != None:
                attention_probas = (attention_probas.transpose(0, 1)*self.padding_mask).transpose(0, 1)
           # attention_probas[attention_probas.isnan()]=0
            attention_probas = self.attention_dropout(attention_probas)

            #print(f'wei1={attention_probas}')
            #attention_probas = self.dropout(attention_probas)
            ## weighted aggregation of the values
            out = attention_probas @ v_multi_head # shape B, S, S @ B, S, MH = B, S, MH
        out = out.transpose(1, 2).contiguous().view(Batch_len, Sentence_len, self.Head_size)
        out = self.proj(out)
        out = self.resid_dropout(out)
        return out        
    
    def forward_decompose(self, x):
        # x of size (B, S, E)
        Batch_len, Sentence_len, _ = x.shape
        q, k, v  = self.qkv_network(x).split(self.Head_size, dim=2) #q, k, v of shape (B, S, H)
        # add a dimension to compute the different attention heads separately
        q_multi_head = q.view(Batch_len, Sentence_len, self.n_head, self.multi_head_size).transpose(1, 2) #shape B, HN, S, MH
        k_multi_head = k.view(Batch_len, Sentence_len, self.n_head, self.multi_head_size).transpose(1, 2)
        v_multi_head = v.view(Batch_len, Sentence_len, self.n_head, self.multi_head_size).transpose(1, 2)

        if self.flash:
            # efficient attention using Flash Attention CUDA kernels
            out = torch.nn.functional.scaled_dot_product_attention(q_multi_head, k_multi_head, v_multi_head, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True)
        else:    
            attention_weights = (q_multi_head @ k_multi_head.transpose(-2, -1))/np.sqrt(self.multi_head_size) # shape B, S, S
            ### padding mask #####
            if self.padding_mask != None:
                padding_mask_weights = -(1-self.padding_mask)*(10**10)
                attention_weights = (attention_weights.transpose(0, 1)+padding_mask_weights).transpose(0, 1)


            attention_probas = F.softmax(attention_weights, dim=-1) # shape B, S, S
            if self.padding_mask != None:
                attention_probas = (attention_probas.transpose(0, 1)*self.padding_mask).transpose(0, 1)
           # attention_probas[attention_probas.isnan()]=0
            attention_probas_dropout = self.attention_dropout(attention_probas)


            #print(f'wei1={attention_probas}')
            #attention_probas = self.dropout(attention_probas)
            ## weighted aggregation of the values
            out = attention_probas_dropout @ v_multi_head # shape B, S, S @ B, S, MH = B, S, MH
        out = out.transpose(1, 2).contiguous().view(Batch_len, Sentence_len, self.Head_size)
        out = self.proj(out)
        #out = self.resid_dropout(out)
        return out, attention_probas   

class FeedForward(nn.Module):
    """ a simple linear layer followed by a non-linearity"""
    def __init__(self, instance_size, p_dropout):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear( instance_size, 4 * instance_size),
            nn.ReLU(),
            nn.Linear(4 * instance_size, instance_size),
            nn.Dropout(p_dropout)
        )

    def forward(self, x):
        return self.net(x)
        


In [None]:
import numpy as np
import torch
from sklearn.metrics import roc_auc_score, classification_report
from torch.nn import functional as F

def calculate_roc_auc(y_true, predicted_probabilities, return_nan=False):
    # Check the number of unique classes
    num_classes = len(np.unique(y_true))

    # Check if there is more than one class
    if num_classes > 1:
        # Compute ROC-AUC score
        roc_auc = roc_auc_score(y_true, predicted_probabilities)
        return roc_auc
    else:
        # Return NaN if there is only one class
        if return_nan:
            return np.nan
        else:
            print("Only one class present in y_true. ROC AUC score is not defined in that case.")
def calculate_classification_report(true_labels_list, predicted_labels_list, return_nan=True):
    # Check the number of unique classes
    num_classes = len(np.unique(true_labels_list))

    # Check if there is more than one class
    if num_classes > 1:
        # Compute ROC-AUC score
        report = classification_report(true_labels_list, predicted_labels_list)
        return report
    else:
        # Return NaN if there is only one class
        if return_nan:
            return np.nan
        else:
            print("Only one class present in y_true. ROC AUC score is not defined in that case.")

def get_proba(true_labels_list, predicted_probas_list):
    avg_proba_zero = np.mean(np.array(predicted_probas_list)[:,0][np.array(true_labels_list)==0])
    avg_proba_one = np.mean(np.array(predicted_probas_list)[:,1][np.array(true_labels_list)==1])
    return avg_proba_zero, avg_proba_one

def calculate_loss(logits, targets=None, loss_type='cross_entropy', gamma=None, alpha=None):
    if targets is None:
            loss = None
    else:
        #target : shape B,
        
        if loss_type == 'cross_entropy':
            cross_entropy = F.cross_entropy(logits, targets)
            return cross_entropy
        elif loss_type == 'focal_loss':
            alphas = ((1 - targets) * (alpha-1)).to(torch.float) + 1
            probas = F.softmax(logits)
            certidude_factor =  (1-probas[[range(probas.shape[0]), targets]])**gamma * alphas
            cross_entropy = F.cross_entropy(logits, targets, reduce=False)
            loss = torch.dot(cross_entropy, certidude_factor)
            return loss
        elif loss_type == 'predictions':
            probas = F.softmax(logits)
            predictions = (probas[:,1] > 0.5).to(int)
            return torch.tensor(torch.sum((predictions-targets)**2)/len(predictions), requires_grad=True)

In [None]:
import os
import sys
sys.path.append('/gpfs/commons/groups/gursoy_lab/pmeddeb/phenotype_embedding')
import time
import pickle
import numpy as np
import pandas as pd
import torch
import torch.nn as nn



class EmbeddingPheno(nn.Module):
    def __init__(self, method=None, vocab_size=None, max_count_same_disease=None, Embedding_size=None, rollup_depth=4, freeze_embed=False, dicts=None):
        super(EmbeddingPheno, self).__init__()

        self.dicts = dicts
        self.rollup_depth = rollup_depth
        self.nb_distinct_diseases_patient = vocab_size
        self.Embedding_size = Embedding_size
        self.max_count_same_disease = None
        self.metadata = None

        if self.dicts != None:
            id_dict = self.dicts['id']
            name_dict = self.dicts['name']
            cat_dict = self.dicts['cat']
            codes = list(id_dict.keys())
            self.metadata = [[name_dict[code], cat_dict[code]] for code in codes]

        
        if method == None:
            self.distinct_diseases_embeddings = nn.Embedding(vocab_size, Embedding_size)
            self.counts_embeddings = nn.Embedding(max_count_same_disease, Embedding_size)
            torch.nn.init.normal_(self.distinct_diseases_embeddings.weight, mean=0.0, std=0.02)
            torch.nn.init.normal_(self.counts_embeddings.weight, mean=0.0, std=0.02)

        elif method == 'Abby':
            embedding_file_diseases = f'/gpfs/commons/groups/gursoy_lab/mstoll/codes/Data_Files/Embeddings/Abby/embedding_abby_no_1_diseases.pth'
            pretrained_weights_diseases = torch.load(embedding_file_diseases)
            self.Embedding_size = pretrained_weights_diseases.shape[1]

            self.distinct_diseases_embeddings = nn.Embedding.from_pretrained(pretrained_weights_diseases, freeze=freeze_embed)
            self.counts_embeddings = nn.Embedding(max_count_same_disease, self.Embedding_size)



        elif method=='Paul':
            embedding_file_diseases = f'/gpfs/commons/groups/gursoy_lab/mstoll/codes/Data_Files/Embeddings/Paul_Glove/glove_UKBB_omop_rollup_closest_depth_{self.rollup_depth}_no_1_diseases.pth'
            pretrained_weights_diseases = torch.load(embedding_file_diseases)
            self.Embedding_size = pretrained_weights_diseases.shape[1]

            self.distinct_diseases_embeddings = nn.Embedding.from_pretrained(pretrained_weights_diseases, freeze=freeze_embed)
            self.counts_embeddings = nn.Embedding(max_count_same_disease, self.Embedding_size)
    def write_embedding(self, writer):
            embedding_tensor = self.distinct_diseases_embeddings.weight.data.detach().cpu().numpy()
            writer.add_embedding(embedding_tensor, metadata=self.metadata, metadata_header=["Name","Label"])


class EmbeddingPhenoCat(nn.Module):
    def __init__(self, method=None, Embedding_size=10, rollup_depth=4, freeze_embed=False, dic_embedding_cat_params={}):
        super(EmbeddingPheno, self).__init__()

        self.rollup_depth = rollup_depth
        self.Embedding_size = Embedding_size
        self.max_count_same_disease = None
        self.dic_embedding_cat_params = dic_embedding_cat_params
        dic_embedding_cat = {}
        for cat, max_number  in self.dic_embedding_cat:
        
            if cat=='diseases':
                if method == None:
                    dic_embedding_cat[cat] = nn.Embedding(max_number, Embedding_size)
                    torch.nn.init.normal_(dic_embedding_cat[cat].weight, mean=0.0, std=0.02)

                elif method == 'Abby':
                    embedding_file_diseases = f'/gpfs/commons/groups/gursoy_lab/mstoll/codes/Data_Files/Embeddings/Abby/embedding_abby_no_1_diseases.pth'
                    pretrained_weights_diseases = torch.load(embedding_file_diseases)
                    self.Embedding_size = pretrained_weights_diseases.shape[1]
                    dic_embedding_cat[cat] = nn.Embedding.from_pretrained(pretrained_weights_diseases, freeze=freeze_embed)

                elif method=='Paul':
                    embedding_file_diseases = f'/gpfs/commons/groups/gursoy_lab/mstoll/codes/Data_Files/Embeddings/Paul_Glove/glove_UKBB_omop_rollup_closest_depth_{self.rollup_depth}_no_1_diseases.pth'
                    pretrained_weights_diseases = torch.load(embedding_file_diseases)
                    dic_embedding_cat[cat] = pretrained_weights_diseases.shape[1]

                    self.distinct_diseases_embeddings = nn.Embedding.from_pretrained(pretrained_weights_diseases, freeze=freeze_embed)
            if cat=='counts':
                if (method == None) or (method == 'Paul') :
                    dic_embedding_cat['counts_embeddings'] = nn.Embedding(max_number, self.Embedding_size)
                    torch.nn.init.normal_(self.counts_embeddings.weight, mean=0.0, std=0.02)

            else:
                dic_embedding_cat[cat] = nn.Embedding(max_number, Embedding_size)
                torch.nn.init.normal_(dic_embedding_cat[cat].weight, mean=0.0, std=0.02)

        self.dic_embedding_cat = dic_embedding_cat


In [None]:
n_patient = 1000
val_possibles = np.arange(1, 4, dtype=int)
indices_random = np.unique(np.random.randint(0,3, size=(3)))

patient_list = np.zeros((n_patient, 3))
indices_1 = np.random.randint(0,2, n_patient)
indices_2 =np.random.randint(0,2, n_patient)

In [None]:
patient_diseases_list = []
for i in range(n_patient):
    indices_random = np.unique(np.random.randint(0,3, size=(3)))
    np.random.shuffle(indices_random)
    patient_diseases_list.append(np.concatenate([val_possibles[indices_random], np.zeros(3-len(val_possibles[indices_random]),dtype=int)]))
patient_diseases_list = np.array(patient_diseases_list)

In [None]:
patient_counts_list = np.ones((n_patient,3))

In [None]:
labels_list = []
for patient in patient_diseases_list:
    res = np.isin(np.array([1, 2, 3]), patient)
    if res[0] and not res[1]:
        labels_list.append(1)
    else:
        labels_list.append(0)
labels_list = np.array(labels_list)

In [None]:
Embedding  = EmbeddingPheno(method=None, vocab_size=4, max_count_same_disease=1, Embedding_size=10, rollup_depth=4, freeze_embed=False)


In [None]:
model = TransformerGeneModel_V2(pheno_method = 'Abby',
                             Embedding = Embedding,
                             Head_size=6,
                             binary_classes=True,
                             n_head=2,
                             n_layer=2,
                             mask_padding=True, 
                             padding_token=0, 
                             p_dropout=0, 
                             loss_version = 'cross_entropy',
                             gamma = 0,
                             instance_size=5)

In [None]:
# equalize:
patient_list_ones = patient_diseases_list[labels_list==1]
patient_list_zeros = patient_diseases_list[labels_list==0][:len(patient_list_ones)]
patient_diseases_list = np.concatenate([patient_list_ones, patient_list_zeros])
labels_list = np.array([1]*len(patient_list_ones) + [0]*len(patient_list_zeros))

In [None]:
prop_train = 0.8
n_train = int(len(patient_diseases_list)*prop_train)
data_train = list(zip(patient_diseases_list[:n_train],patient_counts_list[:n_train], labels_list[:n_train]))
data_test = list(zip(patient_diseases_list[n_train:],patient_counts_list[n_train:], labels_list[n_train:]))

In [None]:
dataloader_train = DataLoader(data_train, batch_size=20, shuffle=True)
dataloader_test = DataLoader(data_train, batch_size=20, shuffle=True)

In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr=0.0001)


In [None]:
batch_sentence, batch_counts, batch_labels = next(iter(dataloader_train))

In [None]:
logits, probas, x_out = model.forward_decomposed(batch_sentence, batch_counts)


In [None]:
model.padding_mask_probas

In [None]:
probas

In [None]:
attention_probas = model.list_attention_layers

In [None]:
attention_probas[0]

In [None]:
attention_weights = torch.rand(20,2, 3,3)

In [None]:
padding_mask = model.padding_mask

In [None]:
### padding mask #####
padding_mask_weights = -((1-padding_mask)*(10**10))
attention_weights = (attention_weights.transpose(0, 1)+padding_mask_weights).transpose(0, 1)



In [None]:
attention_weights

In [None]:
attention_probas = F.softmax(attention_weights, dim=-1) # shape B, S, S


In [None]:
attention_probas

In [None]:

attention_probas = F.softmax(attention_weights, dim=-1) # shape B, S, S

attention_probas = (attention_probas.transpose(0, 1)*padding_mask).transpose(0, 1)
# attention_probas[attention_probas.isnan()]=0
attention_probas_dropout = self.attention_dropout(attention_probas)


#print(f'wei1={attention_probas}')
#attention_probas = self.dropout(attention_probas)
## weighted aggregation of the values
out = attention_probas_dropout @ v_multi_head # shape B, S, S @ B, S, MH = B, S, MH
out = out.transpose(1, 2).contiguous().view(Batch_len, Sentence_len, self.Head_size)
out = self.proj(out)
#out = self.resid_dropout(ou

In [None]:
logits, probas, x_out = model.forward_decomposed(batch_sentence, batch_counts)
attention_probas = model.list_attention_layers
indice_sentence = 0
indice_layer = 0
indice_head = 0
attention_probas = attention_probas[indice_layer].detach().numpy()[indice_sentence][indice_head]
mask = model.padding_mask.detach().numpy()[indice_sentence]
n_real = np.sum(mask[0])

attention_probas_masked = attention_probas[mask].reshape(n_real, n_real)

sns.set(style="whitegrid")
plt.figure(figsize=(10, 8))
sns.heatmap(attention_probas_masked, cmap="YlGnBu", annot=True, fmt=".2f", cbar=False)

# Ajoutez des Ã©tiquettes pour les axes
plt.xlabel("Token")
plt.ylabel("Token")
plt.title("Self-Attention Matrix")

# Affichez le plot
plt.show()

In [None]:
attention_probas[indice_sentence][indice_head]

In [None]:
for epoch in range(1, total_epochs+1):

    total_loss = 0.0  
    
    #with tqdm(total=len(dataloader_train), position=0, leave=True) as pbar:
    for k, (batch_sentences, batch_counts, batch_labels) in enumerate(dataloader_train):

        # evaluate the loss
        logits, loss = model(batch_sentences, batch_counts, batch_labels)
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
    

        total_loss += loss.item()

        optimizer.step()



    
    


In [None]:
logits, probas, attention_probas, attention_weights, attention_weights_bis,  x_out = model.forward_decomposed(batch_sentences, batch_counts)

In [None]:
batch_sentences[0]

In [None]:
attention_weights

In [None]:
attention_probas[0] * (1-model.padding_mask[0].to(int))

In [None]:
batch_sentences[0]

In [None]:
model.padding_mask[0]

In [None]:
pred_labels = model.predict(batch_sentences, batch_counts)

In [None]:
pred_labels

In [None]:
f1, accuracy, auc_score, loss, proba_avg_zero, proba_avg_one, predicted_probas_list, true_labels_list=model.evaluate(dataloader_train)

In [None]:
accuracy

In [None]:
true_labels_list = np.array(true_labels_list)

In [None]:
pred_probas = np.array(predicted_probas_list)



In [None]:
pred_labels = (pred_probas[:,0] < 0.5).astype(int)

In [None]:
np.sum(pred_labels==true_labels_list)

In [None]:
logits, probas, attention_probas, x_out = model.forward_decomposed(batch_sentences, batch_counts)

In [None]:
probas

In [None]:
logits.shape

In [None]:
ind = 3
logits[ind], x_out[ind], batch_sentences[ind], probas[ind], attention_probas[ind]

In [None]:
np.sum(pred_labels==1)

In [None]:
accuracy

In [None]:
probas = torch.rand(2, 3, 1)
mask = torch.zeros(2, 3, 1).to(bool)

In [None]:
probas

In [None]:
mask[0][0][0] = True
mask[0][1][0] = True
mask[1][1][0] = True

In [None]:
probas[mask]=0

In [None]:
probas

In [None]:
model.projection_embed(model.diseases_embedding_table.weight).shape

In [None]:
model.diseases_embedding_table