In [None]:
##### Functionnal version with optionnal mask padding and dropouts, see Transformer_V1.ipynb for example
import sys
path = '/gpfs/commons/groups/gursoy_lab/mstoll/'
sys.path.append(path)
from torch.utils.data import DataLoader
from torchviz import make_dot
import matplotlib.pyplot as plt

from torch.utils.tensorboard import SummaryWriter
from codes.models.data_form.DataForm import DataTransfo_1SNP, PatientList, Patient
import seaborn as sns

In [None]:
import os
import sys
sys.path.append('/gpfs/commons/groups/gursoy_lab/pmeddeb/phenotype_embedding')
import time
import pickle
import numpy as np
import pandas as pd
import torch
import torch.nn as nn



class EmbeddingPheno(nn.Module):
    def __init__(self, method=None, vocab_size=None, max_count_same_disease=None, Embedding_size=None, rollup_depth=4, freeze_embed=False, dicts=None):
        super(EmbeddingPheno, self).__init__()

        self.dicts = dicts
        self.rollup_depth = rollup_depth
        self.nb_distinct_diseases_patient = vocab_size
        self.Embedding_size = Embedding_size
        self.max_count_same_disease = None
        self.metadata = None

        if self.dicts != None:
            id_dict = self.dicts['id']
            name_dict = self.dicts['name']
            cat_dict = self.dicts['cat']
            codes = list(id_dict.keys())
            self.metadata = [[name_dict[code], cat_dict[code]] for code in codes]

        
        if method == None:
            self.distinct_diseases_embeddings = nn.Embedding(vocab_size, Embedding_size)
            self.counts_embeddings = nn.Embedding(max_count_same_disease, Embedding_size)
            torch.nn.init.normal_(self.distinct_diseases_embeddings.weight, mean=0.0, std=0.02)
            torch.nn.init.normal_(self.counts_embeddings.weight, mean=0.0, std=0.02)

        elif method == 'Abby':
            embedding_file_diseases = f'/gpfs/commons/groups/gursoy_lab/mstoll/codes/Data_Files/Embeddings/Abby/embedding_abby_no_1_diseases.pth'
            pretrained_weights_diseases = torch.load(embedding_file_diseases)
            self.Embedding_size = pretrained_weights_diseases.shape[1]

            self.distinct_diseases_embeddings = nn.Embedding.from_pretrained(pretrained_weights_diseases, freeze=freeze_embed)
            self.counts_embeddings = nn.Embedding(max_count_same_disease, self.Embedding_size)



        elif method=='Paul':
            embedding_file_diseases = f'/gpfs/commons/groups/gursoy_lab/mstoll/codes/Data_Files/Embeddings/Paul_Glove/glove_UKBB_omop_rollup_closest_depth_{self.rollup_depth}_no_1_diseases.pth'
            pretrained_weights_diseases = torch.load(embedding_file_diseases)
            self.Embedding_size = pretrained_weights_diseases.shape[1]

            self.distinct_diseases_embeddings = nn.Embedding.from_pretrained(pretrained_weights_diseases, freeze=freeze_embed)
            self.counts_embeddings = nn.Embedding(max_count_same_disease, self.Embedding_size)
    def write_embedding(self, writer):
            embedding_tensor = self.distinct_diseases_embeddings.weight.data.detach().cpu().numpy()
            writer.add_embedding(embedding_tensor, metadata=self.metadata, metadata_header=["Name","Label"])


class EmbeddingPhenoCat(nn.Module):
    def __init__(self, method=None, Embedding_size=10, rollup_depth=4, freeze_embed=False, dic_embedding_cat_params={}):
        super(EmbeddingPheno, self).__init__()

        self.rollup_depth = rollup_depth
        self.Embedding_size = Embedding_size
        self.max_count_same_disease = None
        self.dic_embedding_cat_params = dic_embedding_cat_params
        dic_embedding_cat = {}
        for cat, max_number  in self.dic_embedding_cat:
        
            if cat=='diseases':
                if method == None:
                    dic_embedding_cat[cat] = nn.Embedding(max_number, Embedding_size)
                    torch.nn.init.normal_(dic_embedding_cat[cat].weight, mean=0.0, std=0.02)

                elif method == 'Abby':
                    embedding_file_diseases = f'/gpfs/commons/groups/gursoy_lab/mstoll/codes/Data_Files/Embeddings/Abby/embedding_abby_no_1_diseases.pth'
                    pretrained_weights_diseases = torch.load(embedding_file_diseases)
                    self.Embedding_size = pretrained_weights_diseases.shape[1]
                    dic_embedding_cat[cat] = nn.Embedding.from_pretrained(pretrained_weights_diseases, freeze=freeze_embed)

                elif method=='Paul':
                    embedding_file_diseases = f'/gpfs/commons/groups/gursoy_lab/mstoll/codes/Data_Files/Embeddings/Paul_Glove/glove_UKBB_omop_rollup_closest_depth_{self.rollup_depth}_no_1_diseases.pth'
                    pretrained_weights_diseases = torch.load(embedding_file_diseases)
                    dic_embedding_cat[cat] = pretrained_weights_diseases.shape[1]

                    self.distinct_diseases_embeddings = nn.Embedding.from_pretrained(pretrained_weights_diseases, freeze=freeze_embed)
            if cat=='counts':
                if (method == None) or (method == 'Paul') :
                    dic_embedding_cat['counts_embeddings'] = nn.Embedding(max_number, self.Embedding_size)
                    torch.nn.init.normal_(self.counts_embeddings.weight, mean=0.0, std=0.02)

            else:
                dic_embedding_cat[cat] = nn.Embedding(max_number, Embedding_size)
                torch.nn.init.normal_(dic_embedding_cat[cat].weight, mean=0.0, std=0.02)

        self.dic_embedding_cat = dic_embedding_cat


