In [1]:
import torch
import os
from torch_geometric.data import Data
import pandas as pd
import numpy as np
from torchvision.transforms import ToTensor
from torch_geometric.loader import DataLoader
from tqdm.auto import trange, tqdm
from torch_geometric import seed_everything
import torchvision

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# import wandb
# wandb.login()

In [3]:
from encoder_models import *
from node_decoder_models import *
from test_train_validate import *
from data_processing.process_data import *
from utils import *
from wrappers import *

In [4]:
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
print(device)

cuda:1


In [5]:
clear_cache()

In [6]:
seed_everything(36912)

In [7]:
omics_file = "../data/raw/sm/kirp_sm251_csn.csv"
network_file = "../data/network/CancerSubnetwork.txt"

learning_rate = 3e-4
num_features = 1
out_channels = 1
num_epochs = 1
batch_size = 2
alpha = 1
beta = 0.5

cancer = "KIRP"
omic = "SM"
gmodel = "L3EGConv"
optim = "ADAM"

savefolder = "./new_res/" + cancer + "/" + omic + "/" + gmodel + "/" +  optim + "_" + str(learning_rate) +"/"
savemodels = "./new_mod/" + cancer + "/" + omic + "/" + gmodel + "/" + optim + "_" +  str(learning_rate) +"/"
savename = cancer + "_" + omic + "_" + gmodel + "_" + optim + "_" + str(learning_rate) +"_"

savename = savemodels + savename 
summaryin = savemodels + "runs"
bestmodel = savename + "bestmodel.pt"
finalmodel = savename + "model.pt" 
configf = savename + "config.yml" 
fencsave = savefolder + "final.csv"
bencsave = savefolder + "best.csv"

if not os.path.exists(savefolder):
    os.makedirs(savefolder)
if not os.path.exists(savemodels):
    os.makedirs(savemodels)

In [8]:
# wandb.init(project="TEST1_KIRP_GCNM2")

# cfg = wandb.config
# cfg.update({"epochs" : num_epochs, "batch_size": batch_size, "lr" : learning_rate,"optim" : optim,"data_type" : omic , "cancer" : cancer,"save":savefolder,"model_type":gmodel})

In [13]:
# data = SingleOmicData(network_file, omics_file, 1)
# num_nodes = len(data.node_order)

In [14]:
train_size = int(0.8 * len(data))
x,y = torch.utils.data.random_split(data, lengths=[train_size, len(data) - train_size], generator=torch.Generator())

In [15]:
train_loader = DataLoader(x, shuffle=True, batch_size=batch_size,num_workers=8)
val_loader = DataLoader(y, shuffle=True, batch_size=batch_size, num_workers=8)
encode_loader = DataLoader(data, shuffle=False, batch_size=1)

In [16]:
model = GAEM( encoder = get_encoder(gmodel,in_channels = num_features, out_channels = out_channels),
             node_decoder = L2Linear(out_channels = out_channels, num_nodes = num_nodes, batch_size=batch_size) )

In [17]:
model = model.to(device)

In [18]:
!nvidia-smi

Sun Nov 13 00:55:37 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 510.85.02    Driver Version: 510.85.02    CUDA Version: 11.6     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA GeForce ...  Off  | 00000000:02:00.0 Off |                  N/A |
| 36%   31C    P8    22W / 250W |      3MiB / 11264MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA GeForce ...  Off  | 00000000:82:00.0 Off |                  N/A |
| 36%   32C    P2    47W / 250W |   1364MiB / 11264MiB |      4%      Default |
|       

In [19]:
print("total parameters in the model : ", calculate_num_params(model))
print("total parameters in the encoder : ", calculate_num_params(model.encoder))
print("total parameters in the node_decoder : ", calculate_num_params(model.node_decoder))
print("total parameters in the decoder : ", calculate_num_params(model.decoder))

total parameters in the model :  10527893
total parameters in the encoder :  21367
total parameters in the node_decoder :  10506526
total parameters in the decoder :  0


In [20]:
model

GAEM(
  (encoder): L3EGConv(
    (conv1): EGConv(1, 100, aggregators=['symnorm'])
    (conv2): EGConv(100, 50, aggregators=['symnorm'])
    (conv3): EGConv(50, 1, aggregators=['symnorm'])
    (drop): Dropout(p=0.1, inplace=False)
  )
  (node_decoder): L2Linear(
    (feat_lin): Sequential(
      (0): Linear(in_features=2291, out_features=2291, bias=True)
      (1): BatchNorm1d(2291, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): Linear(in_features=2291, out_features=2291, bias=True)
    )
  )
  (decoder): InnerProductDecoder()
)

In [21]:
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [22]:
lossfn = torchvision.ops.focal_loss.sigmoid_focal_loss

