In [None]:
import sys
path = '/gpfs/commons/groups/gursoy_lab/mstoll/'
sys.path.append(path)

import numpy as np
import pandas as pd
import torch.nn as nn
import torch
import pickle
import time
import os
import shutil

import tensorboard

from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torch.nn import functional as F
from rich.progress import Progress
from tqdm.auto import tqdm
from torch.optim.lr_scheduler import LambdaLR, LinearLR, SequentialLR
from functools import partial
from sklearn.metrics import f1_score, accuracy_score


from codes.models.data_form.DataForm import DataTransfo_1SNP
from codes.models.Transformers.Embedding import EmbeddingPheno
from codes.models.Transformers.dic_model_versions import DIC_MODEL_VERSIONS
from codes.models.utils import clear_last_line, print_file, number_tests


In [None]:

#### framework constants:
model_type = 'transformer'
model_version = 'transformer_V1'
test_name = 'test_train_transfo_V1'
pheno_method = 'Abby' # Paul, Abby
tryout = True # True if we are ding a tryout, False otherwise 
### data constants:
CHR = 1
SNP = 'rs673604'
rollup_depth = 4
Classes_nb = 2 #nb of classes related to an SNP (here 0 or 1)
vocab_size = None # to be defined with data
padding_token = 0
prop_train_test = 0.8
load_data = True
save_data = False
remove_none = True
compute_features = True
padding = True
### data format
batch_size = 20
data_share = 1/100

##### model constants
embedding_method = None #None, Paul, Abby
freeze_embedding = False
Embedding_size = 20 # Size of embedding.
n_head = 4 # number of SA heads
n_layer = 1 # number of blocks in parallel
Head_size = 20  # size of the "single Attention head", which is the sum of the size of all multi Attention heads
eval_epochs_interval = 1 # number of epoch between each evaluation print of the model (no impact on results)
eval_batch_interval = 10
p_dropout = 0 # proba of dropouts in the model
masking_padding = False # do we include padding masking or not
seuil_diseases = 50
loss_version = 'focal_loss'
equalize_label = False
gamma = 2
##### training constants
total_epochs = 10 # number of epochs
learning_rate_max = 0.01 # maximum learning rate (at the end of the warmup phase)
learning_rate_ini = 0.001 # initial learning rate
learning_rate_final = 0.001 
warm_up_frac = 0.20 # fraction of the size of the warmup stage with regards to the total number of epochs.
start_factor_lr = learning_rate_ini / learning_rate_max
end_factor_lr = learning_rate_final / learning_rate_max
warm_up_size = int(total_epochs*warm_up_frac)
device = 'cuda' if torch.cuda.is_available() else 'cpu'


def lr_lambda(current_epoch, warm_up_size=warm_up_size): ## function that defines the evolution of the learning rate.
    warm_up_size = int(total_epochs*warm_up_frac)
    if current_epoch < warm_up_size:
        return learning_rate_ini + current_epoch*(learning_rate_max - learning_rate_ini) / warm_up_size
    else:
        return learning_rate_max / (current_epoch - warm_up_size + 1)
lr_lambda = partial(lr_lambda, warm_up_size=warm_up_size) 

 

In [None]:
### links towards directories
path = '/gpfs/commons/groups/gursoy_lab/mstoll/codes/'

#check test name
model_dir = path + f'logs/SNPS/{str(CHR)}/{SNP}/{model_type}/{model_version}/{pheno_method}'
os.makedirs(model_dir, exist_ok=True)
#check number tests
number_test = number_tests(model_dir)
test_name_with_infos = str(number_test) + '_' + test_name + 'tryout'*tryout
test_dir = f'{model_dir}/{test_name_with_infos}/'
log_model_dir = f'{test_dir}/model/'
log_data_dir = f'{test_dir}/data/'
log_info_dir = f'{test_dir}/infos/tensorboard/'
log_slurm_outputs_dir = f'{test_dir}/Slurm/Outputs/'

os.makedirs(log_model_dir)
os.makedirs(log_data_dir)
os.makedirs(log_info_dir)
os.makedirs(log_slurm_outputs_dir)

