In [None]:
import sys
path = '/gpfs/commons/groups/gursoy_lab/mstoll/'
sys.path.append(path)

import pandas as pd
import numpy as np 
from functools import partial
import torch.nn as nn
import time
import torch 
import torch.nn.functional as F
from codes.models.data_form.DataForm import DataTransfo_1SNP, PatientList
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from codes.models.metrics import calculate_roc_auc

In [None]:
### data constants:
CHR = 1
SNP = 'rs673604'
pheno_method = 'Abby' # Paul, Abby
rollup_depth = 4
Classes_nb = 2 #nb of classes related to an SNP (here 0 or 1)
vocab_size = None # to be defined with data
padding_token = 0
prop_train_test = 0.8
load_data = False
save_data = False
remove_none = True
decorelate = False
equalize_label = False
threshold_corr = 0.9
threshold_rare = 50
remove_rare = 'all' # None, 'all', 'one_class'
compute_features = True
padding = False
list_env_features = ['age', 'sex']
### data format
batch_size = 20
data_share = 1/10000

eval_epochs_interval = 2
nb_epochs = 20

In [None]:
dataT = DataTransfo_1SNP(SNP=SNP,
                         CHR=CHR,
                         method=pheno_method,
                         padding=padding,  
                         pad_token=padding_token, 
                         load_data=load_data, 
                         save_data=save_data, 
                         compute_features=compute_features,
                         prop_train_test=prop_train_test,
                         remove_none=True,
                         equalize_label=equalize_label,
                         rollup_depth=rollup_depth,
                         decorelate=decorelate,
                         threshold_corr=threshold_corr,
                         threshold_rare=threshold_rare,
                         remove_rare=remove_rare, 
                         list_env_features=list_env_features,
                         data_share=data_share)
#patient_list = dataT.get_patientlist()
data, labels, indices_env, name_envs = dataT.get_tree_data(with_env=False)
data, labels = DataTransfo_1SNP.equalize_label(data, labels)
nb_phenos = data.shape[1]
phenos = np.arange(nb_phenos)

In [None]:
def get_risk_pheno(data, labels, pheno_nb):
    labels_ac = labels[data[:,pheno_nb]==1]
    labels_deac = labels[data[:,pheno_nb]==0]
    proba_mut_ac = np.sum(labels_ac==1)/len(labels_ac)
    proba_mut_deac = np.sum(labels_deac==1)/len(labels_deac)
    ratio  = proba_mut_ac / proba_mut_deac
    return ratio
def get_pred_naive(data, labels, pheno_nb, proba=False):
    labels_ac = labels[data[:,pheno_nb]==1]
    nb_ones_ac = np.sum(labels_ac==1)
    nb_zeros_ac = np.sum(labels_ac==0)
    proba = nb_zeros_ac / len(labels_ac)
    label = (1 if nb_ones_ac > nb_zeros_ac else 0)
    return proba, label
get_risk_pheno = partial(get_risk_pheno, data, labels)
get_pred_naive = partial(get_pred_naive, data, labels)

odds_ratios = list(map(get_risk_pheno, phenos))
pred_naive = np.array(list(map(get_pred_naive, phenos)))
probas_pred_naive = pred_naive[:, 0]
labels_pred_naive = pred_naive[:, 1]

def get_pred_sentence(probas_pred_naive, labels_pred_naive, sentence, method='max'):
    sentence = sentence.astype(bool)
    labels_naive = labels_pred_naive[sentence].astype(bool)
    probas_naive = probas_pred_naive[sentence].astype(bool)

    if method=='mean':
        if np.mean(probas_naive)>0.5:
            return 1
        else:
            return 0
    if method=='max':
        argmax = np.argmax((probas_naive-0.5)**2)
        return labels_naive[argmax]
    
get_pred_sentence = partial(get_pred_sentence, probas_pred_naive, labels_pred_naive)

frequencies = np.sum(data, axis=0)
labels_pred  = np.apply_along_axis(get_pred_sentence, arr=data, axis=1)
np.sum(labels_pred == labels)/len(labels)

In [None]:
import sys
path = '/gpfs/commons/groups/gursoy_lab/mstoll/'
sys.path.append(path)

import torch
import torch.nn as nn
import time
import numpy as np
from torch.utils.data import Dataset
import torch.nn.functional as F
from sklearn.metrics import f1_score, accuracy_score

