In [None]:
import sys
path = '/gpfs/commons/groups/gursoy_lab/mstoll/'
sys.path.append(path)

import pandas as pd
import numpy as np 
import time
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import LambdaLR, LinearLR, SequentialLR
from sklearn.metrics import f1_score, accuracy_score


from codes.models.data_form.DataForm import DataTransfo_1SNP, PatientList
from codes.models.metrics import calculate_roc_auc, calculate_classification_report, calculate_loss, get_proba
from codes.models.utils import clear_last_line, print_file

import matplotlib.pyplot as plt

In [None]:
from codes.models.Generative.Embeddings import EmbeddingPheno, EmbeddingSNPS
from codes.models.Generative.GenerativeModel import GenerativeModelPheWasV1

In [None]:
import torch.nn as nn
import torch
import numpy as np
import torch.nn.functional as F
class EmbeddingPheno(nn.Module):
    def __init__(self, method=None, vocab_size=None, Embedding_size=None, rollup_depth=4, freeze_embed=False, dicts=None):
        super(EmbeddingPheno, self).__init__()

        self.dicts = dicts
        self.rollup_depth = rollup_depth
        self.nb_distinct_diseases_patient = vocab_size
        self.Embedding_size = Embedding_size
        self.metadata = None

        if self.dicts != None:
            id_dict = self.dicts['id']
            name_dict = self.dicts['name']
            cat_dict = self.dicts['cat']
            codes = list(id_dict.keys())
            diseases_present = self.dicts['diseases_present']
            self.metadata = [[name_dict[code], cat_dict[code]] for code in codes]

        
        if method == None:
            self.distinct_diseases_embeddings = nn.Embedding(vocab_size, Embedding_size)
            #self.counts_embeddings = nn.Embedding(max_count_same_disease, Embedding_size)
            torch.nn.init.normal_(self.distinct_diseases_embeddings.weight, mean=0.0, std=0.02)
           # torch.nn.init.normal_(self.counts_embeddings.weight, mean=0.0, std=0.02)

        elif method == 'Abby':
            embedding_file_diseases = f'/gpfs/commons/groups/gursoy_lab/mstoll/codes/Data_Files/Embeddings/Abby/embedding_abby_no_1_diseases.pth'
            pretrained_weights_diseases = torch.load(embedding_file_diseases)[diseases_present]
            self.Embedding_size = pretrained_weights_diseases.shape[1]

            self.distinct_diseases_embeddings = nn.Embedding.from_pretrained(pretrained_weights_diseases, freeze=freeze_embed)
            #self.counts_embeddings = nn.Embedding(max_count_same_disease, self.Embedding_size)



        elif method=='Paul':
            embedding_file_diseases = f'/gpfs/commons/groups/gursoy_lab/mstoll/codes/Data_Files/Embeddings/Paul_Glove/glove_UKBB_omop_rollup_closest_depth_{self.rollup_depth}_no_1_diseases.pth'
            pretrained_weights_diseases = torch.load(embedding_file_diseases)[diseases_present]
            self.Embedding_size = pretrained_weights_diseases.shape[1]

            self.distinct_diseases_embeddings = nn.Embedding.from_pretrained(pretrained_weights_diseases, freeze=freeze_embed)
            #self.counts_embeddings = nn.Embedding(max_count_same_disease, self.Embedding_size)

        embedding_file_diseases = f'/gpfs/commons/groups/gursoy_lab/mstoll/codes/Data_Files/Embeddings/Abby/embedding_abby_no_1_diseases.pth'
        pretrained_weights_diseases = torch.load(embedding_file_diseases)[diseases_present]
        pretrained_weights_diseases = pretrained_weights_diseases[1:]
        nb_phenos = pretrained_weights_diseases.shape[0]
        self.similarities_tab = torch.tensor(np.array([F.cosine_similarity(pretrained_weights_diseases, pretrained_weights_diseases[i], dim=-1).view(nb_phenos) for i in range(nb_phenos)]))
        

    def write_embedding(self, writer):
            embedding_tensor = self.distinct_diseases_embeddings.weight.data.detach().cpu().numpy()
            writer.add_embedding(embedding_tensor, metadata=self.metadata, metadata_header=["Name","Label"])
class EmbeddingSNPS(nn.Module):
    def __init__(self, method=None, nb_SNPS=1, Embedding_size=10, freeze_embed=False):
        super(EmbeddingSNPS, self).__init__()

        self.method = method
        self.Embedding_size = Embedding_size
        self.nb_SNPS = nb_SNPS

        if method == None:
            self.SNPS_embeddings = nn.Embedding(self.nb_SNPS*2, Embedding_size)
            #self.counts_embeddings = nn.Embedding(max_count_same_disease, Embedding_size)
            torch.nn.init.normal_(self.SNPS_embeddings.weight, mean=0.0, std=0.02)
           # torch.nn.init.normal_(self.counts_embeddings.weight, mean=0.0, std=0.02)
            


In [None]:
import sys
path = '/gpfs/commons/groups/gursoy_lab/mstoll/'
sys.path.append(path)

import torch.nn as nn
import torch
import numpy as np
import torch.nn.functional as F
import time
from sklearn.metrics import f1_score, accuracy_score
from codes.models.metrics import calculate_roc_auc, calculate_classification_report, calculate_loss, get_proba


class PhenotypeEncoding(nn.Module):
    def __init__(self, Embedding, Head_size, n_head, n_layer, mask_padding=False, padding_token=None, p_dropout=0, device='cpu', instance_size=None, proj_embed=True):
        super().__init__()
       
        self.Embedding_size = Embedding.Embedding_size
        self.instance_size=instance_size
        self.proj_embed = proj_embed
        if not self.proj_embed:
            self.instance_size = self.Embedding_size
        if self.proj_embed:
            self.projection_embed = nn.Linear(self.Embedding_size, self.instance_size)
        self.blocks =PadMaskSequential(*[BlockPheno(self.instance_size, n_head=n_head, Head_size=Head_size, p_dropout=p_dropout) for _ in range(n_layer)]) #Block(self.instance_size, n_head=n_head, Head_size=Head_size) 
        self.ln_f = nn.LayerNorm(self.instance_size) # final layer norm
        self.Embedding = Embedding
        self.mask_padding = mask_padding
        self.padding_token = padding_token
        self.padding_mask = None
        self.device = device
       
        self.diseases_embedding_table = Embedding.distinct_diseases_embeddings
        #if self.pheno_method == 'Paul':
        # self.counts_embedding_table = Embedding.counts_embeddings

        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
            

    def create_padding_mask(self, diseases_sentence):
        B, S = np.shape(diseases_sentence)[0], np.shape(diseases_sentence)[1]
        mask = torch.where(diseases_sentence==self.padding_token)
        padding_mask_mat = torch.ones((B, S, S), dtype=torch.bool)
        padding_mask_mat[mask] = 0
        padding_mask_mat.transpose(-2,-1)[mask] = 1

        padding_mask_probas = torch.zeros((B, S)).to(bool)
        padding_mask_probas[mask] = True
        padding_mask_probas = padding_mask_probas.view(B, S)
        return padding_mask_mat.to(self.device), padding_mask_probas # 1 if masked, 0 else

    def set_padding_mask_transformer(self, padding_mask, padding_mask_probas):
        self.padding_mask = padding_mask
        self.padding_mask_probas = padding_mask_probas
    
    def forward(self, diseases_sentence):
        Batch_len, Sentence_len = diseases_sentence.shape

        diseases_sentence = diseases_sentence.to(self.device)
        #counts_diseases = counts_diseases.to(self.device)
        
        if self.mask_padding:
            padding_mask, padding_mask_probas = self.create_padding_mask(diseases_sentence)
            self.set_padding_mask_transformer(padding_mask, padding_mask_probas)
            self.blocks.set_padding_mask_sequential(self.padding_mask)

        diseases_sentences_embedded = self.diseases_embedding_table(diseases_sentence) # shape B, S, E

        x = diseases_sentences_embedded 
    
        #if self.pheno_method == 'Paul':
        #    counts_diseases_embedded = self.counts_embedding_table(counts_diseases) # shape B, S, E
        #    #x = x + counts_diseases_embedded # shape B, S, E 
        
        if self.proj_embed:
            x = self.projection_embed(x)
        x = self.blocks(x) # shape B, S, E
        
        return x
   

class PadMaskSequential(nn.Sequential):
    def __init__(self, *args):
        super(PadMaskSequential, self).__init__(*args)
        self.padding_mask = None

    def set_padding_mask_sequential(self, padding_mask):
        self.padding_mask = padding_mask

    def forward(self, x):
        for module in self:
            module.set_padding_mask_block(self.padding_mask)
            x = module(x)
        return x
   