log_model_path_torch = f'{test_dir}/model/{test_name}.pth'
log_model_path_pickle = f'{test_dir}/model/{test_name}.pkl'
log_data_path_pickle = f'{test_dir}/data/{test_name}.pkl'


log_info_path = f'{test_dir}/infos/tensorboard/{test_name}'
log_slurm_outputs_path = f'{test_dir}/Slurm/Outputs/{test_name}.pth'



In [None]:
dataT = DataTransfo_1SNP(SNP=SNP,
                         CHR=CHR,
                         method=pheno_method,
                         padding=padding,  
                         pad_token=padding_token, 
                         load_data=load_data, 
                         save_data=save_data, 
                         compute_features=compute_features,
                         data_share=data_share,
                         prop_train_test=prop_train_test,
                         remove_none=remove_none,
                         rollup_depth=rollup_depth, 
                         equalize_label=equalize_label, 
                         seuil_diseases=seuil_diseases
                         )
patient_list = dataT.get_patientlist()
patient_list.padd_data()


In [None]:
indices_train, indices_test = dataT.get_indices_train_test(patient_list=patient_list,prop_train_test=prop_train_test)
patient_list_transformer_train, patient_list_transformer_test = patient_list.get_transformer_data(indices_train.astype(int), indices_test.astype(int))
#creation of torch Datasets:
dataloader_train = DataLoader(patient_list_transformer_train, batch_size=batch_size, shuffle=True)
dataloader_test = DataLoader(patient_list_transformer_test, batch_size=batch_size, shuffle=True)


In [None]:
if patient_list.nb_distinct_diseases_tot==None:
    vocab_size = patient_list.get_nb_distinct_diseases_tot()
if patient_list.nb_max_counts_same_disease==None:
    max_count_same_disease = patient_list.get_max_count_same_disease()
max_count_same_disease = patient_list.nb_max_counts_same_disease
vocab_size = patient_list.nb_distinct_diseases_tot

In [None]:
print(f'\n vocab_size : {vocab_size}, max_count : {max_count_same_disease}\n', 
      f'length_patient = {patient_list.get_nb_max_distinct_diseases_patient()}\n',
      f'sparcity = {patient_list.sparsity}\n',
      f'nombres patients  = {len(patient_list)}')

In [None]:
Embedding  = EmbeddingPheno(method=embedding_method, 
                            vocab_size=vocab_size, 
                            max_count_same_disease=max_count_same_disease, 
                            Embedding_size=Embedding_size, 
                            rollup_depth=rollup_depth, 
                            freeze_embed=freeze_embedding)

In [None]:
### creation of the model
ClassModel = DIC_MODEL_VERSIONS[model_version]
model = ClassModel(pheno_method = pheno_method,
                             Embedding = Embedding,
                             Head_size=Head_size,
                             Classes_nb=Classes_nb,
                             n_head=n_head,
                             n_layer=n_layer,
                             mask_padding=masking_padding, 
                             padding_token=0, 
                             p_dropout=p_dropout, 
                             loss_version = loss_version, 
                             gamma = gamma)
m = model.to(device)
# print the number of parameters in the model
print(sum(p.numel() for p in m.parameters())/1e6, 'M parameters')

In [None]:
f1, accuracy, report, auc_score, val_loss, proba_avg_zero, proba_avg_one, predicted_probas_list, true_labels_list = model.evaluate(dataloader_train)


In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate_max)
lr_scheduler_warm_up = LinearLR(optimizer, start_factor=start_factor_lr , end_factor=1, total_iters=warm_up_size, verbose=False) # to schedule a modification in the learning rate
lr_scheduler_final = LinearLR(optimizer, start_factor=1, total_iters=total_epochs-warm_up_size, end_factor=end_factor_lr)
lr_scheduler = SequentialLR(optimizer, schedulers=[lr_scheduler_warm_up, lr_scheduler_final], milestones=[warm_up_size])


In [None]:
## Open tensor board writer
output_file = log_slurm_outputs_path
with open(output_file, 'w') as file:
    file.truncate()
    file.close()
