In [18]:
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.metrics import recall_score, precision_score, accuracy_score

def index_data():    
    drugPair2effect = pd.read_pickle('../data/drugPair2effect_idx.pkl')
    y_all = list(drugPair2effect.values())
    
    mlb = MultiLabelBinarizer()
    mlb.fit_transform(y_all)
    
    labels = sorted(list(set(y_all)))
    
    print (labels)
    
    label2idx = {}
    for i, j in enumerate(labels):
        label2idx[j] = i
        
    drugPair2effectIdx = {}
    for k, v in drugPair2effect.items():
        drugPair2effectIdx[k] = label2idx[v]
        
    idx2label = np.zeros(len(label2idx), dtype='O')
    for k, v in label2idx.items():
        idx2label[v] = np.array(k)
    
    return mlb, label2idx, idx2label, drugPair2effectIdx


def convert_tensor(x_idx, y_idx, SS_mat, TS_mat, GS_mat, mlb, idx2label):    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    SS = torch.tensor(SS_mat[x_idx].reshape(len(x_idx), len(SS_mat)*2)).float()
    TS = torch.tensor(TS_mat[x_idx].reshape(len(x_idx), len(TS_mat)*2)).float()
    GS = torch.tensor(GS_mat[x_idx].reshape(len(x_idx), len(GS_mat)*2)).float()
    y = torch.tensor(mlb.transform(idx2label[y_idx])).float()
    
    return SS, TS, GS, y

def evaluate_model(answer, prediction):
    accuracy = accuracy_score(answer, prediction)
    macro_recall = recall_score(answer, prediction, average='macro')
    macro_precision = precision_score(answer, prediction, average='macro')
    micro_recall = recall_score(answer, prediction, average='micro')
    micro_precision = precision_score(answer, prediction, average='micro')
    
    return accuracy, macro_recall, macro_precision, micro_recall, micro_precision


In [19]:
SS_mat = pd.read_pickle('../data/structural_similarity_matrix.pkl')
TS_mat = pd.read_pickle('../data/target_similarity_matrix.pkl')
GS_mat = pd.read_pickle('../data/GO_similarity_matrix.pkl')



[(0,), (1,), (2,), (3,), (4,), (5,), (6,), (7,), (8,), (9,), (10,), (11,), (12,), (13,), (14,), (14, 73), (14, 100), (15,), (16,), (17,), (18,), (19,), (20,), (21,), (22,), (23,), (24,), (25,), (26,), (27,), (27, 73), (28,), (29,), (30,), (31,), (32,), (33,), (34,), (35,), (36,), (37,), (38,), (39,), (40,), (41,), (42,), (43,), (44,), (45,), (46,), (47,), (48,), (49,), (50,), (51,), (52,), (53,), (54,), (55,), (56,), (56, 73), (57,), (58,), (59,), (60,), (61,), (62,), (63,), (64,), (65,), (66,), (67,), (67, 73), (67, 100), (68,), (68, 73), (68, 100), (69,), (69, 99), (70,), (71,), (72,), (73,), (73, 100), (74,), (75,), (76,), (77,), (78,), (79,), (80,), (81,), (82,), (83,), (84,), (85,), (86,), (87,), (88,), (89,), (90,), (91,), (92,), (93,), (94,), (95,), (96,), (97,), (98,), (99,), (99, 100), (99, 104), (100,), (100, 104), (101,), (102,), (103,), (104,), (105,)]


(MultiLabelBinarizer(classes=None, sparse_output=False),
 {(0,): 0,
  (1,): 1,
  (2,): 2,
  (3,): 3,
  (4,): 4,
  (5,): 5,
  (6,): 6,
  (7,): 7,
  (8,): 8,
  (9,): 9,
  (10,): 10,
  (11,): 11,
  (12,): 12,
  (13,): 13,
  (14,): 14,
  (14, 73): 15,
  (14, 100): 16,
  (15,): 17,
  (16,): 18,
  (17,): 19,
  (18,): 20,
  (19,): 21,
  (20,): 22,
  (21,): 23,
  (22,): 24,
  (23,): 25,
  (24,): 26,
  (25,): 27,
  (26,): 28,
  (27,): 29,
  (27, 73): 30,
  (28,): 31,
  (29,): 32,
  (30,): 33,
  (31,): 34,
  (32,): 35,
  (33,): 36,
  (34,): 37,
  (35,): 38,
  (36,): 39,
  (37,): 40,
  (38,): 41,
  (39,): 42,
  (40,): 43,
  (41,): 44,
  (42,): 45,
  (43,): 46,
  (44,): 47,
  (45,): 48,
  (46,): 49,
  (47,): 50,
  (48,): 51,
  (49,): 52,
  (50,): 53,
  (51,): 54,
  (52,): 55,
  (53,): 56,
  (54,): 57,
  (55,): 58,
  (56,): 59,
  (56, 73): 60,
  (57,): 61,
  (58,): 62,
  (59,): 63,
  (60,): 64,
  (61,): 65,
  (62,): 66,
  (63,): 67,
  (64,): 68,
  (65,): 69,
  (66,): 70,
  (67,): 71,
  (67, 73): 72

NameError: name 'SS' is not defined