In [None]:
### creation of the reference model
#### framework constants:
model_type = 'transformer'
model_version = 'transformer_V2'
test_name = 'baseline_model'
pheno_method = 'Paul' # Paul, Abby
tryout = True # True if we are doing a tryout, False otherwise 
### data constants:
CHR = 1
SNP = 'rs673604'
rollup_depth = 4
binary_classes = False #nb of classes related to an SNP (here 0 or 1)
vocab_size = None # to be defined with data
padding_token = 0
prop_train_test = 0.8
load_data = True
save_data = False
remove_none = True
compute_features = False
indices=None
padding = True
list_env_features = ['age', 'sex']
### data format
batch_size = 200
data_share = 1/100#402555
seuil_diseases = 600
equalize_label = True
decorelate = True
threshold_corr = 0.9
threshold_rare = 1000
remove_rare = 'all' # None, 'all', 'one_class'
##### model constants
embedding_method = 'Paul' #None, Paul, Abby
counts_method = 'normal'#{'counts': 'SineCos', 'age':'SineCos'}
freeze_embedding = True
Embedding_size = 10 # Size of embedding.
proj_embed = True
instance_size = 10
n_head = 2# number of SA heads
n_layer = 1# number of blocks in parallel
Head_size = 8 # size of the "single Attention head", which is the sum of the size of all multi Attention heads
eval_epochs_interval = 5 # number of epoch between each evaluation print of the model (no impact on results)
eval_batch_interval = 40
p_dropout = 0.3 # proba of dropouts in the model
masking_padding = True # do we include padding masking or not
device = 'cuda' if torch.cuda.is_available() else 'cpu'
loss_version = 'cross_entropy' #cross_entropy or focal_loss
gamma = 2
alpha = 1
##### training constants
total_epochs = 50 # number of epochs
learning_rate_max = 0.001 # maximum learning rate (at the end of the warmup phase)
learning_rate_ini = 0.00001 # initial learning rate 
learning_rate_final = 0.0001
warm_up_frac = 0.5 # fraction of the size of the warmup stage with regards to the total number of epochs.
start_factor_lr = learning_rate_ini / learning_rate_max
end_factor_lr = learning_rate_final / learning_rate_max

warm_up_size = int(warm_up_frac*total_epochs)


In [None]:
dataT = DataTransfo_1SNP(SNP=SNP,
                         CHR=CHR,
                         method=pheno_method,
                         padding=padding,  
                         binary_classes=binary_classes,
                         pad_token=padding_token, 
                         load_data=load_data, 
                         save_data=save_data, 
                         compute_features=compute_features,
                         data_share=data_share,
                         prop_train_test=prop_train_test,
                         remove_none=remove_none,
                         rollup_depth=rollup_depth,
                         equalize_label=equalize_label,
                         seuil_diseases=seuil_diseases,
                         decorelate=decorelate,
                         threshold_corr=threshold_corr,
                         threshold_rare=threshold_rare,
                         remove_rare=remove_rare,
                         list_env_features=list_env_features, 
                         indices=indices)

In [None]:
patient_list = dataT.get_patientlist()

In [None]:
indices_train, indices_test = dataT.get_indices_train_test(nb_data=len(patient_list),prop_train_test=prop_train_test)
patient_list_transformer_train, patient_list_transformer_test = patient_list.get_transformer_data(indices_train.astype(int), indices_test.astype(int))
#creation of torch Datasets:
dataloader_train = DataLoader(patient_list_transformer_train, batch_size=batch_size, shuffle=True)
dataloader_test = DataLoader(patient_list_transformer_test, batch_size=batch_size, shuffle=True)


if patient_list.nb_distinct_diseases_tot==None:
    vocab_size = patient_list.get_nb_distinct_diseases_tot()
if patient_list.nb_max_counts_same_disease==None:
    max_count_same_disease = patient_list.get_max_count_same_disease()
max_count_same_disease = patient_list.nb_max_counts_same_disease
vocab_size = patient_list.get_nb_distinct_diseases_tot()

print(f'\n vocab_size : {vocab_size}, max_count : {max_count_same_disease}\n', 
    f'length_patient = {patient_list.get_nb_max_distinct_diseases_patient()}\n',
    f'sparcity = {patient_list.sparsity}\n',
    f'nombres patients  = {len(patient_list)}')

writer = SummaryWriter(log_tensorboard_path)


