In [1]:
import torch
import os
from torch_geometric.data import Data
import pandas as pd
import numpy as np
from torchvision.transforms import ToTensor
from torch_geometric.loader import DataLoader
from tqdm.auto import trange, tqdm
from torch_geometric import seed_everything
import torchvision

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import wandb
wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33mmeagoodboy[0m ([33mitsagoodteam[0m). Use [1m`wandb login --relogin`[0m to force relogin


True

In [3]:
from encoder_models import *
from node_decoder_models import *
from test_train_validate import *
from data_processing.process_data import *
from utils import *
from wrappers import *

In [4]:
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
print(device)

cuda:1


In [5]:
clear_cache()

In [6]:
seed_everything(36912)

In [7]:
omics_file = "../data/raw/sm/kirp_sm251_csn.csv"
network_file = "../data/network/CancerSubnetwork.txt"

learning_rate = 3e-5
num_features = 1
out_channels = 1
num_epochs = 50
batch_size = 16
alpha = 1
beta = 0.5

cancer = "KIRP"
omic = "SM"
gmodel = "L3GravNetConv"
optim = "ADAM"

savefolder = "./new_res/" + cancer + "/" + omic + "/" + gmodel + "/" +  optim + "_" + str(learning_rate) +"/"
savemodels = "./new_mod/" + cancer + "/" + omic + "/" + gmodel + "/" + optim + "_" +  str(learning_rate) +"/"
savename = cancer + "_" + omic + "_" + gmodel + "_" + optim + "_" + str(learning_rate) +"_"

savename = savemodels + savename 
summaryin = savemodels + "runs"
bestmodel = savename + "bestmodel.pt"
finalmodel = savename + "model.pt" 
configf = savename + "config.yml" 
fencsave = savefolder + "final.csv"
bencsave = savefolder + "best.csv"

if not os.path.exists(savefolder):
    os.makedirs(savefolder)
if not os.path.exists(savemodels):
    os.makedirs(savemodels)

In [8]:
wandb.init(project="TEST1_KIRP_GCNM2")

cfg = wandb.config
cfg.update({"epochs" : num_epochs, "batch_size": batch_size, "lr" : learning_rate,"optim" : optim,"data_type" : omic , "cancer" : cancer,"save":savefolder,"model_type":gmodel})

In [9]:
data = SingleOmicData(network_file, omics_file, 1)
num_nodes = len(data.node_order)

Loading the Network :  ../data/network/CancerSubnetwork.txt


100%|██████████| 2291/2291 [00:01<00:00, 2257.72it/s]

Loaded the Network 





In [10]:
train_size = int(0.8 * len(data))
x,y = torch.utils.data.random_split(data, lengths=[train_size, len(data) - train_size], generator=torch.Generator())

In [11]:
train_loader = DataLoader(x, shuffle=True, batch_size=batch_size,num_workers=8)
val_loader = DataLoader(y, shuffle=True, batch_size=batch_size, num_workers=8)
encode_loader = DataLoader(data, shuffle=False, batch_size=1)

In [12]:
model = GAEM( encoder = L3GravNetConv(in_channels = num_features, out_channels = out_channels),
             node_decoder = L2Linear(out_channels = out_channels, num_nodes = num_nodes, batch_size=batch_size) )

In [13]:
model = model.to(device)

In [14]:
!nvidia-smi

Tue Nov  8 11:20:04 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 510.85.02    Driver Version: 510.85.02    CUDA Version: 11.6     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA GeForce ...  Off  | 00000000:02:00.0 Off |                  N/A |
| 31%   31C    P8    15W / 250W |   3400MiB / 11264MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA GeForce ...  Off  | 00000000:03:00.0 Off |                  N/A |
| 32%   33C    P2    56W / 250W |   1364MiB / 11264MiB |      4%      Default |
|       

In [15]:
print("total parameters in the model : ", calculate_num_params(model))
print("total parameters in the encoder : ", calculate_num_params(model.encoder))
print("total parameters in the node_decoder : ", calculate_num_params(model.node_decoder))
print("total parameters in the decoder : ", calculate_num_params(model.decoder))

total parameters in the model :  10507027
total parameters in the encoder :  501
total parameters in the node_decoder :  10506526
total parameters in the decoder :  0


In [16]:
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [17]:
lossfn = torchvision.ops.focal_loss.sigmoid_focal_loss

In [18]:
all_val_loss = []
for e in range(num_epochs):
    
    print("Epoch : ", e + 1,  " / " , num_epochs)
    
    train_loss = train(model,train_loader,optimizer, device,lossfn )
    val_loss = validate(model,val_loader,device, lossfn)
    all_val_loss.append(val_loss)
    wandb.log({"validation loss" : val_loss,"train loss" : train_loss})
    
    if e > 10:
        if val_loss < min(all_val_loss):
            torch.save(model.state_dict(), bestmodel)
            print("Saved best model weights")
    
    if (e+1) % 20 == 0:
        print("Saving Model")
        torch.save(model.state_dict(), finalmodel)

wandb.finish()

Epoch :  1  /  50


100%|██████████| 13/13 [00:44<00:00,  3.45s/it]


Train Loss :  0.062035504155434094 1.3862941265106201 1.4483296229289129


100%|██████████| 4/4 [00:12<00:00,  3.09s/it]


Validation Loss :    1.450637549161911
Epoch :  2  /  50


100%|██████████| 4/4 [00:12<00:00,  3.06s/it]]


Validation Loss :    1.448566496372223
Epoch :  4  /  50


100%|██████████| 13/13 [00:44<00:00,  3.45s/it]


Train Loss :  0.013074325397610664 1.3862919899133535 1.3993663054246168


100%|██████████| 4/4 [00:12<00:00,  3.05s/it]


Validation Loss :    1.4418471455574036
Epoch :  5  /  50


100%|██████████| 13/13 [00:44<00:00,  3.43s/it]


Train Loss :  0.008821150073065208 1.3862899358455951 1.3951110748144298


100%|██████████| 4/4 [00:12<00:00,  3.06s/it]


Validation Loss :    1.4274495244026184
Epoch :  6  /  50


100%|██████████| 13/13 [00:45<00:00,  3.47s/it]


Train Loss :  0.0067704858330006785 1.3862843146690955 1.3930548062691321


100%|██████████| 4/4 [00:12<00:00,  3.06s/it]


Validation Loss :    1.4086434245109558
Epoch :  7  /  50


100%|██████████| 13/13 [00:44<00:00,  3.46s/it]


Train Loss :  0.0051793476495032124 1.3862857543505156 1.3914651045432458


100%|██████████| 4/4 [00:12<00:00,  3.09s/it]


Validation Loss :    1.3980606496334076
Epoch :  8  /  50


100%|██████████| 13/13 [00:45<00:00,  3.47s/it]


Train Loss :  0.004429574661816542 1.3862875699996948 1.3907171487808228


100%|██████████| 4/4 [00:12<00:00,  3.00s/it]


Validation Loss :    1.3941633701324463
Epoch :  9  /  50


100%|██████████| 13/13 [00:44<00:00,  3.41s/it]


Train Loss :  0.003664303570985794 1.386288312765268 1.3899526045872614


100%|██████████| 4/4 [00:12<00:00,  3.04s/it]


Validation Loss :    1.3916078507900238
Epoch :  10  /  50


100%|██████████| 13/13 [00:44<00:00,  3.44s/it]


Train Loss :  0.0034617943904147698 1.3862784092242901 1.3897402194830089


100%|██████████| 4/4 [00:12<00:00,  3.17s/it]


Validation Loss :    1.391873300075531
Epoch :  11  /  50


100%|██████████| 13/13 [00:45<00:00,  3.46s/it]


Train Loss :  0.0029788367772618164 1.3862881843860333 1.3892670136231642


100%|██████████| 4/4 [00:12<00:00,  3.14s/it]


Validation Loss :    1.3902656435966492
Epoch :  12  /  50


100%|██████████| 13/13 [00:45<00:00,  3.47s/it]


Train Loss :  0.002840454672248318 1.3862859652592585 1.3891264200210571


100%|██████████| 4/4 [00:12<00:00,  3.15s/it]


Validation Loss :    1.3897106051445007
Epoch :  13  /  50


100%|██████████| 13/13 [00:44<00:00,  3.43s/it]


Train Loss :  0.002643699811484951 1.3862859377494225 1.3889296329938448


100%|██████████| 4/4 [00:12<00:00,  3.18s/it]


Validation Loss :    1.3895274102687836
Epoch :  14  /  50


100%|██████████| 13/13 [00:44<00:00,  3.44s/it]


Train Loss :  0.0024855299136386467 1.3862855709516084 1.3887710938086877


100%|██████████| 4/4 [00:12<00:00,  3.19s/it]


Validation Loss :    1.3891453444957733
Epoch :  15  /  50


100%|██████████| 13/13 [00:44<00:00,  3.43s/it]


Train Loss :  0.0023944253782526804 1.3862861211483295 1.3886805497683012


100%|██████████| 4/4 [00:12<00:00,  3.12s/it]


Validation Loss :    1.389522761106491
Epoch :  16  /  50


100%|██████████| 13/13 [00:44<00:00,  3.45s/it]


Train Loss :  0.002286267807134069 1.38627866598276 1.3885649259273822


100%|██████████| 4/4 [00:12<00:00,  3.13s/it]


Validation Loss :    1.388946771621704
Epoch :  17  /  50


100%|██████████| 13/13 [00:44<00:00,  3.43s/it]


Train Loss :  0.0021977090151407397 1.3862835719035222 1.3884812868558443


100%|██████████| 4/4 [00:12<00:00,  3.13s/it]


Validation Loss :    1.388727456331253
Epoch :  18  /  50


100%|██████████| 13/13 [00:44<00:00,  3.46s/it]


Train Loss :  0.0021463940009618034 1.386283305975107 1.3884297059132502


100%|██████████| 4/4 [00:12<00:00,  3.13s/it]


Validation Loss :    1.389023333787918
Epoch :  19  /  50


100%|██████████| 13/13 [00:44<00:00,  3.46s/it]


Train Loss :  0.002077981784868126 1.386287469130296 1.3883654796160185


100%|██████████| 4/4 [00:12<00:00,  3.09s/it]


Validation Loss :    1.3888517618179321
Epoch :  20  /  50


100%|██████████| 13/13 [00:44<00:00,  3.45s/it]


Train Loss :  0.0020502999854775574 1.3862800323046172 1.3883303220455463


100%|██████████| 4/4 [00:12<00:00,  3.05s/it]


Validation Loss :    1.3888018429279327
Saving Model
Epoch :  21  /  50


100%|██████████| 13/13 [00:44<00:00,  3.46s/it]


Train Loss :  0.0020239049962793407 1.386264599286593 1.3882884979248047


100%|██████████| 4/4 [00:12<00:00,  3.14s/it]


Validation Loss :    1.3887199461460114
Epoch :  22  /  50


100%|██████████| 13/13 [00:45<00:00,  3.47s/it]


Train Loss :  0.001958585218884624 1.3862833701647246 1.3882419421122625


100%|██████████| 4/4 [00:12<00:00,  3.09s/it]


Validation Loss :    1.38862544298172
Epoch :  23  /  50


100%|██████████| 13/13 [00:44<00:00,  3.45s/it]


Train Loss :  0.0019485144249091928 1.3862823247909546 1.3882308373084435


100%|██████████| 4/4 [00:12<00:00,  3.03s/it]


Validation Loss :    1.3888052701950073
Epoch :  24  /  50


100%|██████████| 13/13 [00:45<00:00,  3.47s/it]


Train Loss :  0.0019257587756818305 1.3862790694603553 1.3882048221734853


100%|██████████| 4/4 [00:12<00:00,  3.09s/it]


Validation Loss :    1.3884048461914062
Epoch :  25  /  50


100%|██████████| 13/13 [00:45<00:00,  3.46s/it]


Train Loss :  0.0019008025007608992 1.3862642049789429 1.38816499710083


100%|██████████| 4/4 [00:12<00:00,  3.06s/it]


Validation Loss :    1.388388752937317
Epoch :  26  /  50


100%|██████████| 13/13 [00:45<00:00,  3.47s/it]


Train Loss :  0.0018714043455055128 1.386266910112821 1.388138303389916


100%|██████████| 4/4 [00:12<00:00,  3.08s/it]


Validation Loss :    1.3885233104228973
Epoch :  27  /  50


100%|██████████| 13/13 [00:44<00:00,  3.43s/it]


Train Loss :  0.0018602870673371048 1.3862743377685547 1.388134635411776


100%|██████████| 4/4 [00:12<00:00,  3.06s/it]


Validation Loss :    1.3881837129592896
Epoch :  28  /  50


100%|██████████| 13/13 [00:44<00:00,  3.46s/it]


Train Loss :  0.001837338786572218 1.3862703580122728 1.3881077032822828


100%|██████████| 4/4 [00:12<00:00,  3.02s/it]


Validation Loss :    1.3883332908153534
Epoch :  29  /  50


100%|██████████| 13/13 [00:44<00:00,  3.42s/it]


Train Loss :  0.0018057381698431878 1.386267377780034 1.3880731050784771


100%|██████████| 4/4 [00:12<00:00,  3.06s/it]


Validation Loss :    1.3884852826595306
Epoch :  30  /  50


100%|██████████| 13/13 [00:45<00:00,  3.47s/it]


Train Loss :  0.0018099352824859894 1.3862608120991633 1.3880707392325768


100%|██████████| 4/4 [00:12<00:00,  3.09s/it]


Validation Loss :    1.3886152505874634
Epoch :  31  /  50


100%|██████████| 13/13 [00:45<00:00,  3.47s/it]


Train Loss :  0.0017551915326084082 1.3862700829139123 1.3880252746435313


100%|██████████| 4/4 [00:12<00:00,  3.03s/it]


Validation Loss :    1.388453185558319
Epoch :  32  /  50


100%|██████████| 13/13 [00:45<00:00,  3.47s/it]


Train Loss :  0.001732136606453703 1.38626923927894 1.3880013961058397


100%|██████████| 4/4 [00:12<00:00,  3.10s/it]


Validation Loss :    1.3883866369724274
Epoch :  33  /  50


100%|██████████| 13/13 [00:45<00:00,  3.46s/it]


Train Loss :  0.001771493349224329 1.3862602343926063 1.3880317302850576


100%|██████████| 4/4 [00:12<00:00,  3.09s/it]


Validation Loss :    1.3884109854698181
Epoch :  34  /  50


100%|██████████| 13/13 [00:44<00:00,  3.45s/it]


Train Loss :  0.0017608583439141512 1.3862155217390795 1.3879763713249793


100%|██████████| 4/4 [00:12<00:00,  3.07s/it]


Validation Loss :    1.3883496522903442
Epoch :  35  /  50


100%|██████████| 13/13 [00:45<00:00,  3.48s/it]


Train Loss :  0.001780963995350668 1.3860226044288049 1.387803554534912


100%|██████████| 4/4 [00:12<00:00,  3.12s/it]


Validation Loss :    1.3900082111358643
Epoch :  36  /  50


100%|██████████| 13/13 [00:44<00:00,  3.44s/it]


Train Loss :  0.0020774943539156364 1.3725372094374437 1.3746146972362812


100%|██████████| 4/4 [00:12<00:00,  3.16s/it]


Validation Loss :    1.3789452910423279
Epoch :  37  /  50


100%|██████████| 13/13 [00:44<00:00,  3.46s/it]


Train Loss :  0.0018340955183912928 1.3651690116295447 1.367003110738901


100%|██████████| 4/4 [00:12<00:00,  3.13s/it]


Validation Loss :    1.3612324595451355
Epoch :  38  /  50


100%|██████████| 13/13 [00:44<00:00,  3.45s/it]


Train Loss :  0.001794656586403457 1.3661645100666926 1.3679591509012075


100%|██████████| 4/4 [00:12<00:00,  3.12s/it]


Validation Loss :    1.3469296097755432
Epoch :  39  /  50


100%|██████████| 13/13 [00:44<00:00,  3.41s/it]


Train Loss :  0.0017903101755879247 1.3669330431864812 1.3687233466368456


100%|██████████| 4/4 [00:12<00:00,  3.11s/it]


Validation Loss :    1.3678284585475922
Epoch :  40  /  50


100%|██████████| 13/13 [00:44<00:00,  3.41s/it]


Train Loss :  0.0016494307470006438 1.3645584399883564 1.3662078839081984


100%|██████████| 4/4 [00:12<00:00,  3.14s/it]


Validation Loss :    1.3597362339496613
Saving Model
Epoch :  41  /  50


100%|██████████| 13/13 [00:44<00:00,  3.43s/it]


Train Loss :  0.0015844614334547748 1.3667939901351929 1.3683784374823937


100%|██████████| 4/4 [00:12<00:00,  3.16s/it]


Validation Loss :    1.4470335841178894
Epoch :  42  /  50


100%|██████████| 13/13 [00:45<00:00,  3.48s/it]


Train Loss :  0.0015398590700127757 1.3652846079606276 1.3668244710335364


100%|██████████| 4/4 [00:12<00:00,  3.17s/it]


Validation Loss :    1.3574867248535156
Epoch :  43  /  50


100%|██████████| 13/13 [00:44<00:00,  3.43s/it]


Train Loss :  0.0015734969100986535 1.3656434737719023 1.3672169630344098


100%|██████████| 4/4 [00:12<00:00,  3.18s/it]


Validation Loss :    1.3436557054519653
Epoch :  44  /  50


100%|██████████| 13/13 [00:44<00:00,  3.45s/it]


Train Loss :  0.0015324711119039701 1.368386378655067 1.369918859921969


100%|██████████| 4/4 [00:12<00:00,  3.14s/it]


Validation Loss :    1.3734877705574036
Epoch :  45  /  50


100%|██████████| 13/13 [00:50<00:00,  3.90s/it]


Train Loss :  0.0015288304376344269 1.3660030915186956 1.367531941487239


100%|██████████| 4/4 [00:15<00:00,  3.88s/it]


Validation Loss :    1.3557331264019012
Epoch :  46  /  50


100%|██████████| 13/13 [00:47<00:00,  3.62s/it]


Train Loss :  0.0015175461482543212 1.3662078839081984 1.3677254181641798


100%|██████████| 4/4 [00:12<00:00,  3.10s/it]


Validation Loss :    1.3581383526325226
Epoch :  47  /  50


100%|██████████| 13/13 [00:47<00:00,  3.68s/it]


Train Loss :  0.0015265067878107612 1.3670963782530565 1.3686228807155902


100%|██████████| 4/4 [00:12<00:00,  3.18s/it]


Validation Loss :    1.3554022312164307
Epoch :  48  /  50


100%|██████████| 13/13 [00:51<00:00,  3.98s/it]


Train Loss :  0.0014586364671301383 1.3666900029549232 1.3681486386519213


100%|██████████| 4/4 [00:13<00:00,  3.48s/it]


Validation Loss :    1.363096296787262
Epoch :  49  /  50


100%|██████████| 13/13 [00:49<00:00,  3.81s/it]


Train Loss :  0.0014887878384727698 1.3676777527882502 1.3691665484355047


100%|██████████| 4/4 [00:12<00:00,  3.10s/it]


Validation Loss :    1.3572361469268799
Epoch :  50  /  50


100%|██████████| 13/13 [00:43<00:00,  3.36s/it]


Train Loss :  0.0013868975488898845 1.3653201598387499 1.3667070407133837


100%|██████████| 4/4 [00:11<00:00,  2.95s/it]


Validation Loss :    1.3580348491668701


0,1
train loss,█▆▅▄▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▂▁▁▁▁▁▁▁▁▁▁▁
validation loss,███▇▅▅▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▃▂▁▃█▂▁▃▂▂▂▂

0,1
train loss,1.36671
validation loss,1.35803


In [19]:
final_vectors = encode(model,encode_loader,device)
final_vectors = np.array(final_vectors).reshape(len(data.patients), -1)

100%|██████████| 251/251 [00:01<00:00, 191.61it/s]


In [20]:
final_df = pd.DataFrame(final_vectors, index=data.patients, columns = data.node_order)
final_df.to_csv(fencsave)

In [21]:
!nvidia-smi

Tue Nov  8 12:08:33 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 510.85.02    Driver Version: 510.85.02    CUDA Version: 11.6     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA GeForce ...  Off  | 00000000:02:00.0 Off |                  N/A |
| 31%   36C    P8    15W / 250W |   1756MiB / 11264MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA GeForce ...  Off  | 00000000:03:00.0 Off |                  N/A |
| 32%   43C    P2    59W / 250W |   3986MiB / 11264MiB |     15%      Default |
|       

In [22]:
wandb.finish()