In [1]:
import torch
import argparse
from model import Model, MaskLoss
from train_10f import load_data
from data_utils import parseDataFea, parseA2
from model import calculate_pearsonr, calculate_spearmanr, calculate_r2score

In [2]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = "2"

In [3]:
from sklearn.metrics import mean_squared_error
def calculate_mse(ipt, tar, mask):
    a, b = [], []
    for i, k in torch.nonzero(mask):
        a.append(ipt[i, k])
        b.append(tar[i, k])
    a = torch.Tensor(a)
    b = torch.Tensor(b)
    return mean_squared_error(a, b)

from sklearn import metrics
def calculate_auc(ipt, tar, mask):
    a, b = [], []
    for i, k in torch.nonzero(mask):
        a.append(ipt[i, k])
        b.append(tar[i, k])
    a = torch.Tensor(a)
    b = torch.Tensor(b)
    return metrics.roc_auc_score(a, b)

def calculate_acc(ipt, tar, mask):
    a, b = [], []
    for i, k in torch.nonzero(mask):
        a.append(ipt[i, k])
        b.append(tar[i, k])
    a = torch.Tensor(a)
    b = torch.Tensor(b)
    return metrics.accuracy_score(a, b)

def calculate_aupr(ipt, tar, mask):
    a, b = [], []
    for i, k in torch.nonzero(mask):
        a.append(ipt[i, k])
        b.append(tar[i, k])
    a = torch.Tensor(a)
    b = torch.Tensor(b)
    p, r, *_ = metrics.precision_recall_curve(a, b)
    return metrics.auc(r, p)

def calculate_kappa(ipt, tar, mask):
    a, b = [], []
    for i, k in torch.nonzero(mask):
        a.append(ipt[i, k])
        b.append(tar[i, k])
    a = torch.Tensor(a)
    b = torch.Tensor(b)
    return metrics.cohen_kappa_score(a, b)

In [11]:
def evaluate_ckp(ckp_file):
    cur_fold = int(ckp_file.split('-')[-1][0])
    checkpoint = torch.load(ckp_file)
    args = checkpoint['args']
    emb, A, target, mask = load_data('./data/', args.data_set, args.th_rate, args.edge_mask, 10, cur_fold)
    dru_emb, dis_emb, tar_emb = emb
    dru_dru_A, dis_dis_A, tar_tar_A, dru_dis_A, dru_tar_A = A
    train_target, valid_target = target
    train_mask, valid_mask = mask
    device = args.device
    dru_emb = dru_emb.to(device)
    dis_emb = dis_emb.to(device)
    tar_emb = tar_emb.to(device)
    dru_dru_A = dru_dru_A.to(device)
    dis_dis_A = dis_dis_A.to(device)
    tar_tar_A = tar_tar_A.to(device)
    dru_dis_A = dru_dis_A.to(device)
    dru_tar_A = dru_tar_A.to(device)
    
    train_target = train_target.to(device)
    valid_target = valid_target.to(device)

    train_mask = train_mask.to(device)
    valid_mask = valid_mask.to(device)

    model = Model(
        dru_emb=len(dru_emb[0]),
        dis_emb=len(dis_emb[0]),
        tar_emb=len(tar_emb[0]),
        dru_hid=args.dru_hid_size,
        dis_hid=args.dis_hid_size,
        tar_hid=args.tar_hid_size,
        edge_dim=args.edge_dim,
        g_hid=args.g_hid,
        layer=args.layer,
        dp=args.dp,
        decoder=args.decoder,
        dru_agg=args.dru_agg,
        device=args.device
    )
    model = model.to(device)
    model.load_state_dict(checkpoint['model'])
    
    model.eval()
    with torch.no_grad():
        pred = model(dru_emb, dis_emb, tar_emb, dru_dru_A, dis_dis_A, tar_tar_A, dru_dis_A, dru_tar_A, valid_mask)

        loss = calculate_mse(pred, valid_target, valid_mask)
        pear = calculate_pearsonr(pred, valid_target, valid_mask)
        spea = calculate_spearmanr(pred, valid_target, valid_mask)
        r2sc = calculate_r2score(pred, valid_target, valid_mask)
    
    return [loss.item(), pear, spea, r2sc, valid_mask.cpu(), valid_target.cpu(), pred.cpu()]

