In [None]:
import sys
path = '/gpfs/commons/groups/gursoy_lab/mstoll/'
sys.path.append(path)

import os
import numpy as np
import pandas as pd
import time
import torch
import pickle
import shap
import tensorboard


from sklearn.metrics import roc_auc_score, f1_score, accuracy_score, classification_report
from functools import partial
import shutil
from tqdm.auto import tqdm

from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.experimental import enable_hist_gradient_boosting  # noqa
from sklearn.ensemble import HistGradientBoostingClassifier
from sklearn.ensemble import RandomForestClassifier

from sklearn.linear_model import Lasso
from sklearn.preprocessing import StandardScaler

from codes.models.data_form.DataForm import DataTransfo_1SNP
from codes.models.metrics import calculate_roc_auc

import featurewiz as gwiz

import matplotlib.pyplot as plt
from sklearn.calibration import calibration_curve

from codes.models.Decision_tree.utils import get_indice, get_name

In [None]:
### framework constants:
model_type = 'decision_tree'
model_version = 'gradient_boosting'
test_name = '1_test_train_transfo_V1'
tryout = True # True if we are ding a tryout, False otherwise 
### data constants:
### data constants:
CHR = 1
SNP = 'rs673604'
pheno_method = 'Paul' # Paul, Abby
ld = 'no'
rollup_depth = 4
binary_classes = True #nb of classes related to an SNP (here 0 or 1)
vocab_size = None # to be defined with data
padding_token = 0
prop_train_test = 0.8
load_data = False
save_data = True
remove_none = True
decorelate = False
equalize_label = False
threshold_corr = 0.9
threshold_rare = 50
remove_rare = 'all' # None, 'all', 'one_class'
compute_features = True
padding = False
list_env_features = ['age', 'sex']
list_pheno_ids = None #list(np.load(f'/gpfs/commons/groups/gursoy_lab/mstoll/codes/Data_Files/phewas/list_associations_snps/{SNP}_paul.npy'))

### data format

batch_size = 20
data_share = 1

##### model constants


##### training constants
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
dataT = DataTransfo_1SNP(SNP=SNP,
                         CHR=CHR,
                         method=pheno_method,
                         padding=padding,  
                         pad_token=padding_token, 
                         load_data=load_data, 
                         save_data=save_data, 
                         compute_features=compute_features,
                         prop_train_test=prop_train_test,
                         remove_none=remove_none,
                         equalize_label=equalize_label,
                         rollup_depth=rollup_depth,
                         decorelate=decorelate,
                         threshold_corr=threshold_corr,
                         threshold_rare=threshold_rare,
                         remove_rare=remove_rare, 
                         list_env_features=list_env_features,
                         data_share=data_share,
                         list_pheno_ids=list_pheno_ids,
                         binary_classes=binary_classes, 
                         ld = ld)
#patient_list = dataT.get_patientlist()

In [None]:
data, labels_patients, indices_env, name_envs, eids = dataT.get_tree_data(with_env=False, load_possible=True, only_relevant=False)

In [None]:
data = np.concatenate([data, labels_patients.reshape( len(labels_patients),1)], axis=1)

In [None]:
frequencies_ini = np.sum(data, axis=0)

In [None]:
equalized = False
interest = False
keep = False
scaled = False
remove = False

In [None]:
if interest:
    data_use, labels_use = data[:nb_patients_interest, :-1], labels_patients[:nb_patients_interest]
else:
    data_use, labels_use = data_complete, labels_patients
if remove:
    eids_remove = np.load('/gpfs/commons/groups/gursoy_lab/mstoll/codes/Data_Files/UKBB/eids_remove_1.npy')
    indices_eids = (1-np.isin(eids, eids_remove)).astype(bool)
    eids_use = eids[indices_eids]
    data_use = data_use[indices_eids]
    labels_use = labels_use[indices_eids]
    
if equalized:
    pheno, labels = DataTransfo_1SNP.equalize_label(data=data_use, labels = labels_use)
else:
    pheno, labels = data_use, labels_use