class BlockPheno(nn.Module):
    def __init__(self, instance_size, n_head, Head_size, p_dropout):
        super().__init__()
        self.sa = MultiHeadSelfAttention(n_head, Head_size, instance_size, p_dropout)
        self.ffwd = FeedForward(instance_size, p_dropout)
        self.ln1 = nn.LayerNorm(instance_size)
        self.ln2 = nn.LayerNorm(instance_size)
        self.padding_mask = None

    def set_padding_mask_block(self, padding_mask):
        self.padding_mask = padding_mask

    def forward(self, x):
        self.sa.set_padding_mask_sa(self.padding_mask)
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x
    
    

class MultiHeadSelfAttention(nn.Module):
    def __init__(self, n_head, Head_size, instance_size, p_dropout):
        super().__init__()
        self.qkv_network = nn.Linear(instance_size, Head_size * 3, bias = False) #group the computing of the nn.Linear for q, k and v, shape 
        self.proj = nn.Linear(Head_size, instance_size)
        self.attention_dropout = nn.Dropout(p_dropout)
        self.resid_dropout = nn.Dropout(p_dropout)

        self.multi_head_size = Head_size // n_head
        self.flash = False
        self.n_head = n_head
        self.Head_size = Head_size
        self.padding_mask = None

    def set_padding_mask_sa(self, padding_mask):
        self.padding_mask = padding_mask

        #self.dropout = nn.Dropout(dropout)
    def forward(self, x):
        # x of size (B, S, E)
        Batch_len, Sentence_len, _ = x.shape
        q, k, v  = self.qkv_network(x).split(self.Head_size, dim=2) #q, k, v of shape (B, S, H)
        # add a dimension to compute the different attention heads separately
        q_multi_head = q.view(Batch_len, Sentence_len, self.n_head, self.multi_head_size).transpose(1, 2) #shape B, HN, S, MH
        k_multi_head = k.view(Batch_len, Sentence_len, self.n_head, self.multi_head_size).transpose(1, 2)
        v_multi_head = v.view(Batch_len, Sentence_len, self.n_head, self.multi_head_size).transpose(1, 2)

        if self.flash:
            # efficient attention using Flash Attention CUDA kernels
            out = torch.nn.functional.scaled_dot_product_attention(q_multi_head, k_multi_head, v_multi_head, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True)
        else:    
            attention_weights = (q_multi_head @ k_multi_head.transpose(-2, -1))/np.sqrt(self.multi_head_size) # shape B, S, S
            ### padding mask #####
            attention_probas = F.softmax(attention_weights, dim=-1) # shape B, S, S
            if self.padding_mask != None:
                attention_probas = (attention_probas.transpose(0, 1)*self.padding_mask).transpose(0, 1)
           # attention_probas[attention_probas.isnan()]=0
            attention_probas = self.attention_dropout(attention_probas)


            #print(f'wei1={attention_probas}')
            #attention_probas = self.dropout(attention_probas)
            ## weighted aggregation of the values
            out = attention_probas @ v_multi_head # shape B, S, S @ B, S, MH = B, S, MH
        out = out.transpose(1, 2).contiguous().view(Batch_len, Sentence_len, self.Head_size)
        out = self.proj(out)
        out = self.resid_dropout(out)
        return out        
    
  
class FeedForward(nn.Module):
    """ a simple linear layer followed by a non-linearity"""
    def __init__(self, instance_size, p_dropout):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear( instance_size, 4 * instance_size),
            nn.ReLU(),
            nn.Linear(4 * instance_size, instance_size),
            nn.Dropout(p_dropout)
        )

    def forward(self, x):
        return self.net(x)
 

class SNPEncoding(nn.Module):
    def __init__(self, Embedding, Head_size, n_head, n_layer, p_dropout=0, device='cpu'):
        super().__init__()
       
        self.Embedding_size = Embedding.Embedding_size
        self.instance_size = self.Embedding_size
        self.blocks = nn.Sequential(*[BlockSNP(self.instance_size, n_head=n_head, Head_size=Head_size, p_dropout=p_dropout) for _ in range(n_layer)]) #Block(self.instance_size, n_head=n_head, Head_size=Head_size) 
        self.ln_f = nn.LayerNorm(self.instance_size) # final layer norm
        self.Embedding = Embedding
        self.device = device
       
        self.SNPS_embedding_table = Embedding.SNPS_embeddings
        #if self.pheno_method == 'Paul':
        # self.counts_embedding_table = Embedding.counts_embeddings

        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
            

    def create_padding_mask(self, diseases_sentence):
        B, S = np.shape(diseases_sentence)[0], np.shape(diseases_sentence)[1]
        mask = torch.where(diseases_sentence==self.padding_token)
        padding_mask_mat = torch.zeros((B, S, S), dtype=torch.bool)
        padding_mask_mat[mask] = 1
        padding_mask_mat.transpose(-2,-1)[mask] = 1

        padding_mask_probas = torch.zeros((B, S)).to(bool)
        padding_mask_probas[mask] = True
        padding_mask_probas = padding_mask_probas.view(B, S)
        return padding_mask_mat, padding_mask_probas # 1 if masked, 0 else

    def set_padding_mask_transformer(self, padding_mask, padding_mask_probas):
        self.padding_mask = padding_mask
        self.padding_mask_probas = padding_mask_probas
    
    def forward(self, SNPS_sentence):
        Batch_len, Nb_SNP = SNPS_sentence.shape
        pos_SNPS = torch.arange(Nb_SNP)*2
        SNPS_sentence = SNPS_sentence + pos_SNPS # Shape B, nb_SNPS*2
        SNP_sentences_embedded = self.SNPS_embedding_table(SNPS_sentence) # shape B, Nb_SNP, E

        #if self.pheno_method == 'Paul':
        #    counts_diseases_embedded = self.counts_embedding_table(counts_diseases) # shape B, S, E
        #    #x = x + counts_diseases_embedded # shape B, S, E 
        x = self.blocks(SNP_sentences_embedded) # shape B, S, E
        
        return x
   

class BlockSNP(nn.Module):
    def __init__(self, instance_size, n_head, Head_size, p_dropout):
        super().__init__()
        self.sa = MultiHeadSelfAttention(n_head, Head_size, instance_size, p_dropout)
        self.ffwd = FeedForward(instance_size, p_dropout)
        self.ln1 = nn.LayerNorm(instance_size)
        self.ln2 = nn.LayerNorm(instance_size)
        self.padding_mask = None

    def forward(self, x):
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x


class CrossMultiAttentionSNPPheno(nn.Module):
        # Key are the phenos, Queries are the SNPS
    def __init__(self, n_head, Head_size, instance_size_pheno, instance_size_SNPS, p_dropout):
        super().__init__()
        self.queries_network = nn.Linear(instance_size_pheno, Head_size, bias=False)
        self.keys_network = nn.Linear(instance_size_pheno, Head_size, bias=False)
        self.values_network_SNP = nn.Linear(instance_size_SNPS, Head_size, bias=False)
        self.values_network_pheno = nn.Linear(instance_size_pheno, Head_size, bias=False)


        self.attention_dropout = nn.Dropout(p_dropout)
        self.resid_dropout = nn.Dropout(p_dropout)

        self.multi_head_size = Head_size // n_head
        self.n_head = n_head
        self.Head_size = Head_size
        self.padding_mask = None

    def set_padding_mask_sa(self, padding_mask):
        self.padding_mask = padding_mask

        #self.dropout = nn.Dropout(dropout)
    def forward(self, pheno_encoded, SNPS_encoded):
        # x of size (B, S, E)
        Batch_len, Sentence_len_pheno, _ = pheno_encoded.shape
        Batch_len, Sentence_len_SNPS, _ = SNPS_encoded.shape
        keys = self.keys_network(pheno_encoded)
        queries = self.queries_network(SNPS_encoded) 
             
        values_pheno = self.values_network_pheno(pheno_encoded)
        values_SNPS = self.values_network_SNP(SNPS_encoded)
       
    
        # add a dimension to compute the different attention heads separately
        q_multi_head = queries.view(Batch_len, Sentence_len_SNPS, self.n_head, self.multi_head_size).transpose(1, 2) #shape B, HN, S_SNPS, MH
        k_multi_head = keys.view(Batch_len, Sentence_len_pheno, self.n_head, self.multi_head_size).transpose(1, 2)#shape B, HN, S_PHENO, MH
        values_pheno_multihead = values_pheno.view(Batch_len, Sentence_len_pheno, self.n_head, self.multi_head_size).transpose(1, 2)
        values_SNPS_multihead = values_SNPS.view(Batch_len, Sentence_len_SNPS, self.n_head, self.multi_head_size).transpose(1, 2)
        attention_weights = (k_multi_head @ q_multi_head.transpose(-2, -1))/np.sqrt(self.multi_head_size) # shape B, HN, S_PHENO, S_SNPS
        ### padding mask #####
        if self.padding_mask != None:
            attention_weights[self.padding_mask] = 1**-10     #float('-inf')
        #print(f'wei0={attention_weights}')
        

        #print(f'wei1={attention_probas}')
        #attention_probas = self.dropout(attention_probas)
        ## weighted aggregation of the values
       
        attention_probas_phenos = F.softmax(attention_weights, dim=-1) # shape B, S, S
        attention_probas_SNPS = F.softmax(attention_weights.transpose(-1, -2), dim=-1) # shape B, S, S


        attention_probas_phenos = self.attention_dropout(attention_probas_phenos)
        attention_probas_SNPS = self.attention_dropout(attention_probas_SNPS)

        out_pheno = attention_probas_phenos @ values_SNPS_multihead  # shape B, S, S @ B, S, MH = B, S, MH
        out_SNPS = attention_probas_SNPS @ values_pheno_multihead
        
        out_pheno = out_pheno.transpose(1, 2).contiguous().view(Batch_len, Sentence_len_pheno, self.Head_size) + values_pheno
        out_SNPS = out_SNPS.transpose(1, 2).contiguous().view(Batch_len, Sentence_len_SNPS, self.Head_size) + values_SNPS

        out_pheno = self.resid_dropout(out_pheno)
        out_SNPS = self.resid_dropout(out_SNPS)
        return out_pheno, out_SNPS       
    
