In [1]:
import sys
path = '/gpfs/commons/groups/gursoy_lab/mstoll/'
sys.path.append(path)
from codes.models.data_form.DataForm import DataTransfo_1SNP

In [119]:
##### Functionnal version with optionnal mask padding and dropouts, see Transformer_V1.ipynb for example
import sys
path = '/gpfs/commons/groups/gursoy_lab/mstoll/'
sys.path.append(path)


### imports
import torch
import torch.nn as nn
import numpy as np
import time
from torch.nn import functional as F
from sklearn.metrics import f1_score, accuracy_score
from codes.models.data_form.DataForm import DataTransfo_1SNP, PatientList, Patient

from codes.models.Transformers.Embedding import EmbeddingPhenoCat
from codes.models.metrics import calculate_roc_auc, calculate_classification_report, calculate_loss, get_proba
from torch.utils.data import DataLoader
### Transformer's instance
# B, S, E, H, HN, MH = Batch_len, Sentence_len, Embedding_len, Head_size, Head number, MultiHead size.
class TabTransformerGeneModel_V2(nn.Module):
    def __init__(self, pheno_method, Embedding, instance_size, proj_embed, list_env_features, Head_size, binary_classes, n_head, n_layer, mask_padding=False, padding_token=None, p_dropout=0, device='cpu', loss_version='cross_entropy', gamma=2, alpha=1):
        super().__init__()
       
        self.Embedding_size = Embedding.Embedding_size
        

        self.mask_padding = mask_padding
        self.padding_token = padding_token
        self.padding_mask = None
        self.device = device
        self.pheno_method = pheno_method
        self.binary_classes = binary_classes
        self.Classes_nb = 2 if self.binary_classes else 3
        self.loss_version = loss_version
        self.gamma = gamma
        self.alpha = alpha
        self.list_env_features = list_env_features
        self.nb_env = len(self.list_env_features)
        self.instance_size = instance_size
        self.Embedding = Embedding
        self.proj_embed = proj_embed
        if not self.proj_embed:
            self.instance_size = self.Embedding_size
        if self.proj_embed:
            self.projection_embed = nn.Linear(self.Embedding_size, self.instance_size)
        
        self.blocks =PadMaskSequential(*[Block(self.instance_size, n_head=n_head, Head_size=Head_size, p_dropout=p_dropout, nb_env=self.nb_env) for _ in range(n_layer)]) #Block(self.instance_size, n_head=n_head, Head_size=Head_size) 
        self.ln_f = nn.LayerNorm(self.instance_size) # final layer norm
        self.lm_head_logits = nn.Linear(self.instance_size, self.Classes_nb) 
        self.lm_head_proba = nn.Linear(self.instance_size,1) # plus one for the probabilities


        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
            

    def create_padding_mask(self, input_dict):
        diseases_sentence = input_dict['diseases']
        B, S = np.shape(diseases_sentence)[0], np.shape(diseases_sentence)[1]
        mask = torch.where(diseases_sentence==self.padding_token)
        padding_mask_mat = torch.ones((B, S+self.nb_env, S+self.nb_env))
        padding_mask_mat[mask] = 0
        padding_mask_mat.transpose(-2,-1)[mask] = 0

        padding_mask_probas = torch.ones((B, S+self.nb_env))
        padding_mask_probas[mask] = 0
        padding_mask_probas = padding_mask_probas.view(B, S+self.nb_env)
        return padding_mask_mat.to(self.device), padding_mask_probas.to(self.device) # 1 if masked, 0 else

    def set_padding_mask_transformer(self, padding_mask_mat, padding_mask_probas):
        self.padding_mask_mat = padding_mask_mat
        self.padding_mask_probas = padding_mask_probas
    

    def forward(self, input_dict):
        for key, value in input_dict.items():
            input_dict[key] = value.to(self.device)

        if 'SNP_label' in list(input_dict.keys()):
            targets = input_dict.pop('SNP_label')
        else:
            targets = None
        input_embedded = self.Embedding(input_dict)
        Batch_len, Sentence_len, _ = input_embedded.shape

   

        if self.mask_padding:
            padding_mask_mat, padding_mask_probas = self.create_padding_mask(input_dict)
            self.set_padding_mask_transformer(padding_mask_mat, padding_mask_probas)
            self.blocks.set_padding_mask_sequential(self.padding_mask_mat)

        
        x = self.blocks(input_embedded) # shape B, S, E
        x = self.ln_f(x) # shape B, S, E
        logits = self.lm_head_logits(x) #shape B, S, Classes_Numb, token logits
        weights_logits = self.lm_head_proba(x).view(Batch_len, Sentence_len)
        probas = F.softmax(weights_logits) # shape B, S(represent the probas to be chosen)
        probas = probas * self.padding_mask_probas
        logits = (logits.transpose(1, 2) @ probas.view(Batch_len, Sentence_len, 1)).view(Batch_len, self.Classes_nb)# (B,Classes_Numb) Weighted Average logits
        loss = calculate_loss(logits, targets, self.loss_version, self.gamma, self.alpha)
        return logits, loss
    
    
    def forward_decomposed(self, input_dict):
        for key, value in input_dict.items():
            input_dict[key] = value.to(self.device)
            
        if 'SNP_label' in list(input_dict.keys()):
            targets = input_dict.pop('SNP_label')
        else:
            targets = None
        input_embedded = self.Embedding(input_dict)
        Batch_len, Sentence_len, _ = input_embedded.shape


        if self.mask_padding:
            padding_mask_mat, padding_mask_probas = self.create_padding_mask(input_dict)
            self.set_padding_mask_transformer(padding_mask_mat, padding_mask_probas)
            self.blocks.set_padding_mask_sequential(self.padding_mask_mat)

        
        x, attention_probas = self.blocks.forward_decompose(input_embedded) # shape B, S, E
        x = self.ln_f(x) # shape B, S, E
        logits = self.lm_head_logits(x) #shape B, S, Classes_Numb, token logits
        weights_logits = self.lm_head_proba(x).view(Batch_len, Sentence_len)
        probas = F.softmax(weights_logits) # shape B, S(represent the probas to be chosen)
        probas = probas * self.padding_mask_probas

        logits = (logits.transpose(1, 2) @ probas.view(Batch_len, Sentence_len, 1)).view(Batch_len, self.Classes_nb)# (B,Classes_Numb) Weighted Average logits
        loss = calculate_loss(logits, targets, self.loss_version, self.gamma, self.alpha)
        return logits, loss, input_embedded, attention_probas, probas

    def predict(self,input_dict):
        if 'SNP_label' in list(input_dict.keys()):
            input_dict.pop('SNP_label')
        logits, _ = self(input_dict) # shape B, Classes_nb
        return torch.argmax(logits, dim=1)  # (B,)
        
    def predict_proba(self, input_dict):
        if 'SNP_label' in list(input_dict.keys()):
            input_dict.pop('SNP_label')
        logits, _ = self(input_dict)
        probabilities = F.softmax(logits, dim=1)
        return probabilities
    
    def evaluate(self, dataloader_test):
        print('beginning inference evaluation')
        start_time_inference = time.time()
        predicted_labels_list = []
        predicted_probas_list = []
        true_labels_list = []

        total_loss = 0.
        self.eval()
        with torch.no_grad():
            for input_dicts in dataloader_test:

                batch_labels = input_dicts['SNP_label']

                logits, loss = self(input_dicts)
                total_loss += loss.item()
                predicted_labels = self.predict(input_dicts)
                predicted_labels_list.extend(predicted_labels.cpu().numpy())
                predicted_probas = self.predict_proba(input_dicts)
                predicted_probas_list.extend(predicted_probas.cpu().numpy())
                true_labels_list.extend(batch_labels.cpu().numpy())
        f1 = f1_score(true_labels_list, predicted_labels_list, average='macro')
        accuracy = accuracy_score(true_labels_list, predicted_labels_list)
        auc_score = calculate_roc_auc(true_labels_list, np.array(predicted_probas_list)[:, 1], return_nan=True)
        proba_avg_zero, proba_avg_one = get_proba(true_labels_list, predicted_probas_list)
    
        self.train()
        print(f'end inference evaluation in {time.time() - start_time_inference}s')
        return f1, accuracy, auc_score, total_loss/len(dataloader_test), proba_avg_zero, proba_avg_one, predicted_probas_list, true_labels_list


    def write_embedding(self, writer):
        if self.proj_embed:
            embedding_tensor = self.projection_embed(self.Embedding.dic_embedding_cat['diseases'].weight).detach().cpu().numpy()
        else:
            embedding_tensor = self.Embedding.dic_embedding_cat['diseases'].weight.detach().cpu().numpy()
        writer.add_embedding(embedding_tensor, metadata=self.Embedding.metadata, metadata_header=["Name","Label"])