writer = SummaryWriter(log_info_path)
# Training Loop
start_time_training = time.time()
print_file(output_file, f'Beginning of the program for {total_epochs} epochs', new_line=True)
# Training Loop
for epoch in range(total_epochs):

    start_time_epoch = time.time()
    total_loss = 0.0  
    print_file(output_file, 'will be deleted', new_line=True)
    #with tqdm(total=len(dataloader_train), position=0, leave=True) as pbar:
    for k, (batch_sentences, batch_counts, batch_labels) in enumerate(dataloader_train):
        
        batch_sentences = batch_sentences.to(device)
        batch_counts = batch_counts.to(device)
        batch_labels = batch_labels.to(device)

        # evaluate the loss
        logits, loss = model(batch_sentences, batch_counts, batch_labels)
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
    

        total_loss += loss.item()

        optimizer.step()

        if k %eval_batch_interval == 0:
            clear_last_line(output_file)
            print_file(output_file, f'Progress in batch = {round(k / len(dataloader_train)*100, 2)} %, time batch : {time.time() - start_time_epoch}', new_line=False)
        
        

    
    writer.add_scalar('Training loss', total_loss/len(dataloader_train), epoch)

    if epoch % eval_epochs_interval == 0:
        print_file(output_file, f"Epoch {epoch} finished: {int(time.time() - start_time_epoch)} seconds", new_line=True)
        print_file(output_file, f"    Training loss : {(total_loss/len(dataloader_train)):.4f}", new_line = True)
        f1, accuracy, report, auc_score, val_loss, proba_avg_zero, proba_avg_one, predicted_probas_list, true_labels_list = model.evaluate(dataloader_test)
        print_file(output_file, " Evaluation on validation", new_line=True)
        print_file(output_file, f"        Validation loss : {val_loss:.4f}", new_line=True)
        writer.add_scalar('Validation loss', val_loss, epoch)
        writer.add_scalar('Validation AUC', auc_score, epoch)
        writer.add_scalar('Validation f1-score', f1, epoch)
        writer.add_scalar('Validation accuracy', accuracy, epoch)
        writer.add_scalar('Average Proba 0', proba_avg_zero, epoch)
        writer.add_scalar('Average Proba 1', proba_avg_one, epoch)

        print_file(output_file, f"learning rate : {optimizer.param_groups[0]['lr']}", new_line=True)
    
    lr_scheduler.step()


torch.save(model.state_dict(), log_model_path_torch)
print('Model saved to %s' % log_model_path_torch)
# Print time
print(f"Training finished: {int(time.time() - start_time_training)} seconds")
start_time = time.time()

with open(log_model_path_pickle, 'wb') as file:
    pickle.dump(model, file)
print('Model saved to %s' % log_model_path_pickle)

dic_data = {
    'patient_list':patient_list,
    'data' : dataT
}
with open(log_data_path_pickle, 'wb') as file:
    pickle.dump(dic_data, file)
print('Data saved to %s' % log_data_path_pickle)


In [None]:

np.mean(np.array(predicted_probas_list)[:,1][np.array(true_labels_list)==0])

In [None]:
logits = torch.randint(4, (2,), dtype=float)
probas = F.softmax(logits)


In [None]:
target = torch.tensor(0)

In [None]:
logits, target

In [None]:
F.cross_entropy(logits, target)

In [None]:
p1 = probas[0]
p2 = probas[1]

In [None]:
-np.log(p1)

In [None]:
### Oter implementations

In [None]:
## Open tensor board writer
writer = SummaryWriter(log_path)
# Training Loop
start_time_training = time.time()

for epoch in range(total_epochs):

    start_time_epoch = time.time()
    total_loss = 0.0  

    for (batch_sentences, batch_counts, batch_labels) in tqdm(dataloader_train,desc=f"Processing batch", unit="group"):
        batch_sentences = batch_sentences.to(device)
        batch_counts = batch_counts.to(device)
        batch_labels = batch_labels.to(device)

        # evaluate the loss
        logits, loss = model(batch_sentences, batch_counts, batch_labels)
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
    

        total_loss += loss.item()

        optimizer.step()

    lr_scheduler.step()
    
    writer.add_scalar('Training loss', total_loss/len(dataloader_train), epoch)

    if epoch % eval_interval == 0:
        print(f"Epoch {epoch} finished: {int(time.time() - start_time_epoch)} seconds")
        print(f"    Training loss : {(total_loss/len(dataloader_train)):.4f}")
        f1, accuracy, report, auc_score, val_loss = model.evaluate(dataloader_test)
        print(" Evaluation on validation")
        print(f"        Validation loss : {val_loss:.4f}")
        writer.add_scalar('Validation loss', val_loss, epoch)
        writer.add_scalar('Validation AUC', auc_score, epoch)
        writer.add_scalar('Validation f1-score', f1, epoch)
        writer.add_scalar('Validation accuracy', accuracy, epoch)
        
        print(f"learning rate : {optimizer.param_groups[0]['lr']}")
    

