In [1]:
import torch
import os
from torch_geometric.data import Data
import pandas as pd
import numpy as np
from torchvision.transforms import ToTensor
from torch_geometric.loader import DataLoader
from tqdm.auto import trange, tqdm
from torch_geometric import seed_everything
import torchvision

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import wandb
wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33mmeagoodboy[0m ([33mitsagoodteam[0m). Use [1m`wandb login --relogin`[0m to force relogin


True

In [2]:
from encoder_models import *
from node_decoder_models import *
from test_train_validate import *
from data_processing.process_data import *
from utils import *
from wrappers import *

In [3]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [4]:
clear_cache()

In [5]:
seed_everything(36912)

In [6]:
omics_file = "../data/raw/sm/kirp_sm251_csn.csv"
network_file = "../data/network/CancerSubnetwork.txt"

learning_rate = 3e-4
num_features = 1
out_channels = 1
num_epochs = 1
batch_size = 8
alpha = 1
beta = 0.5

cancer = "KIRP"
omic = "SM"
gmodel = "L2GravNetConv"
optim = "ADAM"

savefolder = "./new_res/" + cancer + "/" + omic + "/" + gmodel + "/" +  optim + "_" + str(learning_rate) +"/"
savemodels = "./new_mod/" + cancer + "/" + omic + "/" + gmodel + "/" + optim + "_" +  str(learning_rate) +"/"
savename = cancer + "_" + omic + "_" + gmodel + "_" + optim + "_" + str(learning_rate) +"_"

savename = savemodels + savename 
summaryin = savemodels + "runs"
bestmodel = savename + "bestmodel.pt"
finalmodel = savename + "model.pt" 
configf = savename + "config.yml" 
fencsave = savefolder + "final.csv"
bencsave = savefolder + "best.csv"

if not os.path.exists(savefolder):
    os.makedirs(savefolder)
if not os.path.exists(savemodels):
    os.makedirs(savemodels)

In [7]:
# wandb.init(project="TEST1_KIRP_GCNM2")

# cfg = wandb.config
# cfg.update({"epochs" : num_epochs, "batch_size": batch_size, "lr" : learning_rate,"optim" : optim,"data_type" : omic , "cancer" : cancer,"save":savefolder,"model_type":gmodel})

In [8]:
data = SingleOmicData(network_file, omics_file, 1)
num_nodes = len(data.node_order)

Loading the Network :  ../data/network/CancerSubnetwork.txt


100%|██████████| 2291/2291 [00:00<00:00, 2331.89it/s]


Loaded the Network 


In [11]:
train_size = int(0.8 * len(data))
x,y = torch.utils.data.random_split(data, lengths=[train_size, len(data) - train_size], generator=torch.Generator())

In [12]:
train_loader = DataLoader(x, shuffle=True, batch_size=batch_size,num_workers=8)
val_loader = DataLoader(y, shuffle=True, batch_size=batch_size, num_workers=8)
encode_loader = DataLoader(data, shuffle=False, batch_size=1)

In [9]:
model = GAEM( encoder = get_encoder(gmodel,in_channels = num_features, out_channels = out_channels),
             node_decoder = L2Linear(out_channels = out_channels, num_nodes = num_nodes, batch_size=batch_size) )

In [13]:
model = model.to(device)

In [14]:
!nvidia-smi

Sat Nov 12 22:19:15 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 510.85.02    Driver Version: 510.85.02    CUDA Version: 11.6     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA GeForce ...  Off  | 00000000:02:00.0 Off |                  N/A |
| 36%   34C    P2    54W / 250W |   1364MiB / 11264MiB |      1%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA GeForce ...  Off  | 00000000:82:00.0 Off |                  N/A |
| 37%   27C    P8     9W / 250W |      3MiB / 11264MiB |      0%      Default |
|       

In [15]:
print("total parameters in the model : ", calculate_num_params(model))
print("total parameters in the encoder : ", calculate_num_params(model.encoder))
print("total parameters in the node_decoder : ", calculate_num_params(model.node_decoder))
print("total parameters in the decoder : ", calculate_num_params(model.decoder))

total parameters in the model :  10511827
total parameters in the encoder :  5301
total parameters in the node_decoder :  10506526
total parameters in the decoder :  0


In [16]:
model

GAEM(
  (encoder): L3GCNConv(
    (conv1): GCNConv(1, 100)
    (conv2): GCNConv(100, 50)
    (conv3): GCNConv(50, 1)
    (drop): Dropout(p=0.1, inplace=False)
  )
  (node_decoder): L2Linear(
    (feat_lin): Sequential(
      (0): Linear(in_features=2291, out_features=2291, bias=True)
      (1): BatchNorm1d(2291, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): Linear(in_features=2291, out_features=2291, bias=True)
    )
  )
  (decoder): InnerProductDecoder()
)

In [17]:
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [18]:
lossfn = torchvision.ops.focal_loss.sigmoid_focal_loss

In [None]:
all_val_loss = []
for e in range(num_epochs):
    
    print("Epoch : ", e + 1,  " / " , num_epochs)
    
    train_loss = train(model,train_loader,optimizer, device,lossfn )
    val_loss = validate(model,val_loader,device, lossfn)
    all_val_loss.append(val_loss)
    # wandb.log({"validation loss" : val_loss,"train loss" : train_loss})
    
    if e > 10:
        if val_loss < min(all_val_loss):
            torch.save(model.state_dict(), bestmodel)
            print("Saved best model weights")
    
    if (e+1) % 20 == 0:
        print("Saving Model")
        torch.save(model.state_dict(), finalmodel)

# wandb.finish()

Epoch :  1  /  50


 76%|███████▌  | 19/25 [00:31<00:09,  1.55s/it]

In [None]:
final_vectors = encode(model,encode_loader,device)
final_vectors = np.array(final_vectors).reshape(len(data.patients), -1)

In [None]:
final_df = pd.DataFrame(final_vectors, index=data.patients, columns = data.node_order)
# final_df.to_csv(fencsave)

In [None]:
final_df

In [None]:
!nvidia-smi

In [None]:
wandb.finish()