In [1]:
import numpy as np
import pandas as pd
import sys
import csv
import copy
import time 
import pickle
from sklearn.model_selection import train_test_split
from DTIDataset import *
from GCADTI import *
import argparse
import torch
from torch.utils import data
import torch.nn.functional as F
from torch.autograd import Variable
from torch import nn
from torch.utils.tensorboard import SummaryWriter
from prettytable import PrettyTable
from rdkit import Chem
import json

from sklearn.metrics import accuracy_score,roc_auc_score,f1_score,balanced_accuracy_score,\
        recall_score,precision_score,precision_recall_curve,confusion_matrix,average_precision_score
class Model:
    def __init__(self, modeldir, device):
        self.model = GCADTI()
        self.device = torch.device(device)
        self.modeldir = modeldir
        self.record_fileval = os.path.join(self.modeldir, "valid_markdowntable.txt")
        self.record_filetest = os.path.join(self.modeldir, "test_markdowntable.txt")
        self.pkl_file = os.path.join(self.modeldir, "loss_curve_iter.pkl")
    def test(self,datagenerator,model,mode=None):
        y_label = []
        y_pred = []
        loss_s=0
        model.eval()
        if mode == 'predict':
            for i, data in enumerate(datagenerator):
                drug, target, label = data
                drug, target, label = drug.to(self.device),target.to(self.device),label.float().to(self.device)
                _,score = model(target,drug)
                score = torch.squeeze(score, 1)
                logits =score.detach().cpu().numpy()
                label_ids = label.to('cpu').numpy()
                y_label = y_label + label_ids.flatten().tolist()
                y_pred = y_pred + logits.flatten().tolist()
            return y_label, y_pred
        else:
            for i, data in enumerate(datagenerator):
                drug, target, label = data
                drug, target, label = drug.to(self.device),target.to(self.device),label.float().to(self.device)
                _,score = model(target,drug)
                loss_fct = torch.nn.BCEWithLogitsLoss()
                score = torch.squeeze(score, 1)
                loss = loss_fct(score, label)
                #loss = loss_fct(score.squeeze(), label)
                logits =score.detach().cpu().numpy()
                label_ids = label.to('cpu').numpy()
                y_label = y_label + label_ids.flatten().tolist()
                y_pred = y_pred + logits.flatten().tolist()
                loss_s = loss_s+loss
            loss_m = loss_s/(i+1)
            auroc= roc_auc_score(y_label, y_pred)
            auprc = average_precision_score(y_label, y_pred)
            model.train()
            if mode == 'val':
                return loss_m,auroc,auprc
            elif mode == 'test':
                prec, recall, thresholds = precision_recall_curve(y_label, y_pred)
                f1 = 2 * prec * recall / (prec + recall + 0.00001)
                thred_optim = thresholds[np.argmax(f1)]
                print(thred_optim)
                y_pred_s = [1 if i else 0 for i in (y_pred >= thred_optim)]
                cm1 = confusion_matrix(y_label, y_pred_s)
                accuracy = (cm1[0, 0] + cm1[1, 1]) / sum(sum(cm1))
                sensitivity = cm1[0, 0] / (cm1[0, 0] + cm1[0, 1])
                specificity = cm1[1, 1] / (cm1[1, 0] + cm1[1, 1])
                return y_label, y_pred_s,accuracy_score(y_label, y_pred_s),auroc, auprc,sensitivity,specificity, f1_score(y_label, y_pred_s),\
                        recall_score(y_label, y_pred_s),balanced_accuracy_score(y_label, y_pred_s),precision_score(y_label, y_pred_s), loss_m
    def predict(self,dataset):
        print('predicting...')
        self.model=self.model.to(self.device)
        params = {'batch_size': 128,
                  'shuffle': False,
                  'num_workers':4,
                  'drop_last': False,
                'collate_fn': graph_collate_func}
        generator = DataLoader(dataset, **params)

        y_label, y_pred= self.test(generator, self.model,mode = 'predict')

        return y_label, y_pred
    def load_pretrained(self, path, device):
        if not os.path.exists(path):
            os.makedirs(path)

        if self.device == device:
            state_dict = torch.load(path)
        else:
            state_dict = torch.load(path, map_location=torch.device('cpu'))

        if next(iter(state_dict))[:7] == 'module.':
            # the pretrained model is from data-parallel module
            from collections import OrderedDict
            new_state_dict = OrderedDict()
            for k, v in state_dict.items():
                name = k[7:]  # remove `module.`
                new_state_dict[name] = v
            state_dict = new_state_dict
        self.model.load_state_dict(state_dict)
def predict_dti_vector(smiles: str) -> torch.Tensor:
    """
    주어진 smiles와 단백질 리스트에 대해 DTI 예측값 벡터 반환

    Args:
        smiles (str): SMILES 문자열
        prot_list (list): sequence 문자열 리스트
        model (Model): 학습된 DTI 모델 (GCADTI wrapper)

    Returns:
        torch.Tensor: 예측값 label 벡터 (len = len(prot_list))
    """
    
    IC_test=pd.read_csv('./IC_test.csv',index_col=0)
    np.seterr(divide='ignore',invalid='ignore')

    modeldir = './'
    net = Model(modeldir=modeldir,device='cpu')
    net.load_pretrained('./model.pt', 'cpu')

    prot_list =IC_test['sequence'].unique()
    pred_dti=[]
    
    if Chem.MolFromSmiles(smiles) is None:
        raise ValueError(f"Invalid SMILES: {smiles}")
    
    pred_dti = pd.DataFrame({'smiles': [smiles]*len(prot_list), 'sequence': prot_list, 'label': [0]*len(prot_list)})    

    with torch.no_grad():
        pred_set=DTIDataset(pred_dti)
        y_label, y_pred  = net.predict(pred_set)
    pred_dti['label']=y_pred
    df = pred_dti
    df_sorted = df.sort_values(by=['sequence'], ascending=[True])
    return torch.tensor(df_sorted['label'])

In [2]:
dti = predict_dti_vector('COCCOC1=C(C=C2C(=C1)C(=NC=N2)NC3=CC=CC(=C3)C#C)OCCOC')

predicting...


In [3]:
dti.shape

torch.Size([1572])