class PadMaskSequential(nn.Sequential):
    def __init__(self, *args):
        super(PadMaskSequential, self).__init__(*args)
        self.padding_mask = None

    def set_padding_mask_sequential(self, padding_mask):
        self.padding_mask = padding_mask

    def forward(self, x):
        for module in self:
            module.set_padding_mask_block(self.padding_mask)
            x = module(x)
        return x
    
    def forward_decompose(self, x):
        for module in self:
            module.set_padding_mask_block(self.padding_mask)
            x = module.forward_decompose(x)
        return x
    
class Block(nn.Module):
    def __init__(self, instance_size, n_head, Head_size, p_dropout, nb_env):
        super().__init__()
        self.sa = MultiHeadSelfAttention(n_head, Head_size, instance_size, p_dropout,  nb_env=nb_env)
        self.ffwd = FeedForward(instance_size, p_dropout)
        self.ln1 = nn.LayerNorm(instance_size)
        self.ln2 = nn.LayerNorm(instance_size)
        self.padding_mask = None

    def set_padding_mask_block(self, padding_mask):
        self.padding_mask = padding_mask

    def forward(self, x):
        self.sa.set_padding_mask_sa(self.padding_mask)
        #x = self.ln1(x)
        x = x + self.sa(x)
        x = self.ln1(x)
        x = x + self.ffwd(x)
        x = self.ln2(x)
        return x
    
    def forward_decompose(self, x):
        self.sa.set_padding_mask_sa(self.padding_mask)
        out_sa, attention_probas= self.sa.forward_decompose(x)
        x = out_sa + x
        x = self.ln1(x)
        x = x + self.ffwd(x)
        x = self.ln2(x)
        return x, attention_probas