class BlockCrossSNPPHENO(nn.Module):
    def __init__(self, instance_size_SNPS, instance_size_pheno, n_head, Head_size, p_dropout):
        super().__init__()
        self.ca = CrossMultiAttentionSNPPheno(n_head=n_head, Head_size=Head_size, instance_size_SNPS=instance_size_SNPS, 
                                             instance_size_pheno=instance_size_pheno,
                                              p_dropout=p_dropout)       
        self.ffwd_pheno = FeedForward(instance_size_pheno, p_dropout)
        self.ffwd_SNPS = FeedForward(instance_size_SNPS, p_dropout)

        self.ln1_pheno = nn.LayerNorm(instance_size_pheno)
        self.ln1_SNPS = nn.LayerNorm(instance_size_SNPS)

        self.proj_pheno = nn.Linear(Head_size, instance_size_pheno)
        self.proj_SNPS = nn.Linear(Head_size, instance_size_SNPS)



        self.padding_mask = None

    def forward(self, encoded_phenos, encoded_SNPS):
        encoded_phenos = self.ln1_pheno(encoded_phenos)
        encoded_SNPS = self.ln1_SNPS(encoded_SNPS)

        out_pheno, out_SNPS = self.ca(encoded_phenos, encoded_SNPS)

        
        out_pheno = self.proj_pheno(out_pheno)
        out_pheno = self.ln1_pheno(out_pheno)
        out_pheno = out_pheno + self.ffwd_pheno(out_pheno)

        out_SNPS = self.proj_SNPS(out_SNPS)
        out_SNPS = self.ln1_SNPS(out_SNPS) 
        out_SNPS = out_SNPS + self.ffwd_SNPS(out_SNPS)  
            
        return out_pheno, out_SNPS
    
class CrossPadMaskSequential(nn.Sequential):
    def __init__(self, *args):
        super(CrossPadMaskSequential, self).__init__(*args)
        self.padding_mask = None

    def set_padding_mask_sequential(self, padding_mask):
        self.padding_mask = padding_mask

    def forward(self, encoded_phenos, encoded_SNPS):
        for module in self:
            encoded_phenos, encoded_SNPS = module(encoded_phenos, encoded_SNPS)
        return encoded_phenos, encoded_SNPS
   
class PredictLogit(nn.Module):
    def __init__(self, instance_size_SNPS, instance_size_pheno, nb_phenos_possible):
        super().__init__()
        self.ln2_phenos = nn.LayerNorm(instance_size_pheno)
        self.ln2_SNPS = nn.LayerNorm(instance_size_SNPS)

        
        self.get_logits_phenos = nn.Linear(instance_size_pheno, nb_phenos_possible)
        self.get_logits_SNPS = nn.Linear(instance_size_SNPS, 2)
    
    def forward(self, pheno_sentence, SNPS_sentence, value):
        if value == 'pheno':
            pheno_sentence = self.ln2_phenos(pheno_sentence)
            logits = self.get_logits_phenos(pheno_sentence)
        else:
            SNPS_sentence = self.ln2_SNPS(SNPS_sentence)
            logits = self.get_logits_SNPS(SNPS_sentence)
        logits_mean = logits.mean(axis=1)
        return logits_mean

class GenerativeModelPheWasV1(nn.Module):
    def __init__(self, n_head_pheno, Head_size_pheno, Embedding_pheno, Embedding_SNPS, instance_size_pheno,
                n_layer_pheno,  nb_SNPS, n_layer_SNPS, n_head_SNPS, Head_size_SNPS, instance_size_SNPS, nb_phenos_possible,
                n_head_cross, Head_size_cross, n_layer_cross, p_dropout, device='cpu', mask_padding=True,
                loss_version_pheno='cross_entropy', loss_version_SNPS='cross_entropy', gamma=2, alpha=1, padding_token=0):
        super().__init__()
        print(device, flush=True)
        self.Embedding_pheno = Embedding_pheno
        self.Embedding_SNPS = Embedding_SNPS
        self.Embedding_size_pheno = Embedding_pheno.Embedding_size
        self.Embedding_size_SNP = Embedding_SNPS.Embedding_size
        
        self.instance_size_pheno = instance_size_pheno
        self.n_head_pheno = n_head_pheno
        self.Head_size_pheno = Head_size_pheno
        self.n_layer_pheno = n_layer_pheno
        self.loss_version_pheno = loss_version_pheno

        self.Head_size_SNPS = Head_size_SNPS
        self.n_layer_SNPS = n_layer_SNPS
        self.nb_SNPS = nb_SNPS
        self.n_head_SNPS = n_head_SNPS
        self.instance_size_SNPS = instance_size_SNPS
        self.loss_version_SNPS = loss_version_SNPS


        self.gamma = gamma
        self.alpha = alpha

        self.n_layer_cross = n_layer_cross
        self.Head_size_cross = Head_size_cross
        self.n_head_cross = n_head_cross
        self.p_dropout = p_dropout

       


        self.nb_phenos_possible = nb_phenos_possible
        self.device = device
        self.padding_token = padding_token

      

        self.encoding_phenos = PhenotypeEncoding(Embedding=Embedding_pheno, Head_size=Head_size_pheno, 
            n_head=n_head_pheno, n_layer=n_layer_pheno, instance_size=instance_size_pheno, device=device, mask_padding=mask_padding,
            p_dropout=p_dropout, padding_token=self.padding_token)
        self.encoding_SNPS = SNPEncoding(Embedding=Embedding_SNPS, Head_size=Head_size_SNPS, n_head=n_head_SNPS,
                    device=device, n_layer=n_layer_pheno, p_dropout=p_dropout)
        self.blocks = CrossPadMaskSequential(*[ BlockCrossSNPPHENO(n_head=n_head_cross, Head_size=Head_size_cross, 
                                             instance_size_SNPS=instance_size_SNPS, 
                                             instance_size_pheno=instance_size_pheno,
                                             p_dropout=p_dropout) for _ in range(n_layer_cross)]) #Block(self.instance_size, n_head=n_head, Head_size=Head_size) 

        self.predict_logit = PredictLogit(instance_size_pheno=instance_size_pheno, instance_size_SNPS=instance_size_SNPS, nb_phenos_possible=nb_phenos_possible-1) #-1 for padding



    def forward(self, diseases_sentence, SNPS_sentence, value, targets=None):
        diseases_sentence = diseases_sentence.to(self.device)
        SNPS_sentence = SNPS_sentence.to(self.device)
        
        print('diseases_sentence_device'+str(diseases_sentence.device), flush=True)
        phenotype_encoded = self.encoding_phenos(diseases_sentence)
        SNPS_encoded = self.encoding_SNPS(SNPS_sentence)

        out_pheno, out_SNPS = self.blocks(phenotype_encoded, SNPS_encoded)
        logits = self.predict_logit(out_pheno, out_SNPS, value)

        loss = None
    
        if targets != None:
            targets = targets.to(self.device)
            if value == 'pheno':
                loss = self.calcul_loss_pheno(logits, targets, loss_version=self.loss_version_pheno)
            elif value == 'SNP':
                loss  = self.calcul_loss_SNPS(logits, targets, loss_version=self.loss_version_SNPS, gamma=self.gamma, alpha=self.alpha)
        
        return logits, loss

    def calcul_loss_SNPS(self, logits, targets=None, loss_version='cross_entropy', gamma=None, alpha=None):
        if targets is None:
            loss = None
        else:
            #target : shape B,
            
            if loss_version == 'cross_entropy':
                cross_entropy = F.cross_entropy(logits, targets)
                return cross_entropy
            elif loss_version == 'focal_loss':
                alphas = ((1 - targets) * (alpha-1)).to(torch.float) + 1
                probas = F.softmax(logits)
                certidude_factor =  (1-probas[[range(probas.shape[0]), targets]])**gamma * alphas
                cross_entropy = F.cross_entropy(logits, targets, reduce=False)
                loss = torch.dot(cross_entropy, certidude_factor)
                return loss
            elif loss_version == 'predictions':
                probas = F.softmax(logits)
                predictions = (probas[:,1] > 0.5).to(int)
                return torch.sum((predictions-targets)**2)/len(predictions)
        
    def predict_pheno(self, pheno_sentence, SNPS_sentences):
        self.eval()
        logits, loss = self.forward(pheno_sentence, SNPS_sentences, value='pheno')
        return torch.argmax(logits, axis=1)
        self.train()

    def calcul_loss_pheno(self, logits, targets=None, loss_version='cross_entropy'):
        if targets is None:
            loss = None
        else:
            logits_similarities_embed = self.Embedding_pheno.similarities_tab[targets-1] #-1 to get at the level of without padding
            loss = F.cross_entropy(logits, logits_similarities_embed )

        return loss

    def evaluate(self, dataloader_test):
        print('beginning inference evaluation')
        start_time_inference = time.time()
        predicted_labels_list = []
        predicted_probas_list = []
        true_labels_list = []

        total_loss = 0.
        self.eval()
        with torch.no_grad():
            for batch_sentences_pheno, batch_labels_pheno, batch_sentences_SNPS in dataloader_test:


                logits, loss = self(batch_sentences_pheno, batch_sentences_SNPS,  value = 'pheno', targets=batch_labels_pheno)
                total_loss += loss.item()
                predicted_labels = self.predict_pheno(batch_sentences_pheno, batch_sentences_SNPS)
                predicted_labels_list.extend(predicted_labels.cpu().numpy())
                predicted_probas = F.softmax(logits, dim=1)
                predicted_probas_list.extend(predicted_probas.cpu().numpy())
                true_labels_list.extend(batch_labels_pheno.cpu().numpy())
        f1 = f1_score(true_labels_list, predicted_labels_list, average='macro')
        accuracy = accuracy_score(true_labels_list, predicted_labels_list)
        auc_score = 0#calculate_roc_auc(true_labels_list, np.array(predicted_probas_list)[:, 1], return_nan=True)
        proba_avg_zero, proba_avg_one = get_proba(true_labels_list, predicted_probas_list)
    
        self.train()
        print(f'end inference evaluation in {time.time() - start_time_inference}s')
        return f1, accuracy, auc_score, total_loss/len(dataloader_test), proba_avg_zero, proba_avg_one, predicted_probas_list, true_labels_list


        
            
        
            

