In [1]:
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from plink_datasets import *
from ANN_MCC import *
import sys
import sklearn.metrics as metrics
from ray import tune
import scikitplot as skplt
import pickle
import matplotlib.image as mpimg

In [2]:
preprocess = "severe_clumped_test"
prefix = "_deep_mlp_m"
path = "/home/kasper/data/wheeze/additional_gwas/severe/data/" + preprocess
G = read_plink1_bin(path+".bed",
                    path+".bim",
                    path+".fam",
                    verbose=False)
phenotype = pd.read_csv("/home/kasper/data/wheeze/additional_gwas/severe/data/phenotypes.txt", 
                        sep = " ")

In [98]:
snp = ["rs3795262", "rs1817914"]
G_frame = G.to_series()
#G_new = G_frame.loc[G_frame['snp'].isin(snp)].to_xarray()

Unnamed: 0_level_0,Unnamed: 1_level_0,genotype
sample,variant,Unnamed: 2_level_1
21228_A1,variant0,2.0
21228_A1,variant1,2.0
21228_A1,variant2,2.0
21228_A1,variant3,2.0
21228_A1,variant4,2.0
...,...,...
10016077,variant3002,0.0
10016077,variant3003,0.0
10016077,variant3004,2.0
10016077,variant3005,1.0


In [87]:
G_new

In [11]:
G_top = G.values[:,top_indices]
gwas_data = PlinkDataset(G, phenotype, scale=True, shuffle=True)
N, P = G.shape
weight1 = sum(G.trait.values == "2") / N
weight2 = sum(G.trait.values == "1") / N
w = torch.tensor(weight2/weight1)

In [12]:
net = ANN(P, [140, 100, 60, 20], 1, act_func=nn.ReLU, mlp_m=True)
score, trainloss_list, valloss_list, model = train(net=net,
                                                   dataset=gwas_data,
                                                   batch_size=300,
                                                   nepochs=100,
                                                   criterion=nn.BCEWithLogitsLoss(pos_weight=w),
                                                   evaluate=MCCLoss_bin,
                                                   test = test_mcc_bin,
                                                   learning_rate=3e-4,
                                                   l1_const=5e-5,
                                                   l2_const=1e-2,
                                                   early_stopping=False,
                                                   verbose=True)

Epoch: 1 Training loss: 1.404422640800476 Correlation: 0.04634247347712517
Epoch: 2 Training loss: 1.1656630039215088 Correlation: 0.1594761610031128
Epoch: 3 Training loss: 1.0336060523986816 Correlation: 0.19110655784606934
Epoch: 4 Training loss: 1.0657905340194702 Correlation: 0.11130324751138687
Epoch: 5 Training loss: 1.1037344932556152 Correlation: 0.24917516112327576
Epoch: 6 Training loss: 0.9656583666801453 Correlation: 0.3899216055870056
Epoch: 7 Training loss: 0.9780314564704895 Correlation: 0.179626002907753
Epoch: 8 Training loss: 0.8899168372154236 Correlation: 0.5674626231193542
Epoch: 9 Training loss: 0.9120393991470337 Correlation: 0.29599636793136597
Epoch: 10 Training loss: 0.7657333016395569 Correlation: 0.6769374012947083
Epoch: 11 Training loss: 0.7891891002655029 Correlation: 0.543427586555481
Epoch: 12 Training loss: 0.7166086435317993 Correlation: 0.6572362184524536
Epoch: 13 Training loss: 0.7321398854255676 Correlation: 0.47451797127723694
Epoch: 14 Training

Epoch: 109 Training loss: 0.12196021527051926 Correlation: 0.9351791143417358
Epoch: 110 Training loss: 0.12950272858142853 Correlation: 0.9346030950546265
Epoch: 111 Training loss: 0.14555081725120544 Correlation: 0.9335007667541504
Epoch: 112 Training loss: 0.11585690826177597 Correlation: 0.935286819934845
Epoch: 113 Training loss: 0.1404419243335724 Correlation: 0.9304472208023071
Epoch: 114 Training loss: 0.13848340511322021 Correlation: 0.934125542640686
Epoch: 115 Training loss: 0.12087622284889221 Correlation: 0.9349665641784668
Epoch: 116 Training loss: 0.10577797144651413 Correlation: 0.9304078221321106
Epoch: 117 Training loss: 0.16617245972156525 Correlation: 0.9336511492729187
Epoch: 118 Training loss: 0.1429567188024521 Correlation: 0.9311250448226929
Epoch: 119 Training loss: 0.11249266564846039 Correlation: 0.9274007678031921
Epoch: 120 Training loss: 0.12399937957525253 Correlation: 0.9295927882194519
Epoch: 121 Training loss: 0.1061420813202858 Correlation: 0.92552024

KeyboardInterrupt: 

In [None]:
plot_path = "/home/kasper/data/wheeze/additional_gwas/severe/torch_models/Plots/"

train_plot = plt.figure()
plt.plot(trainloss_list)
plt.ylabel("Binary Cross-Entropy")
plt.xlabel("Epoch")
plt.title("Training loss")
pickle.dump(train_plot, open(plot_path + preprocess + prefix + "_train.pickle", "wb"))

val_plot = plt.figure()
plt.plot(valloss_list)
plt.ylabel("Matthew's Correlation Coefficient")
plt.xlabel("Epoch")
plt.title("Validation 'loss'")
pickle.dump(val_plot, open(plot_path + preprocess + prefix + "_val.pickle", "wb"))

model_path = "/home/kasper/data/wheeze/additional_gwas/severe/torch_models/" + preprocess + prefix + ".pth"
torch.save(model, model_path)