class MultiHeadSelfAttention(nn.Module):
    def __init__(self, n_head, Head_size, instance_size, p_dropout, nb_env):
        super().__init__()
        self.q_network = nn.Linear(instance_size, Head_size, bias = False) 
        self.k_network =  nn.Linear(instance_size, Head_size, bias = False)
        self.v_network =  nn.Linear(instance_size, Head_size, bias = False)
        self.proj = nn.Linear(Head_size, instance_size)
        self.attention_dropout = nn.Dropout(p_dropout)
        self.resid_dropout = nn.Dropout(p_dropout)

        self.multi_head_size = Head_size // n_head
        self.flash = False
        self.n_head = n_head
        self.Head_size = Head_size
        self.padding_mask = None
        self.nb_env = nb_env

    def set_padding_mask_sa(self, padding_mask):
        self.padding_mask = padding_mask

        #self.dropout = nn.Dropout(dropout)
   
    def forward(self, x):
        # x of size (B, S, E)
        Batch_len, Sentence_len, _ = x.shape

        q = self.q_network(x)
        k = self.k_network(x)
        v = self.v_network(x)
        # add a dimension to compute the different attention heads separately
        q_multi_head = q.view(Batch_len, Sentence_len, self.n_head, self.multi_head_size).transpose(1, 2) #shape B, HN, S, MH
        k_multi_head = k.view(Batch_len, Sentence_len, self.n_head, self.multi_head_size).transpose(1, 2)
        v_multi_head = v.view(Batch_len, Sentence_len, self.n_head, self.multi_head_size).transpose(1, 2)

        if self.flash:
            # efficient attention using Flash Attention CUDA kernels
            out = torch.nn.functional.scaled_dot_product_attention(q_multi_head, k_multi_head, v_multi_head, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True)
        else:    
            attention_weights = (q_multi_head @ k_multi_head.transpose(-2, -1))/np.sqrt(self.multi_head_size) # shape B, S, S
            ### padding mask #####
            attention_probas = F.softmax(attention_weights, dim=-1) # shape B, S, S
            if self.padding_mask != None:
                attention_probas = (attention_probas.transpose(0, 1)*self.padding_mask).transpose(0, 1)
           # attention_probas[attention_probas.isnan()]=0
            attention_probas = self.attention_dropout(attention_probas)

            #print(f'wei1={attention_probas}')
            #attention_probas = self.dropout(attention_probas)
            ## weighted aggregation of the values
            out = attention_probas @ v_multi_head # shape B, S-Env, S @ B, S, MH = B, S, MH
        out = out.transpose(1, 2).contiguous().view(Batch_len, Sentence_len, self.Head_size)
        out = self.proj(out)
        out = self.resid_dropout(out)
        return out        

    def forward_decompose(self, x):
        # x of size (B, S, E)
        Batch_len, Sentence_len, _ = x.shape

        q = self.q_network(x)
        k = self.k_network(x)
        v = self.v_network(x)
        # add a dimension to compute the different attention heads separately
        q_multi_head = q.view(Batch_len, Sentence_len, self.n_head, self.multi_head_size).transpose(1, 2) #shape B, HN, S, MH
        k_multi_head = k.view(Batch_len, Sentence_len, self.n_head, self.multi_head_size).transpose(1, 2)
        v_multi_head = v.view(Batch_len, Sentence_len, self.n_head, self.multi_head_size).transpose(1, 2)

        if self.flash:
            # efficient attention using Flash Attention CUDA kernels
            out = torch.nn.functional.scaled_dot_product_attention(q_multi_head, k_multi_head, v_multi_head, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True)
        else:    
            ### padding mask #####
            attention_weights = (q_multi_head @ k_multi_head.transpose(-2, -1))/np.sqrt(self.multi_head_size) # shape B, S, S
            ### padding mask #####
            attention_probas = F.softmax(attention_weights, dim=-1) # shape B, S, S
            if self.padding_mask != None:
                attention_probas = (attention_probas.transpose(0, 1)*self.padding_mask).transpose(0, 1)
           # attention_probas[attention_probas.isnan()]=0
            attention_probas = self.attention_dropout(attention_probas)

            #print(f'wei1={attention_probas}')
            #attention_probas = self.dropout(attention_probas)
            ## weighted aggregation of the values
            out = attention_probas @ v_multi_head # shape B, S, S @ B, S, MH = B, S, MH
        out = out.transpose(1, 2).contiguous().view(Batch_len, Sentence_len, self.Head_size)
        out = self.proj(out)
        out = self.resid_dropout(out)
        return out, attention_probas     