In [None]:
### data constants:
CHR = 1
SNP = 'rs673604'
pheno_method = 'Abby' # Paul, Abby
rollup_depth = 4
Classes_nb = 2 #nb of classes related to an SNP (here 0 or 1)
vocab_size = None # to be defined with data
padding_token = 0
prop_train_test = 0.8
load_data = True
save_data = False
remove_none = True
decorelate = False
equalize_label = False
threshold_corr = 0.9
threshold_rare = 50
remove_rare = 'all' # None, 'all', 'one_class'
compute_features = True
padding = False
list_env_features = ['age', 'sex']
### data format
batch_size = 20
data_share = 1/1000

dataT = DataTransfo_1SNP(SNP=SNP,
                         CHR=CHR,
                         method=pheno_method,
                         padding=padding,  
                         pad_token=padding_token, 
                         load_data=load_data, 
                         save_data=save_data, 
                         compute_features=compute_features,
                         prop_train_test=prop_train_test,
                         remove_none=True,
                         equalize_label=equalize_label,
                         rollup_depth=rollup_depth,
                         decorelate=decorelate,
                         threshold_corr=threshold_corr,
                         threshold_rare=threshold_rare,
                         remove_rare=remove_rare, 
                         list_env_features=list_env_features,
                         data_share=data_share)
#patient_list = dataT.get_patientlist()
patient_list = dataT.get_patientlist()


In [None]:
patient_list.unpad_data()

In [None]:
nb_phenos = patient_list.get_nb_distinct_diseases_tot()

In [None]:
list_pheno_truth = []
list_labels = []
list_diseases_sentence_masked = []
for patient in patient_list[:5]:
    diseases_sentence = torch.tensor(patient.diseases_sentence)
    nb_diseases = len(diseases_sentence)
    masks = np.zeros((nb_diseases, nb_diseases)).astype(bool)
    np.fill_diagonal(masks,True)
    diseases_sentence_masked = np.tile(diseases_sentence, nb_diseases).reshape(nb_diseases, nb_diseases)
    pheno_Truth = diseases_sentence_masked[masks]
    labels = [np.array([patient.SNP_label])]*nb_diseases
    diseases_sentence_masked[masks] = nb_phenos 

    list_pheno_truth.extend(pheno_Truth)
    list_labels.extend(labels)
    list_diseases_sentence_masked.extend(diseases_sentence_masked)

In [None]:
patient_list

In [None]:
## padding the data
list_diseases_new = []
nb_max_distinct_diseases_patient= patient_list.get_nb_max_distinct_diseases_patient() 
for list_diseases in list_diseases_sentence_masked:
    padd = np.zeros(nb_max_distinct_diseases_patient- len(list_diseases), dtype=int)
    list_diseases_new.append(np.concatenate([list_diseases, padd]).astype(int))
list_diseases_sentence_masked = list_diseases_new

In [None]:
list_data_gen = list(zip(list_diseases_sentence_masked, list_pheno_truth, list_labels))


In [None]:
indices= np.arange(len(list_data_gen))
np.random.shuffle(indices)
indices_train= indices[:int(prop_train_test * len(list_data_gen))]
indices_test = indices[int(prop_train_test * len(list_data_gen)):]

In [None]:
data_training = [list_data_gen[i] for i in indices_train]
data_test = [list_data_gen[i] for i in indices_test]


In [None]:
dataloader_training = DataLoader(data_training, batch_size=batch_size, shuffle=True)
dataloader_test = DataLoader(data_test, batch_size=batch_size, shuffle=True)

In [None]:
rollup_depth = 4
Head_size_pheno = 4
n_head_pheno = 2
n_layer_pheno = 2
instance_size_pheno = 10
Embedding_size_pheno = 10
embedding_method_pheno = None
proj_embed_pheno = False
freeze_embed_pheno = False
loss_version_pheno = 'cross_entropy'
p_dropout = 0.1
device = 'cuda' if torch.cuda.is_available() else 'cpu'

pheno_method = 'Abby'
embedding_method_pheno = None
embedding_method_SNPS = None
freeze_embed_SNPS = False
nb_phenos = patient_list.get_nb_distinct_diseases_tot()
nb_SNPS = 2
Embedding_size_SNPS = 10
n_head_SNPS = 2
Head_size_SNPS = 4
loss_version_SNPS = 'cross_entropy'
n_layer_SNPS = 2
instance_size_SNPS = 10
mask_padding = True
#multi
n_head_cross = 2
Head_size_cross = 4
n_layer_cross = 2
instance_size_cross = 10

nb_phenos_possible = patient_list.get_nb_distinct_diseases_tot()
vocab_size = nb_phenos_possible + 1 # masking
##### training constants
total_epochs = 100# number of epochs
learning_rate_max = 0.001 # maximum learning rate (at the end of the warmup phase)
learning_rate_ini = 0.00001 # initial learning rate 
learning_rate_final = 0.0001
warm_up_frac = 0.5 # fraction of the size of the warmup stage with regards to the total number of epochs.
start_factor_lr = learning_rate_ini / learning_rate_max
end_factor_lr = learning_rate_final / learning_rate_max
warm_up_size = int(warm_up_frac*total_epochs)
padding_masking = True

eval_batch_interval = 40
eval_epochs_interval = 1

In [None]:
Embedding_pheno = EmbeddingPheno(method=embedding_method_pheno, vocab_size=vocab_size, Embedding_size=Embedding_size_pheno,
     rollup_depth=rollup_depth, freeze_embed=freeze_embed_pheno, dicts=dataT.dicts)

Embedding_SNPS = EmbeddingSNPS(method=embedding_method_SNPS, nb_SNPS=nb_SNPS, Embedding_size=Embedding_size_SNPS, freeze_embed=freeze_embed_SNPS)
    



In [None]:

model = GenerativeModelPheWasV1(n_head_pheno=n_head_pheno, Head_size_pheno=Head_size_pheno, Embedding_pheno=Embedding_pheno, Embedding_SNPS=Embedding_SNPS,
    instance_size_pheno=instance_size_pheno, n_layer_pheno=n_layer_pheno,  nb_SNPS=nb_SNPS, n_layer_SNPS=n_layer_SNPS, n_head_SNPS=n_head_SNPS, mask_padding=mask_padding,
    Head_size_SNPS=Head_size_SNPS, instance_size_SNPS=instance_size_SNPS, nb_phenos_possible=nb_phenos_possible,
    n_head_cross=n_head_cross, Head_size_cross=Head_size_cross, n_layer_cross=n_layer_cross, p_dropout=p_dropout, device=device,
    loss_version_pheno=loss_version_pheno, loss_version_SNPS=loss_version_SNPS, gamma=2, alpha=1, padding_token=padding_token)



In [None]:

optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate_max)
lr_scheduler_warm_up = LinearLR(optimizer, start_factor=start_factor_lr , end_factor=1, total_iters=warm_up_size, verbose=False) # to schedule a modification in the learning rate
lr_scheduler_final = LinearLR(optimizer, start_factor=1, total_iters=total_epochs-warm_up_size, end_factor=end_factor_lr)
lr_scheduler = SequentialLR(optimizer, schedulers=[lr_scheduler_warm_up, lr_scheduler_final], milestones=[warm_up_size])

In [None]:
# Training Loop
start_time_training = time.time()
print(f'Beginning of the program for {total_epochs} epochs')
# Training Loop
for epoch in range(1, total_epochs+1):

    start_time_epoch = time.time()
    total_loss = 0.0  
    
    #with tqdm(total=len(dataloader_train), position=0, leave=True) as pbar:
    for k, (batch_sentences_pheno, batch_labels_pheno, batch_sentences_SNPS) in enumerate(dataloader_training):
        start_time_batch = time.time()
        
        batch_sentences_pheno = batch_sentences_pheno.to(device)
        batch_labels_pheno = batch_labels_pheno.to(device)
        batch_sentences_SNPS = batch_sentences_SNPS.to(device)

        # evaluate the loss
        logits, loss = model(batch_sentences_pheno, batch_sentences_SNPS,value='pheno', targets= batch_labels_pheno)
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
    

        total_loss += loss.item()

        optimizer.step()

        if k % eval_batch_interval == 0:
            print( f'Progress in epoch {epoch}  = {round(k / len(dataloader_training)*100, 2)} %, time batch : {time.time() - start_time_batch}')

    if epoch % eval_epochs_interval == 0:
        f1_val, accuracy_val, auc_score_val, loss_val, proba_avg_zero_val, proba_avg_one_val, predicted_probas_list_val, true_labels_list_val = model.evaluate(dataloader_test)
        print( f'epoch {epoch}, time epoch : {time.time() - start_time_epoch}')

    
    
    lr_scheduler.step()


In [None]:
model.forward_decomposed(batch_sentences_pheno,
        batch_labels_pheno,
        batch_sentences_SNPS)

In [None]:

output_file = '/gpfs/commons/groups/gursoy_lab/mstoll/codes/logs/TestGene_output.txt'
error_file  = '/gpfs/commons/groups/gursoy_lab/mstoll/codes/logs/TestGene_error.txt'


In [None]:
os.makefile(output_file)

In [None]:
# Training Loop
start_time_training = time.time()
print_file(output_file, f'Beginning of the program for {total_epochs} epochs')
# Training Loop
for epoch in range(1, total_epochs+1):

    start_time_epoch = time.time()
    total_loss = 0.0  
    
    #with tqdm(total=len(dataloader_train), position=0, leave=True) as pbar:
    for k, (batch_sentences_pheno, batch_labels_pheno, batch_sentences_SNPS) in enumerate(dataloader_training):
        start_time_batch = time.time()
        
        batch_sentences_pheno = batch_sentences_pheno.to(device)
        batch_labels_pheno = batch_labels_pheno.to(device)
        batch_sentences_SNPS = batch_sentences_SNPS.to(device)

        # evaluate the loss
        logits, loss = model(batch_sentences_pheno, batch_sentences_SNPS,value='pheno', targets= batch_labels_pheno)
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
    

        total_loss += loss.item()

        optimizer.step()

        if k % eval_batch_interval == 0:
            clear_last_line(output_file)
            print_file(output_file, f'Progress in epoch {epoch}  = {round(k / len(dataloader_training)*100, 2)} %, time batch : {time.time() - start_time_batch}')

    if epoch % eval_epochs_interval == 0:
        f1_val, accuracy_val, auc_score_val, loss_val, proba_avg_zero_val, proba_avg_one_val, predicted_probas_list_val, true_labels_list_val = model.evaluate(dataloader_test)
        print_file(output_file,  f'epoch {epoch}, time epoch : {time.time() - start_time_epoch}')

    
    
    lr_scheduler.step()


In [None]:
batch_sentences_pheno

In [None]:
model.forward(diseases_sentence, SNPS_sentence, 'f',targets=torch.tensor([0,1]) )

In [None]:
pheno_encoded = phenotype_encoding.forward(diseases_sentence)
SNPS_encoded = SNP_encoding.forward(SNP_sentence)
pheno_encoded.shape, SNP_encoded.shape

In [None]:
out = crossattention.forward(pheno_encoded=pheno_encoded, SNPS_encoded=SNPS_encoded, value='phenos')

In [None]:
out.shape

In [None]:
keys_network = nn.Linear(instance_size_pheno, Head_size_cross, bias = False)
values_network_SNPS = nn.Linear(instance_size_SNPS, Head_size_cross, bias = False)
values_network_pheno= nn.Linear(instance_size_pheno, Head_size_cross, bias = False)

queries_network = nn.Linear(instance_size_SNPS, Head_size_cross, bias = False)
multi_head_size = Head_size_cross // n_head_cross

Batch_len, Sentence_len_pheno, _ = pheno_encoded.shape
Batch_len, Sentence_len_SNPS, _ = SNPS_encoded.shape
Sentence_len_out = Sentence_len_pheno 
keys = keys_network(pheno_encoded)
queries = queries_network(SNPS_encoded) 
values_pheno = values_network_pheno(pheno_encoded)
values_SNPS = values_network_SNPS(SNPS_encoded)

# add a dimension to compute the different attention heads separately
q_multi_head = queries.view(Batch_len, Sentence_len_SNPS, n_head_cross, multi_head_size).transpose(1, 2) #shape B, HN, S_SNPS, MH
k_multi_head = keys.view(Batch_len, Sentence_len_pheno, n_head_cross, multi_head_size).transpose(1, 2)#shape B, HN, S_PHENO, MH
values_SNPS_multihead = values_SNPS.view(Batch_len, Sentence_len_SNPS, n_head_cross, multi_head_size).transpose(1, 2)
values_pheno_multihead = values_pheno.view(Batch_len, Sentence_len_pheno, n_head_cross, multi_head_size).transpose(1, 2)


attention_weights = (k_multi_head @ q_multi_head.transpose(-2, -1))/np.sqrt(multi_head_size) # shape B, HN, S_PHENO, S_SNPS
attention_probas = F.softmax(attention_weights, dim=-1)
out = attention_probas @ values_SNPS_multihead # shape B, S, S @ B, S, MH = B, S, MH
out = out.transpose(2, 1).contiguous().view(Batch_len, Sentence_len_pheno, Head_size_cross)

In [None]:
values_pheno_multihead.shape, attention_probas.shape

In [None]:
attention_probas.transpose(-1, -2).shape

In [None]:
a = q_multi_head.transpose(-2, -1)
u = k_multi_head @a


In [None]:
a.shape, k_multi_head.shape

In [None]:
attention_probas = F.softmax(attention_weights, dim=-1) # shape B, S, S


In [None]:
attention_weights.shape

In [None]:
attention_weights.transpose(1, 2)..shape

In [None]:
contiguous().view(Batch_len, Sentence_len_out, Head_size_cross)

In [None]:
## padding mask #####
if padding_mask != None:
    attention_weights[padding_mask] = 1**-10     #float('-inf')
#print(f'wei0={attention_weights}')
attention_probas = F.softmax(attention_weights, dim=-1) # shape B, S, S
attention_probas = attention_dropout(attention_probas)

#print(f'wei1={attention_probas}')
#attention_probas = dropout(attention_probas)
## weighted aggregation of the values
out = attention_probas @ v_multi_head # shape B, S, S @ B, S, MH = B, S, MH
out = out.transpose(1, 2).contiguous().view(Batch_len, Sentence_len_out, Head_size)
out = proj(out)
out = resid_dropout(out)

In [None]:

    def predict(self, diseases_sentence, diseases_count):
        logits, _ = self(diseases_sentence, diseases_count) # shape B, Classes_nb
        return torch.argmax(logits, dim=1)  # (B,)
        
    def predict_proba(self, diseases_sentence, diseases_count):
        logits, _ = self(diseases_sentence, diseases_count)
        probabilities = F.softmax(logits, dim=1)
        return probabilities

In [None]:
x = self.ln_f(x) # shape B, S, E
        logits = self.lm_head_logits(x) #shape B, S, Classes_Numb, token logits
        weights_logits = self.lm_head_proba(x).view(Batch_len, Sentence_len)
        weights_logits[self.padding_mask_probas] = -torch.inf
        probas = F.softmax(weights_logits) # shape B, S(represent the probas to be chosen)
        logits = (logits.transpose(1, 2) @ probas.view(Batch_len, Sentence_len, 1)).view(Batch_len, self.Classes_nb)# (B,Classes_Numb) Weighted Average logits
        loss = calculate_loss(logits, targets, self.loss_version, self.gamma, self.alpha)
        return logits, loss

In [None]:
embedding_file_diseases = f'/gpfs/commons/groups/gursoy_lab/mstoll/codes/Data_Files/Embeddings/Abby/embedding_abby_no_1_diseases.pth'
x = torch.load(embedding_file_diseases)
target = torch.tensor([1])