from codes.models.metrics import calculate_roc_auc, get_proba
class NaiveModelWeights(nn.Module):
    def __init__(self, pheno_nb):
        super().__init__()
        
        self.linear_weights_predictor = nn.Linear(pheno_nb, pheno_nb,bias=True,  dtype=float)
        self.logits_predictor = nn.Linear(pheno_nb, 2 *pheno_nb, bias=True, dtype=float)

    def forward(self, x, labels_target=None):
        B, P = x.shape
        weights = self.linear_weights_predictor(x)
        prob_weights = F.softmax(weights).view(B, P, 1)
        logits = self.logits_predictor(x).view(B, P, 2)
        logits = (logits.transpose(1, 2)) @ prob_weights
        

        if labels_target != None:
            err = F.cross_entropy(logits, labels_target.view(B, 1))#torch.sqrt(torch.sum((pred_probas - labels_target)**2)/len(x)) 
        return logits, err


    def eval_model(self, dataloader_test):
        self.eval()
        print('beginning inference evaluation')
        start_time_inference = time.time()
        predicted_labels_list = []
        predicted_probas_list = []
        true_labels_list = []

        total_loss = 0.
        self.eval()
        with torch.no_grad():
            for k, batch in enumerate(dataloader_test):
                data_train= batch['data']
                labels_train = batch['label']

                logits, loss = self(data_train, labels_train)
                total_loss += loss.item()
                predicted_probas = F.softmax(logits).detach().numpy()
                predicted_probas_reduced = predicted_probas[:, 1]
                predicted_labels = (predicted_probas_reduced > 0.5).astype(int)

                #predicted_labels = self.predict(batch_sentences, batch_counts)
                predicted_labels_list.extend(predicted_labels)
                predicted_probas_list.extend(predicted_probas)
                true_labels_list.extend(labels_train.cpu().numpy())
        f1 = f1_score(true_labels_list, predicted_labels_list, average='macro')
        accuracy = accuracy_score(true_labels_list, predicted_labels_list)
        auc_score = calculate_roc_auc(true_labels_list, np.array(predicted_probas_list), return_nan=True)
        proba_avg_zero, proba_avg_one = get_proba(true_labels_list, predicted_probas_list)
        self.train()
        print(f'end inference evaluation in {time.time() - start_time_inference}s')
        return f1, accuracy, auc_score, total_loss/len(dataloader_test), proba_avg_zero, proba_avg_one, predicted_probas_list, true_labels_list

    

class CustomDatasetWithLabels(Dataset):
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = {'data': self.data[idx], 'label': self.labels[idx]}
        return sample

In [None]:
data_train, data_test, labels_train, labels_test = train_test_split(data, labels, test_size = 1-prop_train_test, random_state=42)
data_train = CustomDatasetWithLabels(data_train, labels)
dataloader_train = DataLoader(data_train, batch_size=batch_size, shuffle=True)
data_test = CustomDatasetWithLabels(data_test, labels)
dataloader_test = DataLoader(data_test, batch_size=batch_size, shuffle=True)
naive_model = NaiveModelWeights(pheno_nb=nb_phenos)
optimizer = torch.optim.AdamW(naive_model.parameters(), lr=0.0001)



In [None]:
print(sum(p.numel() for p in naive_model.parameters())/1e6, 'M parameters')


In [None]:
for epoch in range(nb_epochs):
    start_time = time.time()
    loss_tot = 0
    for k, batch in enumerate(dataloader_train):
        data_train= batch['data']
        labels_train = batch['label']
        # evaluate the loss
        pred_probas, loss = naive_model(data_train, labels_train)
        loss_tot += loss
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()
    
    if epoch%eval_epochs_interval==0:
        
        f1, accuracy, auc_score, loss_val, proba_avg_zero, proba_avg_one, predicted_probas_list, true_labels_list= naive_model.eval_model(dataloader_test)
        print(f'loss_val = {loss_val}')

    print(f'epoch {epoch} ended in {time.time() - start_time}')
    print(f'loss_train = {loss_tot / len(dataloader_train)}')


In [None]:
pheno_nb = 10
linear_weights_predictor = nn.Linear(pheno_nb, pheno_nb,bias=True,  dtype=float)
logits_predictor = nn.Linear(pheno_nb, 2 *pheno_nb, bias=True, dtype=float)

In [None]:
x = torch.rand(5, 10, dtype=float)
B, P = x.shape


In [None]:
weights = linear_weights_predictor(x)
prob_weights = F.softmax(weights).view(B, P, 1)
weights = linear_weights_predictor(x)
prob_weights = F.softmax(weights).view(B, P, 1)

logits = logits_predictor(x).view(B, P, 2)


In [None]:
logits.shape, weights.shape

In [None]:
logits = (logits.transpose(1, 2)) @ prob_weights


In [None]:
logits.shape

In [None]:
#logits = (logits.transpose(1, 2)) @ prob_weights
pred_probas = F.softmax(logits.view(5, 2))


In [None]:
pred_probas