In [23]:
all_val_loss = []
for e in range(num_epochs):
    
    print("Epoch : ", e + 1,  " / " , num_epochs)
    
    train_loss = train(model,train_loader,optimizer, device,lossfn )
    val_loss = validate(model,val_loader,device, lossfn)
    all_val_loss.append(val_loss)
    # wandb.log({"validation loss" : val_loss,"train loss" : train_loss})
    
    if e > 10:
        if val_loss < min(all_val_loss):
            torch.save(model.state_dict(), bestmodel)
            print("Saved best model weights")
    
    if (e+1) % 20 == 0:
        print("Saving Model")
        torch.save(model.state_dict(), finalmodel)

# wandb.finish()

Epoch :  1  /  1


100%|██████████| 100/100 [00:43<00:00,  2.32it/s]


Train Loss :  0.019447598982369526 1.3862909173965454 1.4057385087013246


100%|██████████| 26/26 [00:11<00:00,  2.21it/s]

Validation Loss :    1.4498558044433594





In [24]:
final_vectors = encode(model,encode_loader,device)
final_vectors = np.array(final_vectors).reshape(len(data.patients), -1)

100%|██████████| 251/251 [00:02<00:00, 95.87it/s] 


In [25]:
final_df = pd.DataFrame(final_vectors, index=data.patients, columns = data.node_order)
# final_df.to_csv(fencsave)

In [26]:
final_df

Unnamed: 0,HSPA2,RPN1,GK2,HSPA6,PPP3R1,DLG1,YWHAH,HIST1H4I,HSPA8,PCSK6,...,JMJD7-PLA2G4B,MALAT1,REG3G,NUTM2A,TRB,IGL,HES3,CCL15,CCL4L2,SCUBE1
TCGA-B9-A8YI,0.01183,0.012383,0.012383,0.013066,0.012407,0.012407,0.013058,0.012407,0.013788,0.012383,...,0.010676,0.010430,0.010462,0.010775,0.010567,0.011374,0.010571,0.010365,0.010454,0.010388
TCGA-UZ-A9PK,0.01183,0.012383,0.012383,0.013066,0.012407,0.012407,0.013058,0.012407,0.013788,0.012383,...,0.010735,0.010464,0.010471,0.010823,0.010563,0.011373,0.010591,0.010367,0.010478,0.010424
TCGA-Y8-A897,0.01183,0.012383,0.012383,0.013066,0.012407,0.012407,0.013058,0.012407,0.013788,0.012383,...,0.010654,0.010418,0.010449,0.010775,0.010555,0.011378,0.010557,0.010351,0.010437,0.010378
TCGA-2Z-A9J3,0.01183,0.012383,0.012383,0.013066,0.012407,0.012407,0.013058,0.012407,0.013788,0.012383,...,0.010684,0.010432,0.010461,0.010803,0.010564,0.011374,0.010593,0.010359,0.010452,0.010400
TCGA-UZ-A9PN,0.01183,0.012383,0.012383,0.013066,0.012407,0.012407,0.013058,0.012407,0.013788,0.012383,...,0.010667,0.010423,0.010454,0.010785,0.010562,0.011373,0.010568,0.010351,0.010446,0.010387
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
TCGA-BQ-7053,0.01183,0.012383,0.012383,0.013066,0.012407,0.012407,0.013058,0.012407,0.013788,0.012383,...,0.010662,0.010423,0.010450,0.010786,0.010556,0.011376,0.010561,0.010348,0.010438,0.010384
TCGA-UZ-A9PL,0.01183,0.012383,0.012383,0.013066,0.012407,0.012407,0.013058,0.012407,0.013788,0.012383,...,0.010722,0.010462,0.010461,0.010827,0.010565,0.011373,0.010574,0.010360,0.010462,0.010410
TCGA-MH-A854,0.01183,0.012383,0.012383,0.013066,0.012407,0.012407,0.013058,0.012407,0.013788,0.012383,...,0.010714,0.010472,0.010472,0.010821,0.010572,0.011906,0.010585,0.010373,0.010469,0.010404
TCGA-HE-7129,0.01183,0.012383,0.012383,0.013066,0.012407,0.012407,0.013058,0.012407,0.013788,0.012383,...,0.010646,0.010415,0.010444,0.010775,0.010552,0.011373,0.010557,0.010344,0.010433,0.010373


In [27]:
!nvidia-smi

Sun Nov 13 00:56:35 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 510.85.02    Driver Version: 510.85.02    CUDA Version: 11.6     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA GeForce ...  Off  | 00000000:02:00.0 Off |                  N/A |
| 36%   31C    P8    21W / 250W |      3MiB / 11264MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA GeForce ...  Off  | 00000000:82:00.0 Off |                  N/A |
| 40%   40C    P2    75W / 250W |   3322MiB / 11264MiB |     55%      Default |
|       

In [28]:
# wandb.finish()