class FeedForward(nn.Module):
    """ a simple linear layer followed by a non-linearity"""
    def __init__(self, instance_size, p_dropout):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear( instance_size, 4 * instance_size),
            nn.ReLU(),
            nn.Linear(4 * instance_size, instance_size),
            nn.Dropout(p_dropout)
        )

    def forward(self, x):
        return self.net(x)
        


In [120]:
import os
import sys
sys.path.append('/gpfs/commons/groups/gursoy_lab/pmeddeb/phenotype_embedding')
import time
import pickle
import numpy as np
import pandas as pd
import torch
import torch.nn as nn

import torch
import torch.nn as nn

class SineCosineEncoding(nn.Module):
    def __init__(self, Embedding_size, max_len=1000):
        super(SineCosineEncoding, self).__init__()
        self.encoding = torch.zeros(max_len, Embedding_size)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        div_term = torch.exp(torch.arange(0, Embedding_size, 2).float() * -(np.log(10000.0) / Embedding_size))
        self.encoding[:, 0::2] = torch.sin(position * div_term)
        self.encoding[:, 1::2] = torch.cos(position * div_term)

    def forward(self, x):

        return self.encoding.to(x.device)[x]

class ZeroEmbedding(nn.Module):
    def __init__(self, Embedding_size, max_len=1000):
        super(ZeroEmbedding, self).__init__()
        self.encoding = torch.zeros(max_len, Embedding_size)
       

    def forward(self, x):

        return self.encoding.to(x.device)[x]