In [None]:
x.shape

In [None]:
F.cosine_embedding_loss(x, x, target)

In [None]:
similarities = F.cosine_similarity(x, x[1], dim=-1).view(1, 1718)

In [None]:
targets = torch.tensor([1, 2])

In [None]:
x = Embedding_pheno.distinct_diseases_embeddings

In [None]:
x

In [None]:
x

In [None]:
batch_sentences_pheno[batch_sentences_pheno==-1] = 0

In [None]:
x(batch_sentences_pheno)

In [None]:
similarities_tab = torch.tensor(np.array([F.cosine_similarity(x, x[i], dim=-1).view(1, 1718) for i in range(x.shape[0])]))

In [None]:
u = l[targets]

In [None]:
F.cross_entropy(logits, targets)

In [None]:
logits = torch.rand((2, 1718))

In [None]:
similarities.shape

In [None]:
similarities.shape

In [None]:
import sys
path = '/gpfs/commons/groups/gursoy_lab/mstoll/'
sys.path.append(path)

import os
import pandas as pd
import numpy as np 
import time
import pickle
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import LambdaLR, LinearLR, SequentialLR
from sklearn.metrics import f1_score, accuracy_score
from torch.utils.tensorboard import SummaryWriter


from codes.models.data_form.DataForm import DataTransfo_1SNP, PatientList
from codes.models.metrics import calculate_roc_auc, calculate_classification_report, calculate_loss, get_proba
from codes.models.Generative.Embeddings import EmbeddingPheno, EmbeddingSNPS
from codes.models.Generative.GenerativeModel import GenerativeModelPheWasV1
from codes.models.utils import print_file, plot_infos, plot_ini_infos, clear_last_line
from sklearn.metrics import f1_score, accuracy_score


import matplotlib.pyplot as plt


### data constants:
model_type = 'Generative_Transformer'
model_version = 'V1'
test_name = 'tests_generative_1'
CHR = 1
SNP = 'rs673604'
pheno_method = 'Abby' # Paul, Abby
rollup_depth = 4
Classes_nb = 2 #nb of classes related to an SNP (here 0 or 1)
vocab_size = None # to be defined with data
padding_token = 0
prop_train_test = 0.8
load_data = True
save_data = False
remove_none = True
decorelate = False
equalize_label = False
threshold_corr = 0.9
threshold_rare = 50
remove_rare = 'all' # None, 'all', 'one_class'
compute_features = True
padding = False
list_env_features = ['age', 'sex']
### data format
batch_size = 20
data_share = 1/10000

dataT = DataTransfo_1SNP(SNP=SNP,
                         CHR=CHR,
                         method=pheno_method,
                         padding=padding,  
                         pad_token=padding_token, 
                         load_data=load_data, 
                         save_data=save_data, 
                         compute_features=compute_features,
                         prop_train_test=prop_train_test,
                         remove_none=True,
                         equalize_label=equalize_label,
                         rollup_depth=rollup_depth,
                         decorelate=decorelate,
                         threshold_corr=threshold_corr,
                         threshold_rare=threshold_rare,
                         remove_rare=remove_rare, 
                         list_env_features=list_env_features,
                         data_share=data_share)
#patient_list = dataT.get_patientlist()
patient_list = dataT.get_patientlist()
patient_list.unpad_data()


rollup_depth = 4
Head_size_pheno = 4
n_head_pheno = 2
n_layer_pheno = 2
instance_size_pheno = 10
Embedding_size_pheno = 10
embedding_method_pheno = None
proj_embed_pheno = False
freeze_embed_pheno = False
loss_version_pheno = 'cross_entropy'
p_dropout = 0.1
device = 'cpu'
pheno_method = 'Abby'
embedding_method_pheno = None
embedding_method_SNPS = None
freeze_embed_SNPS = False
nb_phenos = patient_list.get_nb_distinct_diseases_tot()
nb_SNPS = 2
Embedding_size_SNPS = 10
n_head_SNPS = 2
Head_size_SNPS = 4
loss_version_SNPS = 'cross_entropy'
n_layer_SNPS = 2
instance_size_SNPS = 10
mask_padding = True
#multi
n_head_cross = 2
Head_size_cross = 4
n_layer_cross = 2
instance_size_cross = 10

nb_phenos_possible = patient_list.get_nb_distinct_diseases_tot()
vocab_size = nb_phenos_possible + 1 # masking
##### training constants
total_epochs = 10# number of epochs
learning_rate_max = 0.001 # maximum learning rate (at the end of the warmup phase)
learning_rate_ini = 0.00001 # initial learning rate 
learning_rate_final = 0.0001
warm_up_frac = 0.5 # fraction of the size of the warmup stage with regards to the total number of epochs.
start_factor_lr = learning_rate_ini / learning_rate_max
end_factor_lr = learning_rate_final / learning_rate_max
warm_up_size = int(warm_up_frac*total_epochs)
padding_masking = True

eval_batch_interval = 40
eval_epochs_interval = 1

#################### generate the ouptut files and dirs ############################################
path = '/gpfs/commons/groups/gursoy_lab/mstoll/codes/'
#check test name
model_dir = path + f'logs/runs/SNPS/{str(CHR)}/{SNP}/{model_type}/{model_version}/{pheno_method}'
model_plot_dir = path + f'logs/plots/tests/SNP/{str(CHR)}/{SNP}/{model_type}/{model_version}/{pheno_method}/'

os.makedirs(model_dir, exist_ok=True)
os.makedirs(model_plot_dir, exist_ok=True)
#check number tests
test_dir = f'{model_dir}/{test_name}/'
print(test_dir)
log_data_dir = f'{test_dir}/data/'
log_tensorboard_dir = f'{test_dir}/tensorboard/'
log_slurm_outputs_dir = f'{test_dir}/Slurm/Outputs/'
log_slurm_errors_dir = f'{test_dir}/Slurm/Errors/'
os.makedirs(log_data_dir, exist_ok=True)
os.makedirs(log_tensorboard_dir, exist_ok=True)
os.makedirs(log_slurm_outputs_dir, exist_ok=True)
os.makedirs(log_slurm_errors_dir, exist_ok=True)


log_data_path_pickle = f'{test_dir}/data/{test_name}.pkl'
log_tensorboard_path = f'{test_dir}/tensorboard/{test_name}'
log_slurm_outputs_path = f'{test_dir}/Slurm/Outputs/{test_name}.txt'
log_slurm_error_path = f'{test_dir}/Slurm/Errors/{test_name}.txt'
model_plot_path = path + f'logs/plots/tests/SNP/{str(CHR)}/{SNP}/{model_type}/{model_version}/{pheno_method}/{test_name}.png'

############ generate the masked list of diseases #############################################
start_time = time.time()
print('generating the data files')
list_pheno_truth = []
list_labels = []
list_diseases_sentence_masked = []
for patient in patient_list:
    diseases_sentence = torch.tensor(patient.diseases_sentence)
    nb_diseases = len(diseases_sentence)
    masks = np.zeros((nb_diseases, nb_diseases)).astype(bool)
    np.fill_diagonal(masks,True)
    diseases_sentence_masked = np.tile(diseases_sentence, nb_diseases).reshape(nb_diseases, nb_diseases)
    pheno_Truth = diseases_sentence_masked[masks]
    labels = [np.array([patient.SNP_label])]*nb_diseases
    diseases_sentence_masked[masks] = nb_phenos 

    list_pheno_truth.extend(pheno_Truth)
    list_labels.extend(labels)
    list_diseases_sentence_masked.extend(diseases_sentence_masked)
print(f'generated files in {time.time() - start_time} seconds')

################################### padding the data ###################################################
list_diseases_new = []
nb_max_distinct_diseases_patient= patient_list.get_nb_max_distinct_diseases_patient() 
for list_diseases in list_diseases_sentence_masked:
    padd = np.zeros(nb_max_distinct_diseases_patient- len(list_diseases), dtype=int)
    list_diseases_new.append(np.concatenate([list_diseases, padd]).astype(int))
list_diseases_sentence_masked = list_diseases_new


list_data_gen = list(zip(list_diseases_sentence_masked, list_pheno_truth, list_labels))
indices= np.arange(len(list_data_gen))
np.random.shuffle(indices)
indices_train= indices[:int(prop_train_test * len(list_data_gen))]
indices_test = indices[int(prop_train_test * len(list_data_gen)):]


data_training = [list_data_gen[i] for i in indices_train]
data_test = [list_data_gen[i] for i in indices_test]


dataloader_train = DataLoader(data_training, batch_size=batch_size, shuffle=True)
dataloader_test = DataLoader(data_test, batch_size=batch_size, shuffle=True)



Embedding_pheno = EmbeddingPheno(method=embedding_method_pheno, vocab_size=vocab_size, Embedding_size=Embedding_size_pheno,
     rollup_depth=rollup_depth, freeze_embed=freeze_embed_pheno, dicts=dataT.dicts)