torch.save(model.state_dict(), model_path)
print('Model saved to %s' % model_path)
# Print time
print(f"Training finished: {int(time.time() - start_time_training)} seconds")
start_time = time.time()


In [None]:
# Training Loop
for epoch in range(total_epochs):

    start_time_epoch = time.time()
    total_loss = 0.0  
    with tqdm(total=100, desc="Progress") as pbar:

        for batch_sentences, batch_counts, batch_labels in dataloader_train:

            batch_sentences = batch_sentences.to(device)
            batch_counts = batch_counts.to(device)
            batch_labels = batch_labels.to(device)

            # evaluate the loss
            logits, loss = model(batch_sentences, batch_counts, batch_labels)
            optimizer.zero_grad(set_to_none=True)
            loss.backward()
        

            total_loss += loss.item()

            optimizer.step()

            pbar.update(1)

        lr_scheduler.step()
        
        writer.add_scalar('Training loss', total_loss/len(dataloader_train), epoch)

        if epoch % eval_interval == 0:
            print(f"Epoch {epoch} finished: {int(time.time() - start_time_epoch)} seconds")
            print(f"    Training loss : {(total_loss/len(dataloader_train)):.4f}")
            f1, accuracy, report, auc_score, val_loss = model.evaluate(dataloader_test)
            print(" Evaluation on validation")
            print(f"        Validation loss : {val_loss:.4f}")
            writer.add_scalar('Validation loss', val_loss, epoch)
            writer.add_scalar('Validation AUC', auc_score, epoch)
            writer.add_scalar('Validation f1-score', f1, epoch)
            writer.add_scalar('Validation accuracy', accuracy, epoch)
            
            print(f"learning rate : {optimizer.param_groups[0]['lr']}")
        

torch.save(model.state_dict(), model_path)
print('Model saved to %s' % model_path)
# Print time
print(f"Training finished: {int(time.time() - start_time_training)} seconds")
start_time = time.time()


In [None]:

# Training Loop
with Progress() as progress:
    task1 = progress.add_task("Training model", total=total_epochs)
    task2 = progress.add_task("epoch", total=len(dataloader_train))

    for epoch in range(total_epochs):

        start_time_epoch = time.time()
        total_loss = 0.0  

        for batch_sentences, batch_counts, batch_labels in dataloader_train:

            batch_sentences = batch_sentences.to(device)
            batch_counts = batch_counts.to(device)
            batch_labels = batch_labels.to(device)

            # evaluate the loss
            logits, loss = model(batch_sentences, batch_counts, batch_labels)
            optimizer.zero_grad(set_to_none=True)
            loss.backward()
        

            total_loss += loss.item()

            optimizer.step()

            progress.update(task2, advance=1)

        lr_scheduler.step()
        
        writer.add_scalar('Training loss', total_loss/len(dataloader_train), epoch)

        if epoch % eval_interval == 0:
            print(f"Epoch {epoch} finished: {int(time.time() - start_time_epoch)} seconds")
            print(f"    Training loss : {(total_loss/len(dataloader_train)):.4f}")
            f1, accuracy, report, auc_score, val_loss = model.evaluate(dataloader_test)
            print(" Evaluation on validation")
            print(f"        Validation loss : {val_loss:.4f}")
            writer.add_scalar('Validation loss', val_loss, epoch)
            writer.add_scalar('Validation AUC', auc_score, epoch)
            writer.add_scalar('Validation f1-score', f1, epoch)
            writer.add_scalar('Validation accuracy', accuracy, epoch)
            
            print(f"learning rate : {optimizer.param_groups[0]['lr']}")
        
        progress.reset(task2)
        progress.update(task1, advance=1)

    torch.save(model.state_dict(), model_path)
    print('Model saved to %s' % model_path)
    # Print time
    print(f"Training finished: {int(time.time() - start_time_training)} seconds")
    start_time = time.time()