In [41]:
import os
import pickle

matchmaker_cell_line = pickle.load(open('./data/cell_line_vec_dict.pt', 'rb'))
exists_cell_line = os.listdir('./data/')
cell_line = [name for name in matchmaker_cell_line if name in exists_cell_line]

from collections import defaultdict

data_all_score = defaultdict(lambda: [])
ckp_list = [item for item in os.listdir('./result/')]
ckp_list = sorted([['./result/' + item, item[12:], item[12:].split('_[')[0]] for item in ckp_list], key=lambda x:x[1])
ckp_list, ckp_name_list = [], []
for ckp in ckp_list:
    if ckp[1] not in ckp_name_list:
        ckp_list.append(ckp)
        ckp_name_list.append(ckp[1])

In [13]:
from tqdm import tqdm

In [8]:
save_list = {}
with open('res-grn-1201.pkl', 'rb') as f:
    save_list = pickle.load(f)

In [None]:
ckp_res_list = []
for ckp in tqdm(ckp_list):
    if ckp[0] in save_list:
        continue
    res = evaluate_ckp(ckp_file=ckp[0])
    ckp_res_list.append(ckp + [res])
    save_list[ckp[0]] = ckp + [res]

In [None]:
nsave_list = {save_list[ckp][1]: save_list[ckp] for ckp in save_list}
ckp_res_list = []
for ckp in tqdm(ckp_list):
    assert ckp[1] in nsave_list
    res = evaluate_ckp(ckp_file=ckp[0])
    ckp_res_list.append(ckp + [res])
    nsave_list[ckp[1]] = ckp + [res]

In [40]:
from collections import defaultdict

In [44]:
ckp_res_list = [nsave_list[ckp[1]] for ckp in ckp_list]

In [53]:
from data.data_v2 import parseA, parseA2, parseDataFea
def norm_deno(ckp_file):
    checkpoint = torch.load(ckp_file)
    args = checkpoint['args']
    dru_dict, *_ = parseDataFea('drugfeature1_finger_extract.csv', path='./data/' + args.data_set + '/')
    dru_dru_mat = parseA('drugdrug_extract.csv', dru_dict, path='./data/' + args.data_set + '/')
    deno = dru_dru_mat.max() - dru_dru_mat.min()
    return deno

In [None]:
deno_dict = {}
error_dict = []
res = defaultdict(lambda: {})
for threshold in [0, 5, 10][::-1]:
    for item in tqdm(ckp_res_list):
        if item[0] in res[threshold]:
            continue
        if item[2] not in deno_dict:
            deno_dict[item[2]] = norm_deno(item[0])
        deno = deno_dict[item[2]]
        ipt, tar, mask = item[-1][-1], item[-1][-2], item[-1][-3]
        ipt = ipt * deno
        tar = tar * deno
        ipt = ipt >= threshold
        tar = tar >= threshold
        
        try:
            acc = calculate_acc(ipt, tar, mask)
            auc = calculate_auc(ipt, tar, mask)
            aupr = calculate_aupr(ipt, tar, mask)
            kappa = calculate_kappa(ipt, tar, mask)
        
        except Exception as e:
            print(item[1])
            print(tar.max())
            print(deno)
            print(tar.max() * deno)
            print(e.args)
            raise e

        res[threshold][item[0]] = [
            item[0],
            item[1],
            item[2],
            [acc, auc, aupr, kappa]
        ]

In [55]:
nres = defaultdict(lambda: defaultdict(lambda: []))
for threshold in [0, 5, 10]:
    for item in ckp_res_list:
        nres[threshold][item[2]].append(res[threshold][item[0]])

def mean(data):
    return sum(data) / len(data)

res_score = {}
for threshold in [0, 5, 10]:
    res_score[threshold] = defaultdict(lambda: [])
    for key in nres[threshold]:
        scores = [item[-1] for item in nres[threshold][key]]
        res_score[threshold][key] = [mean(item) for item in zip(*scores)]

In [57]:
with open('res-grn-th=0.pkl', 'wb') as f:
    pickle.dump(dict(res_score[0]), f)

with open('res-grn-th=5.pkl', 'wb') as f:
    pickle.dump(dict(res_score[5]), f)

with open('res-grn-th=10.pkl', 'wb') as f:
    pickle.dump(dict(res_score[10]), f)