In [1]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from model.GNN_Model import GNN_Model
import os
from os import listdir
from os.path import isfile, join
from ops.os_operation import mkdir
import shutil
import  numpy as np
from data_processing.Prepare_Input import Prepare_Input
from data_processing.Single_Dataset import Single_Dataset
from torch.utils.data import DataLoader
from data_processing.collate_fn import collate_fn_Jake
from torch.utils.data import DataLoader

In [None]:

def predict_multi_input(input_path, params):
    save_path = os.path.join(os.getcwd(), "Predict_Result")
    mkdir(save_path)
    save_path = os.path.join(save_path, "Multi_Target")
    mkdir(save_path)
    save_path = os.path.join(save_path, "Fold_" + str(params['fold']) + "_Result")
    mkdir(save_path)
    input_path=os.path.abspath(input_path)
    folder_name=os.path.split(input_path)[1]
    save_path = os.path.join(save_path, folder_name)
    mkdir(save_path)

    fold_choice = params['fold']
    # loading the model
    if fold_choice != -1:
        model_path = os.path.join(os.getcwd(), "best_model")
        model_path = os.path.join(model_path, "fold" + str(fold_choice))
        model_path = os.path.join(model_path, "checkpoint.pth.tar")
        model, device = init_model(model_path, params)
    else:
        root_model_path = os.path.join(os.getcwd(), "best_model")
        model_list = []
        for k in range(1, 4):
            model_path = os.path.join(root_model_path, "fold" + str(k))
            model_path = os.path.join(model_path, "checkpoint.pth.tar")
            model, device = init_model(model_path, params)
            model_list.append(model)
        model = model_list

    listfiles=[x for x in os.listdir(input_path) if ".pdb" in x]
    listfiles.sort()
    Study_Name=[]
    Input_File_List=[]
    for item in listfiles:
        input_pdb_path=os.path.join(input_path,item)
        cur_root_path = os.path.join(save_path, item[:-4])
        Study_Name.append(item[:-4])
        mkdir(cur_root_path)
        structure_path=os.path.join(cur_root_path,"Input.pdb")
        shutil.copy(input_pdb_path, structure_path)
        input_file = Prepare_Input(structure_path)
        Input_File_List.append(input_file)
    list_npz = Input_File_List
    dataset = Single_Dataset(list_npz)
    dataloader = DataLoader(dataset, params['batch_size'], shuffle=False,
                            num_workers=params['num_workers'],
                            drop_last=False, collate_fn=collate_fn)

    # prediction
    if fold_choice != -1:
        Final_Pred = Get_Predictions(dataloader, device, model)
    else:
        Final_Pred = []
        for cur_model in model:
            tmp_pred = Get_Predictions(dataloader, device, cur_model)
            Final_Pred.append(tmp_pred)
        Final_Pred = np.mean(Final_Pred, axis=0)
    pred_path = os.path.join(save_path, 'Predict.txt')
    with open(pred_path, 'w') as file:
        file.write("Input\tScore\n")
        for k in range(len(Input_File_List)):
            file.write(Study_Name[k] + "\t%.4f\n" % Final_Pred[k])
    pred_sort_path=os.path.join(save_path,"Predict_sort.txt")
    os.system("sort -n -k 2 -r "+pred_path+" >"+pred_sort_path)

In [2]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def initialize_model(model, device, load_save_file=False):
    if load_save_file:
        model.load_state_dict(torch.load(load_save_file))
    else:
        for param in model.parameters():
            if param.dim() == 1:
                continue
                nn.init.constant(param, 0)
            else:
                #nn.init.normal(param, 0.0, 0.15)
                nn.init.xavier_normal_(param)

    if torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
        model = nn.DataParallel(model)
    model.to(device)
    return model

def init_model(model_path,params):
    model = GNN_Model(params)
    print('    Total params: %.10fM' % (count_parameters(model) / 1000000.0))
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = initialize_model(model, device)
    state_dict = torch.load(model_path, map_location = device)
    model.load_state_dict(state_dict)
    model.eval()
    return model,device

In [3]:
def Get_Predictions(dataloader,device,model):
    Final_pred = []
    with torch.no_grad():
        for batch_idx, sample in enumerate(dataloader):
            H, A1, A2, V, Atom_count = sample
            batch_size = H.size(0)
            H, A1, A2, V = H.to(device), A1.to(device), A2.to(device), V.to(device)
            pred= model.test_model((H, A1, A2, V, Atom_count), device)
            pred1 = pred.detach().cpu().numpy()
            Final_pred += list(pred1)
    return Final_pred