In [None]:
diseases_patients_train, diseases_patients_test, label_patients_train, label_patients_test = train_test_split(data, labels, test_size = 1-prop_train_test, random_state=42)

In [None]:
indices_keep = (frequencies_ini > 0) & (frequencies_ini > 100)
#indices_keep = shaps!=0
diseases_patients_train_keep = diseases_patients_train[:,indices_keep]
diseases_patients_test_keep = diseases_patients_test[:, indices_keep]
if keep:
    diseases_patients_train_model = diseases_patients_train_keep
    diseases_patients_test_model = diseases_patients_test_keep
else:
    diseases_patients_train_model = diseases_patients_train
    diseases_patients_test_model = diseases_patients_test



In [None]:
### maskage:
nb_features = diseases_patients_train_model.shape[1]


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

class Generator(nn.Module):
    def __init__(self, latent_dim, feature_dim):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_dim, 128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(128, 256),
            nn.BatchNorm1d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, feature_dim),
            nn.Sigmoid()  # Assuming input data range [0, 1]
        )

    def forward(self, data_random, data_truth):
        data_gen = self.model(data_random)
        loss_pheno = torch.norm(data_gen[:, :-1], data_truth[:, :-1] ) / np.sqrt(data_gen.numel())
        loss_labels = torch.norm(data_gen[:, -1], data_truth[:, -1] ) / np.sqrt(data_gen.numel())


        return data_gen, loss_pheno + loss_labels

    def eval(self, data_mask, data_truth):
        data_gen = self.model(data_mask)
        indices_mask = np.where(data)

class Discriminator(nn.Module):
    def __init__(self, feature_dim):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(feature_dim, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 128),
            nn.BatchNorm1d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )

    def forward(self, features):
        validity = self.model(features)
        return validity



In [None]:
# Hyperparameters
latent_dim = nb_features
feature_dim = nb_features # Number of features in your input data
lr = 0.0002
batch_size = 64
epochs = 10

In [None]:
generator = Generator(latent_dim, feature_dim)
discriminator = Discriminator(feature_dim)

# Loss function and optimizer
adversarial_loss = nn.BCELoss()
optimizer_G = optim.Adam(generator.parameters(), lr=lr)
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr)

# Convert your data to PyTorch TensorDataset
data_tensor_train_random = torch.tensor(diseases_patients_train_model, dtype=torch.float32)
data_tensor_train_truth = torch.tensor(diseases_patients_train_model, dtype=torch.float32)
data_tensor_train_random[:, -1] = torch.rand(len(data_tensor_train_random[:, -1]))

data_tensor_test_random = torch.tensor(diseases_patients_test_model, dtype=torch.float32)
data_tensor_test_truth = torch.tensor(diseases_patients_test_model, dtype=torch.float32)
data_tensor_test_random[:, -1] = torch.rand(len(data_tensor_test_random[:, -1]))


In [None]:
dataloader_train = DataLoader(list(zip(data_tensor_train, labels_tensor_train)), batch_size=20)
dataloader_test = DataLoader(list(zip(data_tensor_test, labels_tensor_test)), batch_size=20)

In [None]:
data_batch_train.shape, labels_batch_train.shape

In [None]:

# Training Loop
for epoch in range(epochs):
    for i, (data_batch_train, labels_batch_train) in enumerate(dataloader_train):
        
        # Train Generator
        optimizer_G.zero_grad()
        d_gen, loss = generator(data_batch_train, labels_batch_train)
        loss.backward()
        optimizer_G.step()
        """
        # Train Discriminator
        optimizer_D.zero_grad()
        d_real_loss = adversarial_loss(discriminator(features), valid)
        d_fake_loss = adversarial_loss(discriminator(gen_features.detach()), fake)
        d_loss = (d_real_loss + d_fake_loss) / 2
        d_loss.backward()
        optimizer_D.step()
        """
        print(
            "[Epoch %d/%d] [Batch %d/%d] [G loss: %f]"
            % (epoch, epochs, i, len(dataloader_train), loss.item())
        )