class EmbeddingPheno(nn.Module):
    def __init__(self, method=None, counts_method=None, vocab_size=None, max_count_same_disease=None, Embedding_size=None, rollup_depth=4, freeze_embed=False, dicts=None):
        super(EmbeddingPheno, self).__init__()

        self.dicts = dicts
        self.rollup_depth = rollup_depth
        self.nb_distinct_diseases_patient = vocab_size
        self.Embedding_size = Embedding_size
        self.max_count_same_disease = None
        self.metadata = None
        self.counts_method = counts_method

        if self.dicts != None:
            id_dict = self.dicts['id']
            name_dict = self.dicts['name']
            cat_dict = self.dicts['cat']
            codes = list(id_dict.keys())
            diseases_present = self.dicts['diseases_present']
            self.metadata = [[name_dict[code], cat_dict[code]] for code in codes]

        
        if method == None:
            self.distinct_diseases_embeddings = nn.Embedding(vocab_size, Embedding_size)
            self.counts_embeddings = nn.Embedding(max_count_same_disease, Embedding_size)
            torch.nn.init.normal_(self.distinct_diseases_embeddings.weight, mean=0.0, std=0.02)
            torch.nn.init.normal_(self.counts_embeddings.weight, mean=0.0, std=0.02)

        elif method == 'Abby':
            embedding_file_diseases = f'/gpfs/commons/groups/gursoy_lab/mstoll/codes/Data_Files/Embeddings/Abby/embedding_abby_no_1_diseases.pth'
            pretrained_weights_diseases = torch.load(embedding_file_diseases)[diseases_present]
            self.Embedding_size = pretrained_weights_diseases.shape[1]

            self.distinct_diseases_embeddings = nn.Embedding.from_pretrained(pretrained_weights_diseases, freeze=freeze_embed)
            self.counts_embeddings = nn.Embedding(max_count_same_disease, self.Embedding_size)



        elif method=='Paul':
            embedding_file_diseases = f'/gpfs/commons/groups/gursoy_lab/mstoll/codes/Data_Files/Embeddings/Paul_Glove/glove_UKBB_omop_rollup_closest_depth_{self.rollup_depth}_no_1_diseases.pth'
            pretrained_weights_diseases = torch.load(embedding_file_diseases)[diseases_present]
            self.Embedding_size = pretrained_weights_diseases.shape[1]

            self.distinct_diseases_embeddings = nn.Embedding.from_pretrained(pretrained_weights_diseases, freeze=freeze_embed)
            if self.counts_method == 'SineCosine':
                self.counts_embeddings = SineCosineEncoding(self.Embedding_size, max_count_same_disease)
            elif self.counts_method == 'no_counts':
                self.counts_embeddings = ZeroEmbedding(self.Embedding_size, max_count_same_disease )
            else:

                self.counts_embeddings = nn.Embedding(max_count_same_disease, self.Embedding_size)
    def write_embedding(self, writer):
            embedding_tensor = self.distinct_diseases_embeddings.weight.data.detach().cpu().numpy()
            writer.add_embedding(embedding_tensor, metadata=self.metadata, metadata_header=["Name","Label"])