## Add hyper parameters to tensorboard
hyperparams = {"CHR": CHR, "SNP": SNP, "ROLLUP LEVEL": rollup_depth,
               'PHENO_METHOD':pheno_method, 'EMBEDDING_METHOD':embedding_method,
              'EMBEDDING SIZE': Embedding_size, 'ATTENTION HEADS': n_head, 'BLOCKS': n_layer,
              'LR':1 , 'DROPOUT': p_dropout, 'NUM_EPOCHS': total_epochs, 
              'BATCH_SIZE': batch_size, 
              'PADDING_MASKING':masking_padding,
              'VERSION': model_version,
              'NB_Patients' : len(patient_list),
              'LOSS_VERSION' : loss_version,

            }


In [None]:
import time

from rich.progress import Progress

with Progress() as progress:

    task1 = progress.add_task("[red]Downloading...", total=1000)
    task2 = progress.add_task("[green]Processing...", total=1000)
    task3 = progress.add_task("[cyan]Cooking...", total=1000)

    while not progress.finished:
        progress.update(task1, advance=0.5)
        progress.update(task2, advance=0.3)
        progress.update(task3, advance=0.9)
        time.sleep(0.02)

In [None]:
liste = [1,2,3]
with Progress() as progress:
    task1 = progress.add_task("[red]Downloading...", total=3)
    task2 = progress.add_task("blueDownloading...", total=3)

    for el in liste:
        time.sleep(1)
        for el in liste:
            progress.update(task2, advance=1)
            time.sleep(1)
        progress.reset(task2)
        progress.update(task1, advance=1)

In [None]:
len(dataloader_train)

In [None]:

# Training Loop
with Progress() as progress:
    task1 = progress.add_task("Training model", total=total_epochs)
    task2 = progress.add_task("epoch", total=len(dataloader_train))

    for epoch in range(total_epochs):

        start_time_epoch = time.time()
        total_loss = 0.0  

        for batch_sentences, batch_counts, batch_labels in dataloader_train:

            batch_sentences = batch_sentences.to(device)
            batch_counts = batch_counts.to(device)
            batch_labels = batch_labels.to(device)

            # evaluate the loss
            logits, loss = model(batch_sentences, batch_counts, batch_labels)
            optimizer.zero_grad(set_to_none=True)
            loss.backward()
        

            total_loss += loss.item()

            optimizer.step()

            progress.update(task2, advance=1)

        lr_scheduler.step()
        
        writer.add_scalar('Training loss', total_loss/len(dataloader_train), epoch)

        if epoch % eval_interval == 0:
            print(f"Epoch {epoch} finished: {int(time.time() - start_time_epoch)} seconds")
            print(f"    Training loss : {(total_loss/len(dataloader_train)):.4f}")
            f1, accuracy, report, auc_score, val_loss = model.evaluate(dataloader_test)
            print(" Evaluation on validation")
            print(f"        Validation loss : {val_loss:.4f}")
            writer.add_scalar('Validation loss', val_loss, epoch)
            writer.add_scalar('Validation AUC', auc_score, epoch)
            writer.add_scalar('Validation f1-score', f1, epoch)
            writer.add_scalar('Validation accuracy', accuracy, epoch)
            
            print(f"learning rate : {optimizer.param_groups[0]['lr']}")
        
        progress.reset(task2)
        progress.update(task1, advance=1)

    torch.save(model.state_dict(), model_path)
    print('Model saved to %s' % model_path)
    # Print time
    print(f"Training finished: {int(time.time() - start_time_training)} seconds")
    start_time = time.time()


In [None]:
p = torch.randint(8, (2,2, 1))
p

In [None]:
p.view(2, 2)

In [None]:
import pandas as pd

In [None]:
pd.read_csv('/gpfs/commons/groups/gursoy_lab/mstoll/codes/models/list_models_class/list_instance_tests_saved.csv')