In [1]:
import torch
import os
from torch_geometric.data import Data
import pandas as pd
import numpy as np
from torchvision.transforms import ToTensor
from torch_geometric.loader import DataLoader
from tqdm.auto import trange, tqdm
from torch_geometric import seed_everything
import torchvision

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import wandb
wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33mmeagoodboy[0m ([33mitsagoodteam[0m). Use [1m`wandb login --relogin`[0m to force relogin


True

In [3]:
from encoder_models import *
from node_decoder_models import *
from test_train_validate import *
from data_processing.process_data import *
from utils import *
from wrappers import *

In [4]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [5]:
clear_cache()

In [6]:
seed_everything(36912)

In [7]:
omics_file = "../data/raw/sm/kirp_sm251_csn.csv"
network_file = "../data/network/CancerSubnetwork.txt"

learning_rate = 3e-4
num_features = 1
out_channels = 1
num_epochs = 50
batch_size = 8
alpha = 1
beta = 0.5

cancer = "KIRP"
omic = "SM"
gmodel = "L2GravNetConv"
optim = "ADAM"

savefolder = "./new_res/" + cancer + "/" + omic + "/" + gmodel + "/" +  optim + "_" + str(learning_rate) +"/"
savemodels = "./new_mod/" + cancer + "/" + omic + "/" + gmodel + "/" + optim + "_" +  str(learning_rate) +"/"
savename = cancer + "_" + omic + "_" + gmodel + "_" + optim + "_" + str(learning_rate) +"_"

savename = savemodels + savename 
summaryin = savemodels + "runs"
bestmodel = savename + "bestmodel.pt"
finalmodel = savename + "model.pt" 
configf = savename + "config.yml" 
fencsave = savefolder + "final.csv"
bencsave = savefolder + "best.csv"

if not os.path.exists(savefolder):
    os.makedirs(savefolder)
if not os.path.exists(savemodels):
    os.makedirs(savemodels)

In [8]:
wandb.init(project="TEST1_KIRP_GCNM2")

cfg = wandb.config
cfg.update({"epochs" : num_epochs, "batch_size": batch_size, "lr" : learning_rate,"optim" : optim,"data_type" : omic , "cancer" : cancer,"save":savefolder,"model_type":gmodel})

In [9]:
data = SingleOmicData(network_file, omics_file, 1)
num_nodes = len(data.node_order)

Loading the Network :  ../data/network/CancerSubnetwork.txt


100%|██████████| 2291/2291 [00:00<00:00, 2300.66it/s]


Loaded the Network 


In [10]:
train_size = int(0.8 * len(data))
x,y = torch.utils.data.random_split(data, lengths=[train_size, len(data) - train_size], generator=torch.Generator())

In [11]:
train_loader = DataLoader(x, shuffle=True, batch_size=batch_size,num_workers=8)
val_loader = DataLoader(y, shuffle=True, batch_size=batch_size, num_workers=8)
encode_loader = DataLoader(data, shuffle=False, batch_size=1)

In [12]:
model = GAEM( encoder = L2GravNetConv(in_channels = num_features, out_channels = out_channels),
             node_decoder = L2Linear(out_channels = out_channels, num_nodes = num_nodes, batch_size=batch_size) )

In [13]:
model = model.to(device)

In [14]:
!nvidia-smi

Tue Nov  8 11:20:39 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 510.85.02    Driver Version: 510.85.02    CUDA Version: 11.6     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA GeForce ...  Off  | 00000000:02:00.0 Off |                  N/A |
| 32%   35C    P2    55W / 250W |   1364MiB / 11264MiB |      4%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA GeForce ...  Off  | 00000000:03:00.0 Off |                  N/A |
| 31%   36C    P2    57W / 250W |   3986MiB / 11264MiB |      0%      Default |
|       

In [15]:
print("total parameters in the model : ", calculate_num_params(model))
print("total parameters in the encoder : ", calculate_num_params(model.encoder))
print("total parameters in the node_decoder : ", calculate_num_params(model.node_decoder))
print("total parameters in the decoder : ", calculate_num_params(model.decoder))

total parameters in the model :  10506552
total parameters in the encoder :  26
total parameters in the node_decoder :  10506526
total parameters in the decoder :  0


In [16]:
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [17]:
lossfn = torchvision.ops.focal_loss.sigmoid_focal_loss

In [18]:
all_val_loss = []
for e in range(num_epochs):
    
    print("Epoch : ", e + 1,  " / " , num_epochs)
    
    train_loss = train(model,train_loader,optimizer, device,lossfn )
    val_loss = validate(model,val_loader,device, lossfn)
    all_val_loss.append(val_loss)
    wandb.log({"validation loss" : val_loss,"train loss" : train_loss})
    
    if e > 10:
        if val_loss < min(all_val_loss):
            torch.save(model.state_dict(), bestmodel)
            print("Saved best model weights")
    
    if (e+1) % 20 == 0:
        print("Saving Model")
        torch.save(model.state_dict(), finalmodel)

wandb.finish()

Epoch :  1  /  50


100%|██████████| 25/25 [00:43<00:00,  1.72s/it]


Train Loss :  0.0012976529635488987 1.3862943124771119 1.387591962814331


100%|██████████| 7/7 [00:11<00:00,  1.66s/it]


Validation Loss :    1.4076206173215593
Epoch :  4  /  50


100%|██████████| 25/25 [00:42<00:00,  1.71s/it]


Train Loss :  0.0008526658965274692 1.3862942934036255 1.3871469545364379


100%|██████████| 7/7 [00:11<00:00,  1.63s/it]


Validation Loss :    1.390024985585894
Epoch :  5  /  50


100%|██████████| 25/25 [00:42<00:00,  1.72s/it]


Train Loss :  0.0006455068720970302 1.3862942886352538 1.3869397974014281


100%|██████████| 7/7 [00:11<00:00,  1.69s/it]


Validation Loss :    1.3883266959871565
Epoch :  6  /  50


100%|██████████| 25/25 [00:42<00:00,  1.72s/it]


Train Loss :  0.0005019978852942586 1.386294298171997 1.386796293258667


100%|██████████| 7/7 [00:11<00:00,  1.65s/it]


Validation Loss :    1.3882971491132463
Epoch :  7  /  50


100%|██████████| 25/25 [00:42<00:00,  1.71s/it]


Train Loss :  0.0003569502092432231 1.3862942743301392 1.3866512203216552


100%|██████████| 7/7 [00:11<00:00,  1.65s/it]


Validation Loss :    1.3883284160069056
Epoch :  8  /  50


100%|██████████| 25/25 [00:43<00:00,  1.74s/it]


Train Loss :  0.00029424693842884154 1.3862942504882811 1.3865884971618652


100%|██████████| 7/7 [00:11<00:00,  1.64s/it]


Validation Loss :    1.3882932833262853
Epoch :  9  /  50


100%|██████████| 25/25 [00:42<00:00,  1.70s/it]


Train Loss :  0.0002450203022453934 1.3862942600250243 1.386539297103882


100%|██████████| 7/7 [00:11<00:00,  1.65s/it]


Validation Loss :    1.3883098704474313
Epoch :  10  /  50


100%|██████████| 25/25 [00:42<00:00,  1.70s/it]


Train Loss :  0.00020579013304086402 1.3862942457199097 1.3865000343322753


100%|██████████| 7/7 [00:11<00:00,  1.69s/it]


Validation Loss :    1.3883302552359444
Epoch :  11  /  50


100%|██████████| 25/25 [00:42<00:00,  1.70s/it]


Train Loss :  0.00017579678009497002 1.3862942457199097 1.3864700412750244


100%|██████████| 7/7 [00:11<00:00,  1.69s/it]


Validation Loss :    1.3882764066968645
Epoch :  12  /  50


100%|██████████| 25/25 [00:42<00:00,  1.69s/it]


Train Loss :  0.00014328168268548324 1.3862942218780518 1.3864374923706055


100%|██████████| 7/7 [00:12<00:00,  1.76s/it]


Validation Loss :    1.3883192368916102
Epoch :  13  /  50


100%|██████████| 25/25 [00:42<00:00,  1.69s/it]


Train Loss :  0.0001447309546347242 1.3862941980361938 1.3864389276504516


100%|██████████| 7/7 [00:11<00:00,  1.67s/it]


Validation Loss :    1.3884512526648385
Epoch :  14  /  50


100%|██████████| 25/25 [00:42<00:00,  1.70s/it]


Train Loss :  0.00011760332548874431 1.3862942028045655 1.386411805152893


100%|██████████| 7/7 [00:12<00:00,  1.72s/it]


Validation Loss :    1.388496654374259
Epoch :  15  /  50


100%|██████████| 25/25 [00:43<00:00,  1.73s/it]


Train Loss :  0.00010831388994120061 1.3862941837310792 1.3864025115966796


100%|██████████| 7/7 [00:12<00:00,  1.73s/it]


Validation Loss :    1.388575860432216
Epoch :  16  /  50


100%|██████████| 25/25 [00:43<00:00,  1.72s/it]


Train Loss :  8.396899051149376e-05 1.3862941884994506 1.3863781547546388


100%|██████████| 7/7 [00:11<00:00,  1.71s/it]


Validation Loss :    1.388577835900443
Epoch :  17  /  50


100%|██████████| 25/25 [00:43<00:00,  1.74s/it]


Train Loss :  7.66692132310709e-05 1.3862941646575928 1.3863708305358886


100%|██████████| 7/7 [00:11<00:00,  1.69s/it]


Validation Loss :    1.3885797943387712
Epoch :  18  /  50


100%|██████████| 25/25 [00:43<00:00,  1.73s/it]


Train Loss :  7.789654788211919e-05 1.3862941598892211 1.3863720512390136


100%|██████████| 7/7 [00:11<00:00,  1.69s/it]


Validation Loss :    1.3885730334690638
Epoch :  19  /  50


100%|██████████| 25/25 [00:43<00:00,  1.73s/it]


Train Loss :  6.77653153979918e-05 1.3862941694259643 1.3863619327545167


100%|██████████| 7/7 [00:11<00:00,  1.66s/it]


Validation Loss :    1.388521943773542
Epoch :  20  /  50


100%|██████████| 25/25 [00:44<00:00,  1.79s/it]


Train Loss :  6.663183339696843e-05 1.3862941694259643 1.38636079788208


100%|██████████| 7/7 [00:11<00:00,  1.68s/it]


Validation Loss :    1.3885357209614344
Saving Model
Epoch :  21  /  50


100%|██████████| 25/25 [00:43<00:00,  1.74s/it]


Train Loss :  5.5596994643565266e-05 1.3862941789627075 1.3863497829437257


100%|██████████| 7/7 [00:11<00:00,  1.67s/it]


Validation Loss :    1.3884973866598946
Epoch :  22  /  50


100%|██████████| 25/25 [00:43<00:00,  1.75s/it]


Train Loss :  5.847467626153957e-05 1.3862941694259643 1.3863526487350464


100%|██████████| 7/7 [00:11<00:00,  1.68s/it]


Validation Loss :    1.3885841369628906
Epoch :  23  /  50


100%|██████████| 25/25 [00:43<00:00,  1.74s/it]


Train Loss :  4.552404217974981e-05 1.3862941646575928 1.3863396883010863


100%|██████████| 7/7 [00:11<00:00,  1.68s/it]


Validation Loss :    1.3887273924691337
Epoch :  24  /  50


100%|██████████| 25/25 [00:43<00:00,  1.73s/it]


Train Loss :  5.9698864897654855e-05 1.3862941408157348 1.3863538360595704


100%|██████████| 7/7 [00:11<00:00,  1.68s/it]


Validation Loss :    1.388588581766401
Epoch :  25  /  50


100%|██████████| 25/25 [00:43<00:00,  1.75s/it]


Train Loss :  4.600098960509058e-05 1.386294150352478 1.3863401460647582


100%|██████████| 7/7 [00:11<00:00,  1.66s/it]


Validation Loss :    1.3887748207364763
Epoch :  26  /  50


100%|██████████| 25/25 [00:43<00:00,  1.75s/it]


Train Loss :  4.777437592565548e-05 1.3862941408157348 1.3863419055938722


100%|██████████| 7/7 [00:11<00:00,  1.67s/it]


Validation Loss :    1.388684936932155
Epoch :  27  /  50


100%|██████████| 25/25 [00:43<00:00,  1.74s/it]


Train Loss :  4.297653333196649e-05 1.3862941455841065 1.3863371181488038


100%|██████████| 7/7 [00:11<00:00,  1.67s/it]


Validation Loss :    1.38857901096344
Epoch :  28  /  50


100%|██████████| 25/25 [00:43<00:00,  1.74s/it]


Train Loss :  4.340029568993486e-05 1.3862941455841065 1.3863375520706176


100%|██████████| 7/7 [00:11<00:00,  1.68s/it]


Validation Loss :    1.3889392273766654
Epoch :  29  /  50


100%|██████████| 25/25 [00:43<00:00,  1.75s/it]


Train Loss :  3.591155433241511e-05 1.3862941312789916 1.386330051422119


100%|██████████| 7/7 [00:11<00:00,  1.66s/it]


Validation Loss :    1.3887360436575753
Epoch :  30  /  50


100%|██████████| 25/25 [00:43<00:00,  1.74s/it]


Train Loss :  3.701485842611874e-05 1.3862941265106201 1.38633113861084


100%|██████████| 7/7 [00:11<00:00,  1.69s/it]


Validation Loss :    1.3887004000799996
Epoch :  31  /  50


100%|██████████| 25/25 [00:43<00:00,  1.76s/it]


Train Loss :  3.564077149349032e-05 1.3862941455841065 1.3863297843933104


100%|██████████| 7/7 [00:11<00:00,  1.68s/it]


Validation Loss :    1.3888708863939558
Epoch :  32  /  50


100%|██████████| 25/25 [00:43<00:00,  1.74s/it]


Train Loss :  3.4369261866231684e-05 1.3862941360473633 1.386328511238098


100%|██████████| 7/7 [00:11<00:00,  1.70s/it]


Validation Loss :    1.388830031667437
Epoch :  33  /  50


100%|██████████| 25/25 [00:43<00:00,  1.75s/it]


Train Loss :  3.43863036687253e-05 1.3862941360473633 1.386328525543213


100%|██████████| 7/7 [00:11<00:00,  1.67s/it]


Validation Loss :    1.3889728103365218
Epoch :  34  /  50


100%|██████████| 25/25 [00:43<00:00,  1.74s/it]


Train Loss :  2.7557297034945804e-05 1.3862941360473633 1.3863216876983642


100%|██████████| 7/7 [00:11<00:00,  1.66s/it]


Validation Loss :    1.3888721806662423
Epoch :  35  /  50


100%|██████████| 25/25 [00:43<00:00,  1.74s/it]


Train Loss :  2.711795976210851e-05 1.386294116973877 1.3863212394714355


100%|██████████| 7/7 [00:11<00:00,  1.68s/it]


Validation Loss :    1.388831513268607
Epoch :  36  /  50


100%|██████████| 25/25 [00:43<00:00,  1.73s/it]


Train Loss :  3.071025512326742e-05 1.3862941551208496 1.3863248634338379


100%|██████████| 7/7 [00:12<00:00,  1.72s/it]


Validation Loss :    1.3887499400547572
Epoch :  37  /  50


100%|██████████| 25/25 [00:42<00:00,  1.72s/it]


Train Loss :  2.991224366269307e-05 1.3862941408157348 1.3863240575790405


100%|██████████| 7/7 [00:12<00:00,  1.73s/it]


Validation Loss :    1.3888063430786133
Epoch :  38  /  50


100%|██████████| 25/25 [00:42<00:00,  1.72s/it]


Train Loss :  3.0074437272560317e-05 1.3862941265106201 1.386324200630188


100%|██████████| 7/7 [00:12<00:00,  1.75s/it]


Validation Loss :    1.388884459223066
Epoch :  39  /  50


100%|██████████| 25/25 [00:43<00:00,  1.73s/it]


Train Loss :  2.275428651046241e-05 1.3862941074371338 1.3863168573379516


100%|██████████| 7/7 [00:12<00:00,  1.72s/it]


Validation Loss :    1.388811469078064
Epoch :  40  /  50


100%|██████████| 25/25 [00:43<00:00,  1.73s/it]


Train Loss :  2.9336733978198025e-05 1.3862941312789916 1.3863234615325928


100%|██████████| 7/7 [00:12<00:00,  1.73s/it]


Validation Loss :    1.388912217957633
Saving Model
Epoch :  41  /  50


100%|██████████| 25/25 [00:43<00:00,  1.72s/it]


Train Loss :  2.7013230301236036e-05 1.3862941455841065 1.3863211584091186


100%|██████████| 7/7 [00:12<00:00,  1.72s/it]


Validation Loss :    1.388870699065072
Epoch :  42  /  50


100%|██████████| 25/25 [00:42<00:00,  1.71s/it]


Train Loss :  3.1164370666374454e-05 1.3862941265106201 1.386325283050537


100%|██████████| 7/7 [00:11<00:00,  1.69s/it]


Validation Loss :    1.3889001607894897
Epoch :  43  /  50


100%|██████████| 25/25 [00:43<00:00,  1.73s/it]


Train Loss :  2.513133893444319e-05 1.3862941360473633 1.3863192749023439


100%|██████████| 7/7 [00:11<00:00,  1.71s/it]


Validation Loss :    1.3888198818479265
Epoch :  44  /  50


100%|██████████| 25/25 [00:43<00:00,  1.74s/it]


Train Loss :  2.3347648111666784e-05 1.3862941360473633 1.3863174867630006


100%|██████████| 7/7 [00:12<00:00,  1.72s/it]


Validation Loss :    1.3888253654752458
Epoch :  45  /  50


100%|██████████| 25/25 [00:43<00:00,  1.74s/it]


Train Loss :  2.358261534027406e-05 1.3862941408157348 1.3863177251815797


100%|██████████| 7/7 [00:12<00:00,  1.73s/it]


Validation Loss :    1.3888649259294783
Epoch :  46  /  50


100%|██████████| 25/25 [00:47<00:00,  1.90s/it]


Train Loss :  2.5049002142623068e-05 1.3862941265106201 1.3863191747665404


100%|██████████| 7/7 [00:14<00:00,  2.12s/it]


Validation Loss :    1.389125841004508
Epoch :  47  /  50


100%|██████████| 25/25 [00:47<00:00,  1.90s/it]


Train Loss :  2.3876469740571337e-05 1.3862941217422486 1.3863180112838744


100%|██████████| 7/7 [00:11<00:00,  1.67s/it]


Validation Loss :    1.3888424464634486
Epoch :  48  /  50


100%|██████████| 25/25 [00:46<00:00,  1.84s/it]


Train Loss :  2.3164708036347292e-05 1.3862941265106201 1.3863172960281371


100%|██████████| 7/7 [00:12<00:00,  1.74s/it]


Validation Loss :    1.3890127284186227
Epoch :  49  /  50


100%|██████████| 25/25 [00:49<00:00,  1.98s/it]


Train Loss :  2.2222020616027293e-05 1.3862941217422486 1.3863163375854493


100%|██████████| 7/7 [00:13<00:00,  1.91s/it]


Validation Loss :    1.3890413727079118
Epoch :  50  /  50


100%|██████████| 25/25 [00:49<00:00,  1.98s/it]


Train Loss :  2.4545479300286387e-05 1.3862941408157348 1.3863186836242676


100%|██████████| 7/7 [00:13<00:00,  1.87s/it]


Validation Loss :    1.3888755185263497


0,1
train loss,█▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
validation loss,█▆▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
train loss,1.38632
validation loss,1.38888


In [19]:
final_vectors = encode(model,encode_loader,device)
final_vectors = np.array(final_vectors).reshape(len(data.patients), -1)

100%|██████████| 251/251 [00:01<00:00, 194.10it/s]


In [20]:
final_df = pd.DataFrame(final_vectors, index=data.patients, columns = data.node_order)
final_df.to_csv(fencsave)

In [21]:
!nvidia-smi

Tue Nov  8 12:07:20 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 510.85.02    Driver Version: 510.85.02    CUDA Version: 11.6     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA GeForce ...  Off  | 00000000:02:00.0 Off |                  N/A |
| 32%   44C    P2    56W / 250W |   1756MiB / 11264MiB |      8%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA GeForce ...  Off  | 00000000:03:00.0 Off |                  N/A |
| 32%   43C    P2    59W / 250W |   3986MiB / 11264MiB |      0%      Default |
|       

In [22]:
wandb.finish()