class EmbeddingPhenoCat(nn.Module):
    def __init__(self, pheno_method=None,  method=None, proj_embed=None, counts_method=None, Embedding_size=10, instance_size=10, rollup_depth=4, freeze_embed=False, dic_embedding_cat_params={}, dicts=None, device='cpu'):
        super(EmbeddingPhenoCat, self).__init__()

        self.rollup_depth = rollup_depth
        self.Embedding_size = Embedding_size
        self.max_count_same_disease = None
        self.dic_embedding_cat_params = dic_embedding_cat_params
        dic_embedding_cat = {}
        self.method = method
        self.pheno_method = pheno_method
        self.dicts = dicts
        self.proj_embed = proj_embed
        self.projection_embed = None
        self.instance_size = instance_size
        self.counts_method = counts_method

        self.device = device
        if self.dicts != None:
            id_dict = self.dicts['id']
            name_dict = self.dicts['name']
            cat_dict = self.dicts['cat']
            codes = list(id_dict.keys())
            diseases_present = self.dicts['diseases_present']
            self.metadata = [[name_dict[code], cat_dict[code]] for code in codes]

        for cat, max_number  in self.dic_embedding_cat_params.items():
        
            if cat=='diseases':
                if self.method == None:
                    dic_embedding_cat[cat] = nn.Embedding(max_number, Embedding_size)
                    torch.nn.init.normal_(dic_embedding_cat[cat].weight, mean=0.0, std=0.02)

                elif self.method == 'Abby':
                    embedding_file_diseases = f'/gpfs/commons/groups/gursoy_lab/mstoll/codes/Data_Files/Embeddings/Abby/embedding_abby_no_1_diseases.pth'
                    pretrained_weights_diseases = torch.load(embedding_file_diseases)[diseases_present]
                    self.Embedding_size = pretrained_weights_diseases.shape[1]
                    dic_embedding_cat[cat] = nn.Embedding.from_pretrained(pretrained_weights_diseases, freeze=freeze_embed).to(self.device)

            

                elif self.method=='Paul':
                    embedding_file_diseases = f'/gpfs/commons/groups/gursoy_lab/mstoll/codes/Data_Files/Embeddings/Paul_Glove/glove_UKBB_omop_rollup_closest_depth_{self.rollup_depth}_no_1_diseases.pth'
                    pretrained_weights_diseases = torch.load(embedding_file_diseases)[diseases_present]
                    self.Embedding_size = pretrained_weights_diseases.shape[1]
                    dic_embedding_cat[cat] = nn.Embedding.from_pretrained(pretrained_weights_diseases, freeze=freeze_embed).to(self.device)
                    
            elif cat == 'counts':
                if self.pheno_method == 'Paul':
                    if self.counts_method[cat] == 'SineCosine':
                        dic_embedding_cat[cat] = SineCosineEncoding(self.instance_size, max_number).to(self.device)
                    elif self.counts_method[cat] == 'no_counts':
                        dic_embedding_cat[cat] = ZeroEmbedding(self.instance_size, max_number).to(self.device)
                    else:
                        dic_embedding_cat[cat] = nn.Embedding(max_number, self.instance_size).to(self.device)
                        torch.nn.init.normal_(dic_embedding_cat[cat].weight, mean=0.0, std=0.02)

            elif cat == 'age':
                if self.counts_method[cat] == 'SineCosine':
                    dic_embedding_cat[cat] = SineCosineEncoding(self.instance_size, max_number).to(self.device)
                elif self.counts_method[cat] == 'no_counts':
                    dic_embedding_cat[cat] = ZeroEmbedding(self.instance_size, max_number).to(self.device)
                else:
                    dic_embedding_cat[cat] = nn.Embedding(max_number, self.instance_size).to(self.device)
                    torch.nn.init.normal_(dic_embedding_cat[cat].weight, mean=0.0, std=0.02)

                    

            else:
                dic_embedding_cat[cat] = nn.Embedding(max_number, self.instance_size).to(self.device)
                torch.nn.init.normal_(dic_embedding_cat[cat].weight, mean=0.0, std=0.02)

        if self.proj_embed:
            self.projection_embed = nn.Linear(self.Embedding_size, self.instance_size).to(self.device)

        self.dic_embedding_cat = dic_embedding_cat

    def forward(self, input_dict):
        list_env_embedded = []
        for key, value in input_dict.items():
            
            batch_len = len(value)

            if key=='diseases':
                diseases_sentences_embedded = self.dic_embedding_cat[key](value)
                if self.proj_embed:
                    diseases_sentences_embedded = self.projection_embed(diseases_sentences_embedded)

            elif key=='counts':
                if self.pheno_method == 'Paul':
                    counts_sentence_embedded = self.dic_embedding_cat[key](value)
                    diseases_sentences_embedded = diseases_sentences_embedded + counts_sentence_embedded
            

            else:
                list_env_embedded.append(self.dic_embedding_cat[key](value).view(batch_len, 1, self.instance_size))

        env_embedded = torch.concat(list_env_embedded, dim=1)

        return torch.concat([diseases_sentences_embedded, env_embedded], dim=1)
            



In [106]:
## creation of the reference model
#### framework constants:
model_type = 'tab_transformer'
model_version = 'transformer_V2'
test_name = 'baseline_model_focal'
pheno_method = 'Abby' # Paul, Abby
tryout = False # True if we are doing a tryout, False otherwise 
### data constants:
CHR = 1
SNP = 'rs673604'
rollup_depth = 4
Classes_nb = 2 #nb of classes related to an SNP (here 0 or 1)
vocab_size = None # to be defined with data
padding_token = 0
prop_train_test = 0.8
load_data = True
save_data = False
remove_none = True
compute_features = False
padding = True
list_env_features = ['age', 'sex']
### data format
batch_size = 200
data_share = 1/1000#402555
seuil_diseases = 600
equalize_label = False
decorelate = False
threshold_corr = 1
threshold_rare = 1000
remove_rare = 'all' # None, 'all', 'one_class'
##### model constants
embedding_method = 'Abby' #None, Paul, Abby
freeze_embedding = True
Embedding_size = 4 # Size of embedding.
proj_embed = True
instance_size = 10
n_head = 2 # number of SA heads
n_layer = 1 # number of blocks in parallel
Head_size = 4 # size of the "single Attention head", which is the sum of the size of all multi Attention heads
eval_epochs_interval = 5 # number of epoch between each evaluation print of the model (no impact on results)
eval_batch_interval = 40
p_dropout = 0.3 # proba of dropouts in the model
masking_padding = True # do we include padding masking or not
loss_version = 'cross_entropy' #cross_entropy or focal_loss
gamma = 2
alpha = 63
##### training constants
total_epochs = 100 # number of epochs
learning_rate_max = 0.001 # maximum learning rate (at the end of the warmup phase)
learning_rate_ini = 0.00001 # initial learning rate 
learning_rate_final = 0.0001
warm_up_frac = 0.5 # fraction of the size of the warmup stage with regards to the total number of epochs.
start_factor_lr = learning_rate_ini / learning_rate_max
end_factor_lr = learning_rate_final / learning_rate_max

