In [1]:
import torch
import os
from torch_geometric.data import Data
import pandas as pd
import numpy as np
from torchvision.transforms import ToTensor
from torch_geometric.loader import DataLoader
from tqdm.auto import trange, tqdm
from torch_geometric import seed_everything
import torchvision

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import wandb
wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33mmeagoodboy[0m ([33mitsagoodteam[0m). Use [1m`wandb login --relogin`[0m to force relogin


True

In [3]:
from encoder_models import *
from node_decoder_models import *
from test_train_validate import *
from data_processing.process_data import *
from utils import *
from wrappers import *

In [4]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [5]:
clear_cache()

In [6]:
seed_everything(36912)

In [7]:
omics_file = "../data/raw/sm/kirp_sm251_csn.csv"
network_file = "../data/network/CancerSubnetwork.txt"

learning_rate = 3e-4
num_features = 1
out_channels = 1
num_epochs = 50
batch_size = 8
alpha = 1
beta = 0.5

cancer = "KIRP"
omic = "SM"
gmodel = "L2GravNetConv"
optim = "ADAM"

savefolder = "./new_res/" + cancer + "/" + omic + "/" + gmodel + "/" +  optim + "_" + str(learning_rate) +"/"
savemodels = "./new_mod/" + cancer + "/" + omic + "/" + gmodel + "/" + optim + "_" +  str(learning_rate) +"/"
savename = cancer + "_" + omic + "_" + gmodel + "_" + optim + "_" + str(learning_rate) +"_"

savename = savemodels + savename 
summaryin = savemodels + "runs"
bestmodel = savename + "bestmodel.pt"
finalmodel = savename + "model.pt" 
configf = savename + "config.yml" 
fencsave = savefolder + "final.csv"
bencsave = savefolder + "best.csv"

if not os.path.exists(savefolder):
    os.makedirs(savefolder)
if not os.path.exists(savemodels):
    os.makedirs(savemodels)

In [8]:
wandb.init(project="TEST1_KIRP_GCNM2")

cfg = wandb.config
cfg.update({"epochs" : num_epochs, "batch_size": batch_size, "lr" : learning_rate,"optim" : optim,"data_type" : omic , "cancer" : cancer,"save":savefolder,"model_type":gmodel})

In [9]:
data = SingleOmicData(network_file, omics_file, 1)
num_nodes = len(data.node_order)

Loading the Network :  ../data/network/CancerSubnetwork.txt


100%|██████████| 2291/2291 [00:01<00:00, 1958.39it/s]


Loaded the Network 


In [10]:
train_size = int(0.8 * len(data))
x,y = torch.utils.data.random_split(data, lengths=[train_size, len(data) - train_size], generator=torch.Generator())

In [11]:
train_loader = DataLoader(x, shuffle=True, batch_size=batch_size,num_workers=8)
val_loader = DataLoader(y, shuffle=True, batch_size=batch_size, num_workers=8)
encode_loader = DataLoader(data, shuffle=False, batch_size=1)

In [12]:
model = GAEM( encoder = L2GravNetConv(in_channels = num_features, out_channels = out_channels),
             node_decoder = L2Linear(out_channels = out_channels, num_nodes = num_nodes, batch_size=batch_size) )

Passed edge index will not be used while learning


In [13]:
model = model.to(device)

In [14]:
!nvidia-smi

Tue Nov  8 12:10:44 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 510.85.02    Driver Version: 510.85.02    CUDA Version: 11.6     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA GeForce ...  Off  | 00000000:02:00.0 Off |                  N/A |
| 32%   36C    P2    55W / 250W |   1364MiB / 11264MiB |      3%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA GeForce ...  Off  | 00000000:03:00.0 Off |                  N/A |
| 31%   33C    P8    18W / 250W |      3MiB / 11264MiB |      0%      Default |
|       

In [15]:
print("total parameters in the model : ", calculate_num_params(model))
print("total parameters in the encoder : ", calculate_num_params(model.encoder))
print("total parameters in the node_decoder : ", calculate_num_params(model.node_decoder))
print("total parameters in the decoder : ", calculate_num_params(model.decoder))

total parameters in the model :  10508154
total parameters in the encoder :  1628
total parameters in the node_decoder :  10506526
total parameters in the decoder :  0


In [16]:
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [17]:
lossfn = torchvision.ops.focal_loss.sigmoid_focal_loss

In [None]:
all_val_loss = []
for e in range(num_epochs):
    
    print("Epoch : ", e + 1,  " / " , num_epochs)
    
    train_loss = train(model,train_loader,optimizer, device,lossfn )
    val_loss = validate(model,val_loader,device, lossfn)
    all_val_loss.append(val_loss)
    wandb.log({"validation loss" : val_loss,"train loss" : train_loss})
    
    if e > 10:
        if val_loss < min(all_val_loss):
            torch.save(model.state_dict(), bestmodel)
            print("Saved best model weights")
    
    if (e+1) % 20 == 0:
        print("Saving Model")
        torch.save(model.state_dict(), finalmodel)

wandb.finish()

Epoch :  1  /  50


100%|██████████| 25/25 [00:44<00:00,  1.76s/it]


Train Loss :  0.009452477167360484 1.386307454109192 1.3957599401474


100%|██████████| 7/7 [00:12<00:00,  1.73s/it]


Validation Loss :    1.4177264826638358
Epoch :  2  /  50


100%|██████████| 25/25 [00:47<00:00,  1.88s/it]


Train Loss :  0.0020663418294861913 1.3862940073013306 1.3883603382110596


100%|██████████| 7/7 [00:17<00:00,  2.52s/it]


Validation Loss :    1.4337471553257533
Epoch :  3  /  50


100%|██████████| 25/25 [01:09<00:00,  2.77s/it]


Train Loss :  0.001777491131797433 1.3862943696975707 1.3880718564987182


100%|██████████| 7/7 [00:16<00:00,  2.30s/it]


Validation Loss :    1.4333449772426061
Epoch :  4  /  50


100%|██████████| 25/25 [00:44<00:00,  1.79s/it]


Train Loss :  0.0017423838702961803 1.3862943792343139 1.3880367612838744


100%|██████████| 7/7 [00:11<00:00,  1.70s/it]


Validation Loss :    1.4159549134118217
Epoch :  5  /  50


100%|██████████| 25/25 [00:53<00:00,  2.13s/it]


Train Loss :  0.0017271271720528602 1.3862943840026856 1.3880215215682983


100%|██████████| 7/7 [00:11<00:00,  1.65s/it]


Validation Loss :    1.404775619506836
Epoch :  6  /  50


100%|██████████| 25/25 [00:53<00:00,  2.15s/it]


Train Loss :  0.0017335002636536957 1.386294388771057 1.388027892112732


100%|██████████| 7/7 [00:11<00:00,  1.62s/it]


Validation Loss :    1.4009373017719813
Epoch :  7  /  50


100%|██████████| 25/25 [00:43<00:00,  1.74s/it]


Train Loss :  0.0017312920140102506 1.3862943792343139 1.3880256605148316


100%|██████████| 7/7 [00:13<00:00,  1.90s/it]


Validation Loss :    1.3984499147960119
Epoch :  8  /  50


100%|██████████| 25/25 [00:52<00:00,  2.08s/it]


Train Loss :  0.0017536655627191067 1.3862943840026856 1.3880480527877808


100%|██████████| 7/7 [00:11<00:00,  1.64s/it]


Validation Loss :    1.3955227988106864
Epoch :  9  /  50


100%|██████████| 25/25 [00:43<00:00,  1.75s/it]


Train Loss :  0.0017684862017631532 1.3862943983078002 1.3880628871917724


100%|██████████| 7/7 [00:11<00:00,  1.65s/it]


Validation Loss :    1.39213946887425
Epoch :  10  /  50


100%|██████████| 25/25 [00:42<00:00,  1.71s/it]


Train Loss :  0.0017214056383818388 1.386294355392456 1.3880157566070557


100%|██████████| 7/7 [00:11<00:00,  1.66s/it]


Validation Loss :    1.3937766551971436
Epoch :  11  /  50


100%|██████████| 25/25 [00:43<00:00,  1.72s/it]


Train Loss :  0.0017685867194086314 1.386294388771057 1.3880629682540893


100%|██████████| 7/7 [00:11<00:00,  1.64s/it]


Validation Loss :    1.3939355782100133
Epoch :  12  /  50


100%|██████████| 25/25 [00:42<00:00,  1.70s/it]


Train Loss :  0.0017227774811908604 1.3862943744659424 1.3880171489715576


100%|██████████| 7/7 [00:11<00:00,  1.68s/it]


Validation Loss :    1.3933345420019967
Epoch :  13  /  50


100%|██████████| 25/25 [00:42<00:00,  1.71s/it]


Train Loss :  0.0017605960695073009 1.3862943649291992 1.3880549669265747


100%|██████████| 7/7 [00:12<00:00,  1.75s/it]


Validation Loss :    1.3914442913872855
Epoch :  14  /  50


100%|██████████| 25/25 [00:42<00:00,  1.70s/it]


Train Loss :  0.0017764647072181106 1.3862943696975707 1.3880708312988281


100%|██████████| 7/7 [00:12<00:00,  1.72s/it]


Validation Loss :    1.3916833571025304
Epoch :  15  /  50


100%|██████████| 25/25 [00:42<00:00,  1.71s/it]


Train Loss :  0.0017250369815155864 1.3862943840026856 1.3880194187164308


100%|██████████| 7/7 [00:11<00:00,  1.70s/it]


Validation Loss :    1.391499706677028
Epoch :  16  /  50


100%|██████████| 25/25 [00:42<00:00,  1.69s/it]


Train Loss :  0.0017814680142328144 1.3862943792343139 1.3880758571624756


100%|██████████| 7/7 [00:11<00:00,  1.69s/it]


Validation Loss :    1.3903991154261999
Epoch :  17  /  50


100%|██████████| 25/25 [00:43<00:00,  1.72s/it]


Train Loss :  0.0017556175868958235 1.3862943792343139 1.3880500030517577


100%|██████████| 7/7 [00:12<00:00,  1.72s/it]


Validation Loss :    1.3899864639554704
Epoch :  18  /  50


100%|██████████| 25/25 [00:42<00:00,  1.71s/it]


Train Loss :  0.0017646737722679972 1.3862943696975707 1.3880590534210204


100%|██████████| 7/7 [00:11<00:00,  1.69s/it]


Validation Loss :    1.3901661123548235
Epoch :  19  /  50


100%|██████████| 25/25 [00:42<00:00,  1.70s/it]


Train Loss :  0.001768951271660626 1.3862943696975707 1.3880633115768433


100%|██████████| 7/7 [00:11<00:00,  1.67s/it]


Validation Loss :    1.3898605108261108
Epoch :  20  /  50


100%|██████████| 25/25 [00:42<00:00,  1.70s/it]


Train Loss :  0.001730298036709428 1.3862943649291992 1.3880246591567993


100%|██████████| 7/7 [00:11<00:00,  1.67s/it]


Validation Loss :    1.39029826436724
Saving Model
Epoch :  21  /  50


100%|██████████| 25/25 [00:43<00:00,  1.72s/it]


Train Loss :  0.0017410535411909222 1.386294388771057 1.3880354499816894


100%|██████████| 7/7 [00:11<00:00,  1.64s/it]


Validation Loss :    1.3897355794906616
Epoch :  22  /  50


100%|██████████| 25/25 [00:43<00:00,  1.72s/it]


Train Loss :  0.0017536223633214832 1.3862943744659424 1.388047981262207


100%|██████████| 7/7 [00:11<00:00,  1.63s/it]


Validation Loss :    1.3892358711787633
Epoch :  23  /  50


100%|██████████| 25/25 [00:42<00:00,  1.72s/it]


Train Loss :  0.0017815219983458518 1.3862943696975707 1.3880758857727051


100%|██████████| 7/7 [00:11<00:00,  1.67s/it]


Validation Loss :    1.3891009092330933
Epoch :  24  /  50


100%|██████████| 25/25 [00:42<00:00,  1.71s/it]


Train Loss :  0.0017510797688737511 1.3862943649291992 1.3880454397201538


100%|██████████| 7/7 [00:11<00:00,  1.66s/it]


Validation Loss :    1.3888611793518066
Epoch :  25  /  50


100%|██████████| 25/25 [00:42<00:00,  1.71s/it]


Train Loss :  0.0017757545877248049 1.3862943649291992 1.3880701303482055


100%|██████████| 7/7 [00:11<00:00,  1.67s/it]


Validation Loss :    1.3887216023036413
Epoch :  26  /  50


100%|██████████| 25/25 [00:42<00:00,  1.71s/it]


Train Loss :  0.0017027935246005654 1.3862943649291992 1.3879971647262572


100%|██████████| 7/7 [00:11<00:00,  1.66s/it]


Validation Loss :    1.388879554612296
Epoch :  27  /  50


100%|██████████| 25/25 [00:42<00:00,  1.71s/it]


Train Loss :  0.0017334413714706898 1.3862943649291992 1.3880277919769286


100%|██████████| 7/7 [00:11<00:00,  1.66s/it]


Validation Loss :    1.3888496501105172
Epoch :  28  /  50


100%|██████████| 25/25 [00:43<00:00,  1.72s/it]


Train Loss :  0.001746287657879293 1.3862943649291992 1.3880406618118286


100%|██████████| 7/7 [00:11<00:00,  1.66s/it]


Validation Loss :    1.3887204613004411
Epoch :  29  /  50


100%|██████████| 25/25 [00:42<00:00,  1.71s/it]


Train Loss :  0.0017525723949074745 1.3862943649291992 1.3880469417572021


100%|██████████| 7/7 [00:11<00:00,  1.66s/it]


Validation Loss :    1.3885644674301147
Epoch :  30  /  50


100%|██████████| 25/25 [00:42<00:00,  1.71s/it]


Train Loss :  0.0018000595271587371 1.3862943649291992 1.388094425201416


100%|██████████| 7/7 [00:11<00:00,  1.65s/it]


Validation Loss :    1.3888846465519495
Epoch :  31  /  50


100%|██████████| 25/25 [00:42<00:00,  1.70s/it]


Train Loss :  0.0017783115198835731 1.3862943696975707 1.388072681427002


100%|██████████| 7/7 [00:11<00:00,  1.71s/it]


Validation Loss :    1.388561589377267
Epoch :  32  /  50


100%|██████████| 25/25 [00:42<00:00,  1.71s/it]


Train Loss :  0.0017622611485421658 1.3862943649291992 1.388056640625


100%|██████████| 7/7 [00:11<00:00,  1.70s/it]


Validation Loss :    1.3882689135415214
Epoch :  33  /  50


100%|██████████| 25/25 [00:42<00:00,  1.72s/it]


Train Loss :  0.001742770296987146 1.3862943649291992 1.3880371379852294


100%|██████████| 7/7 [00:12<00:00,  1.72s/it]


Validation Loss :    1.3883713143212455
Epoch :  34  /  50


100%|██████████| 25/25 [00:42<00:00,  1.70s/it]


Train Loss :  0.0018018242437392474 1.3862943649291992 1.3880961847305298


100%|██████████| 7/7 [00:11<00:00,  1.71s/it]


Validation Loss :    1.3881145545414515
Epoch :  35  /  50


100%|██████████| 25/25 [00:42<00:00,  1.71s/it]


Train Loss :  0.0018033548956736922 1.3862943696975707 1.3880977249145507


100%|██████████| 7/7 [00:11<00:00,  1.67s/it]


Validation Loss :    1.3880703278950282
Epoch :  36  /  50


100%|██████████| 25/25 [00:43<00:00,  1.74s/it]


Train Loss :  0.001772857028990984 1.3862943696975707 1.388067226409912


100%|██████████| 7/7 [00:11<00:00,  1.68s/it]


Validation Loss :    1.3883189473833357
Epoch :  37  /  50


100%|██████████| 25/25 [00:42<00:00,  1.71s/it]


Train Loss :  0.0017844161204993725 1.3862943696975707 1.388078784942627


100%|██████████| 7/7 [00:11<00:00,  1.70s/it]


Validation Loss :    1.3882717405046736
Epoch :  38  /  50


100%|██████████| 25/25 [00:42<00:00,  1.71s/it]


Train Loss :  0.0017435615742579103 1.3862943649291992 1.3880379295349121


100%|██████████| 7/7 [00:11<00:00,  1.71s/it]


Validation Loss :    1.3881022759846278
Epoch :  39  /  50


100%|██████████| 25/25 [00:42<00:00,  1.71s/it]


Train Loss :  0.0017870901431888341 1.3862943649291992 1.3880814599990845


100%|██████████| 7/7 [00:11<00:00,  1.64s/it]


Validation Loss :    1.3882065159933907
Epoch :  40  /  50


100%|██████████| 25/25 [00:43<00:00,  1.72s/it]


Train Loss :  0.001778088789433241 1.3862943649291992 1.3880724573135377


 14%|█▍        | 1/7 [00:02<00:17,  2.84s/it]

In [None]:
final_vectors = encode(model,encode_loader,device)
final_vectors = np.array(final_vectors).reshape(len(data.patients), -1)

In [None]:
final_df = pd.DataFrame(final_vectors, index=data.patients, columns = data.node_order)
final_df.to_csv(fencsave)

In [None]:
!nvidia-smi

In [None]:
wandb.finish()