In [None]:
import numpy as np
import pandas as pd
import seaborn as sea
import matplotlib.pyplot as plt

from sklearn.model_selection import StratifiedKFold

sea.set_style("whitegrid")
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
# https://stackoverflow.com/questions/21971449/how-do-i-increase-the-cell-width-of-the-jupyter-ipython-notebook-in-my-browser

from IPython.core.display import display, HTML
display(HTML("<style>.container { width:90% !important; }</style>"))

In [None]:
import warnings
warnings.simplefilter('ignore')

In [None]:
import os
import gc
import time
import copy
import shutil

import torch
import torch.nn as nn
import meth_model_utils as u
import meth_model_classes as c
import torch.nn.functional as F

from torchvision import utils
from torch.utils.data import Dataset, DataLoader

In [None]:
torch.use_deterministic_algorithms(True)
os.environ["CUBLAS_WORKSPACE_CONFIG"]=":16:8"

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
# https://pytorch.org/docs/stable/notes/randomness.html
seed = u.get_seed()
print('Seed = ', seed)
u.set_all_seeds(seed)

# Setup dataframe

In [None]:
PATH = "D:/CANCER BIOLOGY/DATASET/TCGA/FROM Xena/"

In [None]:
df_luad = pd.read_csv(PATH+"meth_luad.csv", index_col=0)
df_lusu = pd.read_csv(PATH+"meth_lusu.csv", index_col=0)
df_lusu.drop(index=['MBD3L2'], axis=0, inplace=True)
df_final = u.meth_data_preprocess(df_luad, df_lusu)

In [None]:
df_final = df_final.sample(frac=1, random_state=seed).reset_index(drop=True)
labels = list(df_final['label'])
df_final.drop(columns=['label'], axis=1, inplace=True) ## drop column sample_id and label
columns = list(df_final.columns)

In [None]:
xtrain = df_final.to_numpy()
ytrain = labels

---
---
---

In [None]:
u.set_all_seeds(seed)

In [None]:
input_dim = xtrain.shape[1]
epochs = 20
batch_size = 32
learning_rate = 0.0000005
output_dim = 2048

In [None]:
# Load the saved model trained at: XENA_LUNG_METH_AutoEncoder
# This will be reloaded after every K-Fold iteration. It will act as reset weights.

saved_model = torch.load(PATH+"models/LUNG_METH_Autoencoder.kd")    ## when GPU is available

In [None]:
# Setup the Stratified K-Fold Cross Validation
cumulative_train_acc, cumulative_test_acc = 0., 0.
k = 10
kfold = StratifiedKFold(n_splits=k, shuffle=True, random_state=seed)

In [None]:
list_avg_train_acc_per_fold=[]
list_avg_valid_acc_per_fold=[]

for fold, (train_index, test_index) in enumerate(kfold.split(xtrain, ytrain)):
    text=HTML("<h1>Fold: {}</h1>".format(fold+1))
    display(text)
    ##------------------------------------------------------------------------------------##
    
    ## collect the rows for train and test
    ## https://stackoverflow.com/questions/19155718/select-pandas-rows-based-on-list-index
    k_xtrain, k_xtest = xtrain[train_index], xtrain[test_index]
    k_ytrain, k_ytest = np.array(ytrain)[train_index], np.array(ytrain)[test_index]
    
    ## create train_dataset and test_dataset of class LUNG_METH
    train_dataset = c.LUNG_Meth(k_ytrain, k_xtrain)
    test_dataset = c.LUNG_Meth(k_ytest, k_xtest)
    
    
    u.set_all_seeds(seed)
    ## create dataloaders for train_dataset and test_dataset
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

    ##------------------------------------------------------------------------------------##

    torch.cuda.empty_cache()

    # create an object of class AutoEncoder and load the saved model from LEVEL 1
    only_encoder = c.METH_AutoEncoder(input_dim, output_dim)
    only_encoder.load_state_dict(saved_model)

    # detach the decoder part from the saved model
    only_encoder = nn.Sequential(*list(only_encoder.children())[:-1])
    

    # create an object of class Classifier and pass the only_encoder object
    classifier = c.METH_Classifier(only_encoder, output_dim)
    for params in classifier.encoder.parameters():
        params.requires_grad=False
    classifier.to(device)

    ##------------------------------------------------------------------------------------##

    ## setup the optimizer
    optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, classifier.parameters()), lr=learning_rate)
#     optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, classifier.parameters()), lr=learning_rate)

    ## call the train function, print accuracy and plot !!!!!!!!
    u.set_all_seeds(seed)
    log_dict = u.train_classifier(
        num_epochs = epochs,
        model = classifier,
        optimizer = optimizer,
        device = device,
        train_loader = train_loader,
        valid_loader = test_loader,
        patience = 20
    )
    
    print(f"Train accuracy avg: {np.mean(list(log_dict['train_acc']))}, Valid accuracy avg: {np.mean(list(log_dict['valid_acc']))}")
    list_avg_train_acc_per_fold.append(np.mean(list(log_dict['train_acc'])))
    list_avg_valid_acc_per_fold.append(np.mean(list(log_dict['valid_acc'])))
    
    cumulative_train_acc += np.mean(list(log_dict['train_acc']))
    cumulative_test_acc += np.mean(list(log_dict['valid_acc']))


    ##------------------------------------------------------------------------------------##


In [None]:
torch.save(classifier.state_dict(), PATH+"models/LUNG_METH_Classifier.kd")

In [None]:
u.plot_train_test_k_fold_accuracy(
    list_avg_train_acc_per_fold,
    list_avg_valid_acc_per_fold,
    N=k, 
    width=0.45,
    width_mult=1,
    fig_size=(28, 8), 
    title='K-FOLD Accuracy Chart ===> Overall avg_train_acc: {:.4f}, Overall avg_valid_acc: {:.4f}'.format(cumulative_train_acc/k, cumulative_test_acc/k),
    x_ticks=('Fold=1', 'Fold=2', 'Fold=3', 'Fold=4', 'Fold=5', 'Fold=6', 'Fold=7', 'Fold=8', 'Fold=9', 'Fold=10' ),
    legends=('Train', 'Validation'),
    file_path=PATH+"project_summary_seed_wise_meth/seed="+str(seed)+"/classifier_on_k_fold",
)

In [None]:
from scipy.io.wavfile import read

fs, data = read('alert.wav', mmap=True)  # fs - sampling frequency
data = data.reshape(-1, 1)
import sounddevice as sd
sd.play(data, 44100)

---
---
---

# LEVEL 2 complete !!