Embedding_SNPS = EmbeddingSNPS(method=embedding_method_SNPS, nb_SNPS=nb_SNPS, Embedding_size=Embedding_size_SNPS, freeze_embed=freeze_embed_SNPS)
    

model = GenerativeModelPheWasV1(n_head_pheno=n_head_pheno, Head_size_pheno=Head_size_pheno, Embedding_pheno=Embedding_pheno, Embedding_SNPS=Embedding_SNPS,
    instance_size_pheno=instance_size_pheno, n_layer_pheno=n_layer_pheno,  nb_SNPS=nb_SNPS, n_layer_SNPS=n_layer_SNPS, n_head_SNPS=n_head_SNPS, mask_padding=mask_padding,
    Head_size_SNPS=Head_size_SNPS, instance_size_SNPS=instance_size_SNPS, nb_phenos_possible=nb_phenos_possible,
    n_head_cross=n_head_cross, Head_size_cross=Head_size_cross, n_layer_cross=n_layer_cross, p_dropout=p_dropout, device=device,
    loss_version_pheno=loss_version_pheno, loss_version_SNPS=loss_version_SNPS, gamma=2, alpha=1, padding_token=padding_token)



optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate_max)
lr_scheduler_warm_up = LinearLR(optimizer, start_factor=start_factor_lr , end_factor=1, total_iters=warm_up_size, verbose=False) # to schedule a modification in the learning rate
lr_scheduler_final = LinearLR(optimizer, start_factor=1, total_iters=total_epochs-warm_up_size, end_factor=end_factor_lr)
lr_scheduler = SequentialLR(optimizer, schedulers=[lr_scheduler_warm_up, lr_scheduler_final], milestones=[warm_up_size])


######################################################## Training Loop ###################################################
output_file = log_slurm_outputs_path
writer = SummaryWriter(log_tensorboard_path)

## Open tensor board writer
dic_features_list = {
'list_training_loss' : [],
'list_validation_loss' : [],
'list_proba_avg_zero' : [],
'list_proba_avg_one' : [],
'list_auc_validation' : [],
'list_accuracy_validation' : [],
'list_f1_validation' : [],
'epochs' : [] }

# Training Loop
start_time_training = time.time()
print_file(output_file, f'Beginning of the program for {total_epochs} epochs', new_line=True)
# Training Loop
plot_ini_infos(model, output_file, dataloader_test, dataloader_train, writer, dic_features_list)
for epoch in range(1, total_epochs+1):

    start_time_epoch = time.time()
    total_loss = 0.0  
    
    #with tqdm(total=len(dataloader_train), position=0, leave=True) as pbar:
    for k, (batch_sentences_pheno, batch_labels_pheno, batch_sentences_SNPS) in enumerate(dataloader_train):
        start_time_batch = time.time()
        
        batch_sentences_pheno = batch_sentences_pheno.to(device)
        batch_labels_pheno = batch_labels_pheno.to(device)
        batch_sentences_SNPS = batch_sentences_SNPS.to(device)

        # evaluate the loss
        logits, loss = model(batch_sentences_pheno, batch_sentences_SNPS,value='pheno', targets= batch_labels_pheno)
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
    

        total_loss += loss.item()

        optimizer.step()

        if k % eval_batch_interval == 0:
            clear_last_line(output_file)
            print_file(output_file, f'Progress in epoch {epoch}  = {round(k / len(dataloader_train)*100, 2)} %, time batch : {time.time() - start_time_epoch}', new_line=False)

    if epoch % eval_epochs_interval == 0:
        dic_features = plot_infos(model, output_file, epoch, total_loss, start_time_epoch, dataloader_train, dataloader_test, optimizer, writer, dic_features_list, model_plot_path)

    
    
    lr_scheduler.step()

dic_features = dic_features
model.to('cpu')
#model.write_embedding(writer)
# Print time
print_file(output_file, f"Training finished: {int(time.time() - start_time_training)} seconds", new_line=True)
start_time = time.time()





## Add hyper parameters to tensorboard
hyperparams = {"CHR" : CHR, "SNP" : SNP, "ROLLUP LEVEL" : rollup_depth,
            'PHENO_METHOD': pheno_method, 'EMBEDDING_METHOD': embedding_method_pheno,
            'EMBEDDING SIZE' : Embedding_size_pheno, 'ATTENTION HEADS' : n_head_pheno, 'BLOCKS' : n_layer_pheno,
            'LR':1 , 'DROPOUT' : p_dropout, 'NUM_EPOCHS' : total_epochs, 
            'BATCH_SIZE' : batch_size, 
            'PADDING_MASKING': padding_masking,
            'VERSION' : model_version,
            'NB_Patients'  : len(patient_list),
            'LOSS_VERSION'  : loss_version_pheno,
            }

writer.add_hparams(hyperparams, dic_features)




In [None]:
diseases_sentence == 0

In [None]:
for list_diseases in list_diseases_sentence_masked:
    padd = np.zeros(nb_max_distinct_diseases_patient- len(list_diseases), dtype=int)



In [None]:
list_diseases

In [None]:
nb_max_distinct_diseases_patient- len(list_diseases)

In [None]:
############# test.py file
import sys
path = '/gpfs/commons/groups/gursoy_lab/mstoll/'
sys.path.append(path)

import os
import pandas as pd
import numpy as np 
import time
import pickle
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import LambdaLR, LinearLR, SequentialLR
from sklearn.metrics import f1_score, accuracy_score
from torch.utils.tensorboard import SummaryWriter


from codes.models.data_form.DataForm import DataTransfo_1SNP, PatientList
from codes.models.metrics import calculate_roc_auc, calculate_classification_report, calculate_loss, get_proba
from codes.models.Generative.Embeddings import EmbeddingPheno, EmbeddingSNPS
from codes.models.Generative.GenerativeModel import GenerativeModelPheWasV1
from codes.models.utils import print_file, plot_infos, plot_ini_infos, clear_last_line
from sklearn.metrics import f1_score, accuracy_score


import matplotlib.pyplot as plt

### data constants:
model_type = 'Generative_Transformer'
model_version = 'V1'
test_name = 'tests_generative_1'
CHR = 1
SNP = 'rs673604'
pheno_method = 'Abby' # Paul, Abby
rollup_depth = 4
Classes_nb = 2 #nb of classes related to an SNP (here 0 or 1)
vocab_size = None # to be defined with data
padding_token = 0
prop_train_test = 0.8
load_data = True
save_data = False
remove_none = True
decorelate = False
equalize_label = False
threshold_corr = 0.9
threshold_rare = 50
remove_rare = 'all' # None, 'all', 'one_class'
compute_features = True
padding = False
list_env_features = ['age', 'sex']
### data format
batch_size = 200
data_share = 1

dataT = DataTransfo_1SNP(SNP=SNP,
                         CHR=CHR,
                         method=pheno_method,
                         padding=padding,  
                         pad_token=padding_token, 
                         load_data=load_data, 
                         save_data=save_data, 
                         compute_features=compute_features,
                         prop_train_test=prop_train_test,
                         remove_none=True,
                         equalize_label=equalize_label,
                         rollup_depth=rollup_depth,
                         decorelate=decorelate,
                         threshold_corr=threshold_corr,
                         threshold_rare=threshold_rare,
                         remove_rare=remove_rare, 
                         list_env_features=list_env_features,
                         data_share=data_share)
#patient_list = dataT.get_patientlist()
patient_list = dataT.get_patientlist()
patient_list.unpad_data()


rollup_depth = 4
Head_size_pheno = 4
n_head_pheno = 2
n_layer_pheno = 2
instance_size_pheno = 10
Embedding_size_pheno = 10
embedding_method_pheno = None
proj_embed_pheno = False
freeze_embed_pheno = False
loss_version_pheno = 'cross_entropy'
p_dropout = 0.1
device = 'cuda' if torch.cuda.is_available() else 'cpu'

pheno_method = 'Abby'
embedding_method_pheno = None
embedding_method_SNPS = None
freeze_embed_SNPS = False
nb_phenos = patient_list.get_nb_distinct_diseases_tot()
nb_SNPS = 2
Embedding_size_SNPS = 10
n_head_SNPS = 2
Head_size_SNPS = 4
loss_version_SNPS = 'cross_entropy'
n_layer_SNPS = 2
instance_size_SNPS = 10
mask_padding = True
#multi
n_head_cross = 2
Head_size_cross = 4
n_layer_cross = 2
instance_size_cross = 10

nb_phenos_possible = patient_list.get_nb_distinct_diseases_tot()
vocab_size = nb_phenos_possible + 1 # masking
##### training constants
total_epochs = 100# number of epochs
learning_rate_max = 0.001 # maximum learning rate (at the end of the warmup phase)
learning_rate_ini = 0.00001 # initial learning rate 
learning_rate_final = 0.0001
warm_up_frac = 0.5 # fraction of the size of the warmup stage with regards to the total number of epochs.
start_factor_lr = learning_rate_ini / learning_rate_max
end_factor_lr = learning_rate_final / learning_rate_max
warm_up_size = int(warm_up_frac*total_epochs)
padding_masking = True

