### Importing Libraries

In [None]:
import sys,os
import random
import numpy as np
import json
from collections import OrderedDict
from utils import *
from emetrics import *
from data import create_dataset_for_train
import torch
import torch.nn as nn
from torch_geometric.data import DataLoader
import rdkit as rd
from torch_sparse import SparseTensor,transpose
import deepchem
import tensorflow as tf
import pandas as pd
import pickle
from dnn import GNNNet,GNNNet_prod,GNNNet_prod_conc

### Loading the dataset- Davis [0] or KIBA [1]

In [None]:
datasets = [['davis', 'kiba'][0]]
datasets

### Select the ligand encoding method and contact map method for protein encoding

In [None]:
#protein contact map technique
method=['pconsc4', 'esm_cmaps', 'alpha_fold_cmaps','rand_cmaps'][0]
method

In [None]:
#Ligand encoding method
method1=['original','point_random', "random_node",'random_sample'][0]
method1

### Select the method to combine the encodings

In [None]:
comb=['conc','prod','conc+prod'][0]

if comb=='conc':
    model = GNNNet()
elif comb=='prod':
    model = GNNNet_prod()
elif comb=='conc+prod':
    model = GNNNet_prod_conc()

### Initialising the model

In [None]:
#Loading the GNN model for generating the embedding
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#If CUDA is available
cuda_name = ['cuda:0', 'cuda:1', 'cuda:2', 'cuda:3'][0]
device = torch.device(cuda_name)
model.to(device)
model_st = GNNNet.__name__
fold = [0, 1, 2, 3, 4][0]
#Setting the path to save the trained model
models_dir = 'models'
if not os.path.exists(models_dir):
    os.makedirs(models_dir)

### Hyperparameter setting

In [None]:
TRAIN_BATCH_SIZE = 128
TEST_BATCH_SIZE = 128
LR = 0.001 #Learning Rate
NUM_EPOCHS = 2000 #No.of Epochs
loss_fn = nn.MSELoss() # Loss function - MSE
optimizer = torch.optim.Adam(model.parameters(), lr=LR) #Adam Optimizer

In [None]:
for dataset in datasets:
    train_data, valid_data = create_dataset_for_train(dataset, fold, method,method1)
    train_loader = torch.utils.data.DataLoader(train_data, batch_size=TRAIN_BATCH_SIZE, shuffle=True,
                                               collate_fn=collate)
    valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=TEST_BATCH_SIZE, shuffle=False,
                                               collate_fn=collate)

    best_mse = 1000
    best_test_mse = 1000
    best_epoch = -1
    #Set the model file name
    model_file_name = 'models_sample/model_'+ method + '_'+ model_st + '_' +method1 + '_'+ dataset + '_random_node_' + str(fold) +'.model'
    print(model_file_name)
    mse_list1=[]
    for epoch in range(NUM_EPOCHS):
        train(model, device, train_loader, optimizer, epoch + 1)
        print('predicting for valid data')
        G, P = predicting(model, device, valid_loader)
        val = get_mse(G, P)
        mse_list1.append(val)
        print('valid result:', val, best_mse)
        if val < best_mse:
            best_mse = val
            best_epoch = epoch + 1
            torch.save(model.state_dict(), model_file_name)
            print('rmse improved at epoch ', best_epoch, '; best_test_mse', best_mse, model_st, dataset, fold)
        else:
            print('No improvement since epoch ', best_epoch, '; best_test_mse', best_mse, model_st, dataset, fold)


In [None]:
with open('mse_davis_pconcs4_random_node_3.txt', 'w') as f:
    for item in mse_list1:
        f.write("%s\n" % item)

In [None]:
import matplotlib.pyplot as plt
ep=[i for i in range(1,2001)]

plt.plot(ep,mse_list1)