In [63]:
dataT = DataTransfo_1SNP(SNP=SNP,
                         CHR=CHR,
                         method=pheno_method,
                         padding=padding,  
                         pad_token=padding_token, 
                         load_data=load_data, 
                         save_data=save_data, 
                         compute_features=compute_features,
                         prop_train_test=prop_train_test,
                         remove_none=True,
                         equalize_label=equalize_label,
                         rollup_depth=rollup_depth,
                         decorelate=decorelate,
                         threshold_corr=threshold_corr,
                         threshold_rare=threshold_rare,
                         remove_rare=remove_rare, 
                         list_env_features=list_env_features,
                         data_share=data_share)
#patient_list = dataT.get_patientlist()


In [67]:
from codes.models.data_form.DataSets import TabDictDataset
from torch.utils.data import DataLoader

In [68]:
dic_data= dataT.get_data_tabtransfo(actualise_phenos=True)


In [69]:
indices_train, indices_test = dataT.get_indices_train_test(nb_data=len(dic_data['diseases']))

dic_data_train = {key: np.array(dic_data[key])[indices_train] for key in dic_data.keys()}
dic_data_test = {key: np.array(dic_data[key])[indices_test] for key in dic_data.keys()}

max_number_diseases = len(dataT.dicts['id'] ) 
max_number_counts = np.max([np.max(dic_data['counts'][k]) for k in range(len(dic_data['counts']))]) + 1
max_number_age = np.max(np.array(dic_data['age'])) + 1
max_number_sex = 2
dic_embedding_cat_params = {'diseases':max_number_diseases, 'counts':max_number_counts, 'age':max_number_age, 'sex':max_number_sex} 

dataset_train = TabDictDataset(dic_data_train)
dataset_test = TabDictDataset(dic_data_test)

dataloader_train =  DataLoader(dataset_train, batch_size = batch_size, shuffle=True)
dataloader_test =  DataLoader(dataset_test, batch_size = batch_size, shuffle=True)

dataloader_train = dataloader_train
dataloader_test = dataloader_test
vocab_size = max_number_diseases
max_count_same_disease =max_number_diseases

In [70]:
dic_embedding_cat_params = {'diseases':max_number_diseases, 'counts':max_number_counts, 'age':max_number_age, 'sex':max_number_sex}
counts_method = {'age':'SineCos'}

In [121]:
Embedding = EmbeddingPhenoCat(
    pheno_method=pheno_method,
    instance_size=instance_size,
    proj_embed=proj_embed,
    method=embedding_method,
    Embedding_size=Embedding_size,
    rollup_depth=rollup_depth, 
    freeze_embed=freeze_embedding,
    dic_embedding_cat_params=dic_embedding_cat_params,
    dicts=dataT.dicts,
    counts_method=counts_method
)

model = TabTransformerGeneModel_V2(
    pheno_method=pheno_method,
    Embedding=Embedding,
    list_env_features=list_env_features,
    Head_size=Head_size,
    binary_classes=True,
    instance_size=instance_size,
    n_head=n_head,
    n_layer=n_layer,
    mask_padding=masking_padding,
    padding_token=padding_token,
    p_dropout=p_dropout,
    device='cpu',
    loss_version=loss_version,
    gamma=gamma,
    alpha=alpha,
    proj_embed=proj_embed,
    
)

In [122]:
input_dict = next(iter(dataloader_test))

In [123]:
model.forward(input_dict)

In [136]:
model.padding_mask_probas[0]

In [137]:
logits, loss, input_embedded, attention_probas, probas = model.forward_decomposed(input_dict)

