In [1]:
import torch
import os
from torch_geometric.data import Data
import pandas as pd
import numpy as np
from torchvision.transforms import ToTensor
from torch_geometric.loader import DataLoader
from tqdm.auto import trange, tqdm
from torch_geometric import seed_everything
import torchvision

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import wandb
wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33mmeagoodboy[0m ([33mitsagoodteam[0m). Use [1m`wandb login --relogin`[0m to force relogin


True

In [3]:
from encoder_models import *
from node_decoder_models import *
from test_train_validate import *
from data_processing.process_data import *
from utils import *
from wrappers import *

In [4]:
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
print(device)

cuda:1


In [5]:
clear_cache()

In [6]:
seed_everything(36912)

In [7]:
omics_file = "../data/raw/sm/kirp_sm251_csn.csv"
network_file = "../data/network/CancerSubnetwork.txt"

learning_rate = 3e-5
num_features = 1
out_channels = 1
num_epochs = 50
batch_size = 16
alpha = 1
beta = 0.5

cancer = "KIRP"
omic = "SM"
gmodel = "L3GravNetConv"
optim = "ADAM"

savefolder = "./new_res/" + cancer + "/" + omic + "/" + gmodel + "/" +  optim + "_" + str(learning_rate) +"/"
savemodels = "./new_mod/" + cancer + "/" + omic + "/" + gmodel + "/" + optim + "_" +  str(learning_rate) +"/"
savename = cancer + "_" + omic + "_" + gmodel + "_" + optim + "_" + str(learning_rate) +"_"

savename = savemodels + savename 
summaryin = savemodels + "runs"
bestmodel = savename + "bestmodel.pt"
finalmodel = savename + "model.pt" 
configf = savename + "config.yml" 
fencsave = savefolder + "final.csv"
bencsave = savefolder + "best.csv"

if not os.path.exists(savefolder):
    os.makedirs(savefolder)
if not os.path.exists(savemodels):
    os.makedirs(savemodels)

In [8]:
wandb.init(project="TEST1_KIRP_GCNM2")

cfg = wandb.config
cfg.update({"epochs" : num_epochs, "batch_size": batch_size, "lr" : learning_rate,"optim" : optim,"data_type" : omic , "cancer" : cancer,"save":savefolder,"model_type":gmodel})

In [9]:
data = SingleOmicData(network_file, omics_file, 1)
num_nodes = len(data.node_order)

Loading the Network :  ../data/network/CancerSubnetwork.txt


100%|██████████| 2291/2291 [00:01<00:00, 2188.13it/s]

Loaded the Network 





In [10]:
train_size = int(0.8 * len(data))
x,y = torch.utils.data.random_split(data, lengths=[train_size, len(data) - train_size], generator=torch.Generator())

In [11]:
train_loader = DataLoader(x, shuffle=True, batch_size=batch_size,num_workers=8)
val_loader = DataLoader(y, shuffle=True, batch_size=batch_size, num_workers=8)
encode_loader = DataLoader(data, shuffle=False, batch_size=1)

In [12]:
model = GAEM( encoder = L3GravNetConv(in_channels = num_features, out_channels = out_channels),
             node_decoder = L2Linear(out_channels = out_channels, num_nodes = num_nodes, batch_size=batch_size) )

Passed edge index will not be used while learning


In [13]:
model = model.to(device)

In [14]:
!nvidia-smi

Tue Nov  8 12:10:58 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 510.85.02    Driver Version: 510.85.02    CUDA Version: 11.6     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA GeForce ...  Off  | 00000000:02:00.0 Off |                  N/A |
| 32%   37C    P2    55W / 250W |   1796MiB / 11264MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA GeForce ...  Off  | 00000000:03:00.0 Off |                  N/A |
| 32%   35C    P2    57W / 250W |   1364MiB / 11264MiB |      4%      Default |
|       

In [15]:
print("total parameters in the model : ", calculate_num_params(model))
print("total parameters in the encoder : ", calculate_num_params(model.encoder))
print("total parameters in the node_decoder : ", calculate_num_params(model.node_decoder))
print("total parameters in the decoder : ", calculate_num_params(model.decoder))

total parameters in the model :  10513811
total parameters in the encoder :  7285
total parameters in the node_decoder :  10506526
total parameters in the decoder :  0


In [16]:
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [17]:
lossfn = torchvision.ops.focal_loss.sigmoid_focal_loss

In [None]:
all_val_loss = []
for e in range(num_epochs):
    
    print("Epoch : ", e + 1,  " / " , num_epochs)
    
    train_loss = train(model,train_loader,optimizer, device,lossfn )
    val_loss = validate(model,val_loader,device, lossfn)
    all_val_loss.append(val_loss)
    wandb.log({"validation loss" : val_loss,"train loss" : train_loss})
    
    if e > 10:
        if val_loss < min(all_val_loss):
            torch.save(model.state_dict(), bestmodel)
            print("Saved best model weights")
    
    if (e+1) % 20 == 0:
        print("Saving Model")
        torch.save(model.state_dict(), finalmodel)

wandb.finish()

Epoch :  1  /  50


100%|██████████| 13/13 [00:46<00:00,  3.59s/it]


Train Loss :  0.06478372732034096 1.3862944749685435 1.451078204008249


100%|██████████| 4/4 [00:14<00:00,  3.55s/it]


Validation Loss :    1.4510782063007355
Epoch :  2  /  50


100%|██████████| 13/13 [00:58<00:00,  4.47s/it]


Train Loss :  0.06472643865988804 1.3862944749685435 1.4510209193596473


100%|██████████| 4/4 [00:15<00:00,  3.90s/it]


Validation Loss :    1.451000839471817
Epoch :  3  /  50


100%|██████████| 13/13 [01:08<00:00,  5.28s/it]


Train Loss :  0.06466017147669426 1.3862944749685435 1.4509546481646025


100%|██████████| 4/4 [00:12<00:00,  3.05s/it]


Validation Loss :    1.4509678184986115
Epoch :  4  /  50


100%|██████████| 13/13 [00:46<00:00,  3.61s/it]


Train Loss :  0.06458776214948067 1.3862944749685435 1.450882233106173


100%|██████████| 4/4 [00:16<00:00,  4.04s/it]


Validation Loss :    1.4508999288082123
Epoch :  5  /  50


100%|██████████| 13/13 [00:53<00:00,  4.10s/it]


Train Loss :  0.06450487845219098 1.3862944749685435 1.450799355140099


100%|██████████| 4/4 [00:15<00:00,  3.81s/it]


Validation Loss :    1.4507819712162018
Epoch :  6  /  50


100%|██████████| 13/13 [00:49<00:00,  3.84s/it]


Train Loss :  0.06441295720063724 1.3862944749685435 1.4507074356079102


100%|██████████| 4/4 [00:12<00:00,  3.04s/it]


Validation Loss :    1.4506526589393616
Epoch :  7  /  50


100%|██████████| 13/13 [00:52<00:00,  4.04s/it]


Train Loss :  0.0643064219218034 1.3862944749685435 1.4506008991828332


100%|██████████| 4/4 [00:15<00:00,  3.79s/it]


Validation Loss :    1.4505089223384857
Epoch :  8  /  50


100%|██████████| 13/13 [00:45<00:00,  3.50s/it]


Train Loss :  0.06419024845728508 1.3862944749685435 1.4504847343151386


100%|██████████| 4/4 [00:12<00:00,  3.06s/it]


Validation Loss :    1.450398862361908
Epoch :  9  /  50


100%|██████████| 13/13 [00:45<00:00,  3.48s/it]


Train Loss :  0.06405924547177094 1.3862944749685435 1.4503537049660316


100%|██████████| 4/4 [00:12<00:00,  3.07s/it]


Validation Loss :    1.4502079486846924
Epoch :  10  /  50


100%|██████████| 13/13 [00:45<00:00,  3.51s/it]


Train Loss :  0.06391635881020473 1.3862944749685435 1.4502108188775868


100%|██████████| 4/4 [00:12<00:00,  3.01s/it]


Validation Loss :    1.450006365776062
Epoch :  11  /  50


100%|██████████| 13/13 [00:45<00:00,  3.50s/it]


Train Loss :  0.06375880080919999 1.3862944749685435 1.4500532883864183


100%|██████████| 4/4 [00:12<00:00,  3.02s/it]


Validation Loss :    1.4497928023338318
Epoch :  12  /  50


100%|██████████| 13/13 [00:45<00:00,  3.49s/it]


Train Loss :  0.06359308958053589 1.3862944749685435 1.449887569134052


100%|██████████| 4/4 [00:12<00:00,  3.11s/it]


Validation Loss :    1.449685424566269
Epoch :  13  /  50


100%|██████████| 13/13 [00:45<00:00,  3.46s/it]


Train Loss :  0.0634097084403038 1.3862944749685435 1.4497041702270508


100%|██████████| 4/4 [00:12<00:00,  3.07s/it]


Validation Loss :    1.4494941532611847
Epoch :  14  /  50


100%|██████████| 13/13 [00:45<00:00,  3.48s/it]


Train Loss :  0.06321795284748077 1.3862944749685435 1.4495124358397264


100%|██████████| 4/4 [00:12<00:00,  3.10s/it]


Validation Loss :    1.4492649734020233
Epoch :  15  /  50


100%|██████████| 13/13 [00:45<00:00,  3.47s/it]


Train Loss :  0.06301239132881165 1.3862944749685435 1.4493068640048687


100%|██████████| 4/4 [00:12<00:00,  3.08s/it]


Validation Loss :    1.4489562213420868
Epoch :  16  /  50


100%|██████████| 13/13 [00:45<00:00,  3.47s/it]


Train Loss :  0.06279614739693128 1.3862944749685435 1.4490906183536236


100%|██████████| 4/4 [00:12<00:00,  3.06s/it]


Validation Loss :    1.4487537145614624
Epoch :  17  /  50


100%|██████████| 13/13 [00:45<00:00,  3.47s/it]


Train Loss :  0.0625648616025081 1.3862944749685435 1.4488593339920044


100%|██████████| 4/4 [00:12<00:00,  3.13s/it]


Validation Loss :    1.4485979676246643
Epoch :  18  /  50


100%|██████████| 13/13 [00:45<00:00,  3.48s/it]


Train Loss :  0.06232685653062967 1.3862944749685435 1.4486213372303889


100%|██████████| 4/4 [00:12<00:00,  3.14s/it]


Validation Loss :    1.4483719170093536
Epoch :  19  /  50


100%|██████████| 13/13 [00:45<00:00,  3.48s/it]


Train Loss :  0.06207472085952759 1.3862944749685435 1.4483691912430983


100%|██████████| 4/4 [00:12<00:00,  3.05s/it]


Validation Loss :    1.4480499625205994
Epoch :  20  /  50


100%|██████████| 13/13 [00:45<00:00,  3.47s/it]


Train Loss :  0.061813995242118835 1.3862944749685435 1.4481084805268507


100%|██████████| 4/4 [00:12<00:00,  3.02s/it]


Validation Loss :    1.4477530717849731
Saving Model
Epoch :  21  /  50


100%|██████████| 13/13 [00:45<00:00,  3.47s/it]


Train Loss :  0.061538735834451824 1.3862944749685435 1.4478331895974965


100%|██████████| 4/4 [00:12<00:00,  3.04s/it]


Validation Loss :    1.4475047588348389
Epoch :  22  /  50


100%|██████████| 13/13 [00:45<00:00,  3.49s/it]


Train Loss :  0.061258756483976655 1.3862944749685435 1.44755324950585


100%|██████████| 4/4 [00:12<00:00,  3.04s/it]


Validation Loss :    1.447207361459732
Epoch :  23  /  50


100%|██████████| 13/13 [00:45<00:00,  3.53s/it]


Train Loss :  0.06096174997779039 1.3862944749685435 1.4472562074661255


100%|██████████| 4/4 [00:12<00:00,  3.03s/it]


Validation Loss :    1.4469716846942902
Epoch :  24  /  50


100%|██████████| 13/13 [00:45<00:00,  3.49s/it]


Train Loss :  0.06065853455891976 1.3862944749685435 1.4469530123930712


100%|██████████| 4/4 [00:12<00:00,  3.07s/it]


Validation Loss :    1.4464816451072693
Epoch :  25  /  50


100%|██████████| 13/13 [00:45<00:00,  3.48s/it]


Train Loss :  0.06034500724994219 1.3862944749685435 1.4466394919615526


100%|██████████| 4/4 [00:12<00:00,  3.02s/it]


Validation Loss :    1.4462055265903473
Epoch :  26  /  50


100%|██████████| 13/13 [00:45<00:00,  3.49s/it]


Train Loss :  0.060026480028262504 1.3862944749685435 1.4463209463999822


100%|██████████| 4/4 [00:12<00:00,  3.00s/it]


Validation Loss :    1.4459540843963623
Epoch :  27  /  50


100%|██████████| 13/13 [00:46<00:00,  3.54s/it]


Train Loss :  0.059690166264772415 1.3862944749685435 1.4459846386542687


100%|██████████| 4/4 [00:12<00:00,  3.04s/it]


Validation Loss :    1.4453816413879395
Epoch :  28  /  50


100%|██████████| 13/13 [00:45<00:00,  3.50s/it]


Train Loss :  0.059352774173021317 1.3862944749685435 1.4456472488550038


100%|██████████| 4/4 [00:12<00:00,  3.05s/it]


Validation Loss :    1.445024847984314
Epoch :  29  /  50


100%|██████████| 13/13 [00:45<00:00,  3.49s/it]


Train Loss :  0.059003345954876676 1.3862944749685435 1.4452978097475493


100%|██████████| 4/4 [00:12<00:00,  3.08s/it]


Validation Loss :    1.4441677629947662
Epoch :  30  /  50


100%|██████████| 13/13 [00:45<00:00,  3.48s/it]


Train Loss :  0.058644470400535144 1.3862944749685435 1.4449389347663293


100%|██████████| 4/4 [00:12<00:00,  3.07s/it]


Validation Loss :    1.4436251819133759
Epoch :  31  /  50


100%|██████████| 13/13 [00:44<00:00,  3.46s/it]


Train Loss :  0.058281040822084136 1.3862944749685435 1.4445755298321064


100%|██████████| 4/4 [00:12<00:00,  3.13s/it]


Validation Loss :    1.4434880316257477
Epoch :  32  /  50


100%|██████████| 13/13 [00:44<00:00,  3.46s/it]


Train Loss :  0.05790732800960541 1.3862944749685435 1.4442018178793101


100%|██████████| 4/4 [00:12<00:00,  3.10s/it]


Validation Loss :    1.443765640258789
Epoch :  33  /  50


100%|██████████| 13/13 [00:45<00:00,  3.50s/it]


Train Loss :  0.05752801150083542 1.3862944749685435 1.443822484750014


100%|██████████| 4/4 [00:12<00:00,  3.17s/it]


Validation Loss :    1.4439160525798798
Epoch :  34  /  50


100%|██████████| 13/13 [00:44<00:00,  3.46s/it]


Train Loss :  0.057138168467925146 1.3862944749685435 1.4434326520332923


100%|██████████| 4/4 [00:12<00:00,  3.20s/it]


Validation Loss :    1.443743735551834
Epoch :  35  /  50


100%|██████████| 13/13 [00:45<00:00,  3.47s/it]


Train Loss :  0.05674797439804444 1.3862944749685435 1.4430424616887019


100%|██████████| 4/4 [00:12<00:00,  3.14s/it]


Validation Loss :    1.4427925050258636
Epoch :  36  /  50


100%|██████████| 13/13 [00:45<00:00,  3.46s/it]


Train Loss :  0.056346737708036713 1.3862944749685435 1.4426412215599647


100%|██████████| 4/4 [00:12<00:00,  3.11s/it]


Validation Loss :    1.442739337682724
Epoch :  37  /  50


100%|██████████| 13/13 [00:45<00:00,  3.48s/it]


Train Loss :  0.0559407607294046 1.3862944749685435 1.4422352313995361


100%|██████████| 4/4 [00:12<00:00,  3.02s/it]


Validation Loss :    1.4420102536678314
Epoch :  38  /  50


 62%|██████▏   | 8/13 [00:29<00:17,  3.48s/it]

In [None]:
final_vectors = encode(model,encode_loader,device)
final_vectors = np.array(final_vectors).reshape(len(data.patients), -1)

In [None]:
final_df = pd.DataFrame(final_vectors, index=data.patients, columns = data.node_order)
final_df.to_csv(fencsave)

In [None]:
!nvidia-smi

In [None]:
wandb.finish()