def test_falcon_gnn(input_path, params):
    '''
    Trains the GNN according to Jake's parameters.
    FALCON_GNN stands for Fucking Awesome Linking
    Cohort Of Nottingham, which I just made up
    '''
    ## Get all the NPZ files from the input_path

    list_npz = [f for f in listdir(input_path) if isfile(join(input_path, f)) and f.endswith(".npz")]
    fold1_label = ['1a2k', '1e96', '1he1', '1he8', '1wq1', '1f6m', '1ma9', '2btf', '1g20', '1ku6', '1t6g', '1ugh', '1yvb', '2ckh', '3pro']
    fold2_label = ['1akj', '1p7q', '2bnq', '1dfj', '1nbf', '1r4m', '1xd3', '2bkr', '1gpw', '1hxy', '1u7f', '1uex', '1zy8', '2goo', '1ewy']
    fold3_label = ['1avw', '1bth', '1bui', '1cho', '1ezu', '1ook', '1oph', '1ppf', '1tx6', '1xx9', '2fi4', '2kai', '1r0r', '2sni', '3sic']
    fold4_label = ['1bvn', '1tmq', '1f51', '1fm9', '1a2y', '1g6v', '1gpq', '1jps', '1wej', '1l9b', '1s6v', '1w1i', '2a5t', '3fap']

    train_data_list = []
    test_data_list = []
    for file in list_npz:
        if file[:4] in fold1_label or file[:4] in fold3_label or file[:4] in fold4_label:
            train_data_list.append(file)
        elif file[:4] in fold2_label:
            test_data_list.append(file)
    
    ## Just to test it out
    test_data_list = test_data_list[:10]
    
    train_data = Single_Dataset(train_data_list)
    test_data = Single_Dataset(test_data_list)
    #train_data, test_data = random_split(dataset, [0.75, 0.25])

    BATCH_SIZE = 10

    train_loader = DataLoader(train_data, BATCH_SIZE, shuffle=False,
                            num_workers=params['num_workers'],
                            drop_last=False, collate_fn=collate_fn_Jake)

    test_loader = DataLoader(test_data, BATCH_SIZE, shuffle=False,
                            num_workers=params['num_workers'],
                            drop_last=False, collate_fn=collate_fn_Jake)

    ## Initialize Model from saved state
    
    
    model_path = "/mnt/c/Users/jaket/Documents/GNN_DOVE_DATA/full_train_DG1_random_batch_jake_params_3.pt"
    model, device = init_model(model_path, params)
    
    ## Get predictions for the model

    Final_Pred = Get_Predictions(test_loader, device, model)
     
    return test_data_list, Final_Pred

## Actually testing the code

In [4]:
## Getting params in the same format as the actual program

params = {
    'F' : 'example',
    'mode' : '3',
    'gpu' : '0',
    'batch_size' : 32,
    'num_workers' : 7,
    'n_graph_layer' : 3,
    'd_graph_layer' : 1024,
    'n_FC_layer' : 4,
    'd_FC_layer' : 128,
    'initial_mu' : 0.0,
    'initial_dev' : 1.0,
    'dropout_rate' : 0.3,
    'seed' : 888,
    'fold' : -1,
    'receptor_units' : 1,
}


input_path = "/mnt/c/Users/jaket/Documents/GNN_DOVE_DATA/dockground_1_processed_npz"
test_falcon_gnn(input_path, params)

    Total params: 9.5805500000M


NameError: name 'dataloader' is not defined

In [20]:
%debug

> [0;32m/mnt/c/Users/jaket/Documents/GitHub/GNN_DOVE/predict/predict_single_input.py[0m(53)[0;36minit_model[0;34m()[0m
[0;32m     51 [0;31m    [0mmodel[0m [0;34m=[0m [0minitialize_model[0m[0;34m([0m[0mmodel[0m[0;34m,[0m [0mdevice[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     52 [0;31m    [0mstate_dict[0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0mload[0m[0;34m([0m[0mmodel_path[0m[0;34m,[0m [0mmap_location[0m [0;34m=[0m [0mdevice[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 53 [0;31m    [0mmodel[0m[0;34m.[0m[0mload_state_dict[0m[0;34m([0m[0mstate_dict[0m[0;34m[[0m[0;34m'state_dict'[0m[0;34m][0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     54 [0;31m    [0mmodel[0m[0;34m.[0m[0meval[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     55 [0;31m    [0;32mreturn[0m [0mmodel[0m[0;34m,[0m[0mdevice[0m[0;34m[0m[0;34m[0m[0m
[0m
ipdb> p state_dict
OrderedDict([('mu', tensor([-

ipdb> p state_dict.keys()
odict_keys(['mu', 'dev', 'gconv1.0.A', 'gconv1.0.W.weight', 'gconv1.0.W.bias', 'gconv1.0.gate.weight', 'gconv1.0.gate.bias', 'gconv1.1.A', 'gconv1.1.W.weight', 'gconv1.1.W.bias', 'gconv1.1.gate.weight', 'gconv1.1.gate.bias', 'gconv1.2.A', 'gconv1.2.W.weight', 'gconv1.2.W.bias', 'gconv1.2.gate.weight', 'gconv1.2.gate.bias', 'FC.0.weight', 'FC.0.bias', 'FC.1.weight', 'FC.1.bias', 'FC.2.weight', 'FC.2.bias', 'FC.3.weight', 'FC.3.bias'])
--KeyboardInterrupt--

KeyboardInterrupt: Interrupted by user