In [142]:
attention_probas[0][0][2]

In [103]:
model.padding_mask

In [86]:
logits

In [104]:

start_time_epoch = time.time()
total_loss = 0.0  

#with tqdm(total=len(dataloader_train), position=0, leave=True) as pbar:
for k, input_dict in enumerate(dataloader_train):
    


    # evaluate the loss
    logits, loss = model(input_dict)
    #optimizer.zero_grad(set_to_none=True)
    #loss.backward()


    total_loss += loss.item()

    #optimizer.step()

   


In [33]:
Embedding.dic_embedding_cat

In [35]:
Embedding(input_dict)

In [12]:
logits, loss = model.forward(x)

In [141]:
model.padding_mask[0].shape

In [49]:
torch.concat(Embedding.list_env_embedded).view(2, 89, 4).transpose(0, 1)[0][1]

In [50]:
Embedding.list_env_embedded[1][0]

In [12]:
import torch
from torch.utils.data import Dataset, DataLoader

class MyDictDataset(Dataset):
    def __init__(self, data_dict, key):
        self.data_dict = data_dict
        self.keys = list(data_dict.keys())
        self.key = key

    def __len__(self):
        return len(self.data_dict[self.key])

    def __getitem__(self, index):
        batch = {k: self.data_dict[k][index] for k in self.keys}
        return batch

# Example huge dictionary
huge_dict = {
    'a': [1, 2, 3, 4, 5, 6],  # List of 100 values for key 'a'
    'b':  [1, 2, 3, 4, 5, 6] # List of 100 values for key 'b'
}

# Create separate datasets for keys 'a' and 'b'
dataset_a = MyDictDataset(huge_dict, key='a')
dataset_b = MyDictDataset(huge_dict, key='b')

# Create separate DataLoaders for keys 'a' and 'b'
batch_size = 2
dataloader_a = DataLoader(dataset_a, batch_size=batch_size, shuffle=True)
dataloader_b = DataLoader(dataset_b, batch_size=batch_size, shuffle=True)

# Iterate through the DataLoaders
for batch_a, batch_b in zip(dataloader_a, dataloader_b):
    print("Batch for key 'a':", batch_a)
    print("Batch for key 'b':", batch_b)


In [10]:
import torch
from torch.utils.data import Dataset, DataLoader

class MyDictDataset(Dataset):
    def __init__(self, data):
        self.data = data
        self.keys = list(data.keys())

    def __len__(self):
        return len(self.keys)

    def __getitem__(self, index):
        key = self.keys[index]
        return {'key': key, 'value': self.data[key]}

# Example dictionary
my_data = {'a': [1, 2, 3], 'b': [4, 5, 6], 'c': [7, 8, 9]}

# Create an instance of the custom dataset
my_dataset = MyDictDataset(my_data)

# Create a DataLoader using the custom dataset
batch_size = 2
my_dataloader = DataLoader(my_dataset, batch_size=batch_size, shuffle=True)

# Iterate through the DataLoader
for batch in my_dataloader:
    print(batch)


In [151]:
embedding = EmbeddingPheno(method=None, vocab_size=10, Embedding_size=E, max_count_same_disease=10)

In [152]:
model =TransformerGeneModel_V2(pheno_method=None,Embedding=embedding, Head_size=5,  Classes_nb=2, n_head=1, n_layer=1)

In [22]:
u = torch.rand(4, 3, 6)
v = torch.rand(4, 6).view(4, 1, 6)

In [23]:
f = torch.concat([u, v], dim=1)

In [24]:
f.shape

In [153]:
data = torch.tensor([[1, 2, 3, 0, 0]])
counts = torch.tensor([1, 1, 1, 0, 0])
labels = torch.tensor(0)

In [154]:
logits, probas, attention_probas = model.forward_decomposed(data,counts)

In [155]:
probas

In [98]:
probas = torch.rand(2, 3,)

In [99]:
logits = torch.rand(2, 3, 2)

In [107]:
logt = logits.transpose(1, 2) 

In [108]:
logt.shape

In [112]:
probt = probas.view(2, 3, 1)

In [114]:
u = (logt @ probt)

In [128]:
n = logits[0][:,0]

In [131]:
torch.sum(n* probas[0])

In [133]:
u

In [156]:
logits = torch.rand(2, 3, 4)

In [157]:
logits.transpose(1, 2).shape