eval_batch_interval = 40
eval_epochs_interval = 1

#################### generate the ouptut files and dirs ############################################
path = '/gpfs/commons/groups/gursoy_lab/mstoll/codes/'
#check test name
model_dir = path + f'logs/runs/SNPS/{str(CHR)}/{SNP}/{model_type}/{model_version}/{pheno_method}'
model_plot_dir = path + f'logs/plots/tests/SNP/{str(CHR)}/{SNP}/{model_type}/{model_version}/{pheno_method}/'

os.makedirs(model_dir, exist_ok=True)
os.makedirs(model_plot_dir, exist_ok=True)
#check number tests
test_dir = f'{model_dir}/{test_name}/'
print(test_dir)
log_data_dir = f'{test_dir}/data/'
log_tensorboard_dir = f'{test_dir}/tensorboard/'
log_slurm_outputs_dir = f'{test_dir}/Slurm/Outputs/'
log_slurm_errors_dir = f'{test_dir}/Slurm/Errors/'
os.makedirs(log_data_dir, exist_ok=True)
os.makedirs(log_tensorboard_dir, exist_ok=True)
os.makedirs(log_slurm_outputs_dir, exist_ok=True)
os.makedirs(log_slurm_errors_dir, exist_ok=True)


log_data_path_pickle = f'{test_dir}/data/{test_name}.pkl'
log_tensorboard_path = f'{test_dir}/tensorboard/{test_name}'
log_slurm_outputs_path = f'{test_dir}/Slurm/Outputs/{test_name}.txt'
log_slurm_error_path = f'{test_dir}/Slurm/Errors/{test_name}.txt'
model_plot_path = path + f'logs/plots/tests/SNP/{str(CHR)}/{SNP}/{model_type}/{model_version}/{pheno_method}/{test_name}.png'

sys.stdrerr = log_slurm_error_path

############ generate the masked list of diseases #############################################
start_time = time.time()
print('generating the data files')
list_pheno_truth = []
list_labels = []
list_diseases_sentence_masked = []
for patient in patient_list:
    diseases_sentence = torch.tensor(patient.diseases_sentence)
    nb_diseases = len(diseases_sentence)
    masks = np.zeros((nb_diseases, nb_diseases)).astype(bool)
    np.fill_diagonal(masks,True)
    diseases_sentence_masked = np.tile(diseases_sentence, nb_diseases).reshape(nb_diseases, nb_diseases)
    pheno_Truth = diseases_sentence_masked[masks]
    labels = [np.array([patient.SNP_label])]*nb_diseases
    diseases_sentence_masked[masks] = nb_phenos 

    list_pheno_truth.extend(pheno_Truth)
    list_labels.extend(labels)
    list_diseases_sentence_masked.extend(diseases_sentence_masked)
print(f'generated files in {time.time() - start_time} seconds')

################################### padding the data ###################################################
list_diseases_new = []
nb_max_distinct_diseases_patient= patient_list.get_nb_max_distinct_diseases_patient() 
for list_diseases in list_diseases_sentence_masked:
    padd = np.zeros(nb_max_distinct_diseases_patient- len(list_diseases), dtype=int)
    list_diseases_new.append(np.concatenate([list_diseases, padd]).astype(int))
list_diseases_sentence_masked = list_diseases_new


list_data_gen = list(zip(list_diseases_sentence_masked, list_pheno_truth, list_labels))
indices= np.arange(len(list_data_gen))
np.random.shuffle(indices)
indices_train= indices[:int(prop_train_test * len(list_data_gen))]
indices_test = indices[int(prop_train_test * len(list_data_gen)):]


data_training = [list_data_gen[i] for i in indices_train]
data_test = [list_data_gen[i] for i in indices_test]

print(len(data_training), flush=True)
dataloader_train = DataLoader(data_training, batch_size=batch_size, shuffle=True)
dataloader_test = DataLoader(data_test, batch_size=batch_size, shuffle=True)



Embedding_pheno = EmbeddingPheno(method=embedding_method_pheno, vocab_size=vocab_size, Embedding_size=Embedding_size_pheno,
     rollup_depth=rollup_depth, freeze_embed=freeze_embed_pheno, dicts=dataT.dicts)

Embedding_SNPS = EmbeddingSNPS(method=embedding_method_SNPS, nb_SNPS=nb_SNPS, Embedding_size=Embedding_size_SNPS, freeze_embed=freeze_embed_SNPS)
    

model = GenerativeModelPheWasV1(n_head_pheno=n_head_pheno, Head_size_pheno=Head_size_pheno, Embedding_pheno=Embedding_pheno, Embedding_SNPS=Embedding_SNPS,
    instance_size_pheno=instance_size_pheno, n_layer_pheno=n_layer_pheno,  nb_SNPS=nb_SNPS, n_layer_SNPS=n_layer_SNPS, n_head_SNPS=n_head_SNPS, mask_padding=mask_padding,
    Head_size_SNPS=Head_size_SNPS, instance_size_SNPS=instance_size_SNPS, nb_phenos_possible=nb_phenos_possible,
    n_head_cross=n_head_cross, Head_size_cross=Head_size_cross, n_layer_cross=n_layer_cross, p_dropout=p_dropout, device=device,
    loss_version_pheno=loss_version_pheno, loss_version_SNPS=loss_version_SNPS, gamma=2, alpha=1, padding_token=padding_token)

model.to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate_max)
lr_scheduler_warm_up = LinearLR(optimizer, start_factor=start_factor_lr , end_factor=1, total_iters=warm_up_size, verbose=False) # to schedule a modification in the learning rate
lr_scheduler_final = LinearLR(optimizer, start_factor=1, total_iters=total_epochs-warm_up_size, end_factor=end_factor_lr)
lr_scheduler = SequentialLR(optimizer, schedulers=[lr_scheduler_warm_up, lr_scheduler_final], milestones=[warm_up_size])


######################################################## Training Loop ###################################################
output_file = log_slurm_outputs_path
writer = SummaryWriter(log_tensorboard_path)

## Open tensor board writer
dic_features_list = {
'list_training_loss' : [],
'list_validation_loss' : [],
'list_proba_avg_zero' : [],
'list_proba_avg_one' : [],
'list_auc_validation' : [],
'list_accuracy_validation' : [],
'list_f1_validation' : [],
'epochs' : [] }

# Training Loop
start_time_training = time.time()
print_file(output_file, f'Beginning of the program for {total_epochs} epochs', new_line=True)
# Training Loop
plot_ini_infos(model, output_file, dataloader_test, dataloader_train, writer, dic_features_list)
for epoch in range(1, total_epochs+1):

    start_time_epoch = time.time()
    total_loss = 0.0  
    
    #with tqdm(total=len(dataloader_train), position=0, leave=True) as pbar:
    for k, (batch_sentences_pheno, batch_labels_pheno, batch_sentences_SNPS) in enumerate(dataloader_train):
        start_time_batch = time.time()
        
        batch_sentences_pheno = batch_sentences_pheno.to(device)
        batch_labels_pheno = batch_labels_pheno.to(device)
        batch_sentences_SNPS = batch_sentences_SNPS.to(device)

        # evaluate the loss
        logits, loss = model(batch_sentences_pheno, batch_sentences_SNPS,value='pheno', targets= batch_labels_pheno)
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
    

        total_loss += loss.item()

        optimizer.step()

        if k % eval_batch_interval == 0:
            clear_last_line(output_file)
            print_file(output_file, f'Progress in epoch {epoch}  = {round(k / len(dataloader_train)*100, 2)} %, time batch : {time.time() - start_time_epoch}', new_line=False)

    if epoch % eval_epochs_interval == 0:
        dic_features = plot_infos(model, output_file, epoch, total_loss, start_time_epoch, dataloader_train, dataloader_test, optimizer, writer, dic_features_list, model_plot_path)

    
    
    lr_scheduler.step()

dic_features = dic_features
model.to('cpu')
#model.write_embedding(writer)
# Print time
print_file(output_file, f"Training finished: {int(time.time() - start_time_training)} seconds", new_line=True)
start_time = time.time()





## Add hyper parameters to tensorboard
hyperparams = {"CHR" : CHR, "SNP" : SNP, "ROLLUP LEVEL" : rollup_depth,
            'PHENO_METHOD': pheno_method, 'EMBEDDING_METHOD': embedding_method_pheno,
            'EMBEDDING SIZE' : Embedding_size_pheno, 'ATTENTION HEADS' : n_head_pheno, 'BLOCKS' : n_layer_pheno,
            'LR':1 , 'DROPOUT' : p_dropout, 'NUM_EPOCHS' : total_epochs, 
            'BATCH_SIZE' : batch_size, 
            'PADDING_MASKING': padding_masking,
            'VERSION' : model_version,
            'NB_Patients'  : len(patient_list),
            'LOSS_VERSION'  : loss_version_pheno,
            }

writer.add_hparams(hyperparams, dic_features)




In [None]:
model.Embedding_pheno.similarities_tab.device

In [None]:
model.Embedding_pheno.to('cuda')