### Import necessary libraries

In [1]:
from collections import defaultdict
import os
import pickle
import sys
import timeit

import numpy as np

from rdkit import Chem
from rdkit.Chem import rdDepictor, Descriptors
from rdkit.Chem import MACCSkeys

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import make_classification
from sklearn.multioutput import MultiOutputClassifier
from sklearn.metrics import accuracy_score, precision_score, recall_score, auc, roc_curve

### Check if GPU is available

In [2]:
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

### Helper functions

In [3]:
# dictionary of atoms where a new element gets a new index
def create_atoms(mol):
    atoms = [atom_dict[a.GetSymbol()] for a in mol.GetAtoms()]
    return np.array(atoms)

# format from_atomIDx : [to_atomIDx, bondDict]
def create_ijbonddict(mol):
    i_jbond_dict = defaultdict(lambda: [])
    for b in mol.GetBonds():
        i, j = b.GetBeginAtomIdx(), b.GetEndAtomIdx()
        bond = bond_dict[str(b.GetBondType())]
        i_jbond_dict[i].append((j, bond))
        i_jbond_dict[j].append((i, bond))
    return i_jbond_dict


def create_fingerprints(atoms, i_jbond_dict, radius):
    """Extract the r-radius subgraphs (i.e., fingerprints)
    from a molecular graph using WeisfeilerLehman-like algorithm."""

    if (len(atoms) == 1) or (radius == 0):
        fingerprints = [fingerprint_dict[a] for a in atoms]

    else:
        vertices = atoms
        for _ in range(radius):
            fingerprints = []
            for i, j_bond in i_jbond_dict.items():
                neighbors = [(vertices[j], bond) for j, bond in j_bond]
                fingerprint = (vertices[i], tuple(sorted(neighbors)))
                fingerprints.append(fingerprint_dict[fingerprint])
            vertices = fingerprints

    return np.array(fingerprints)


def create_adjacency(mol):
    adjacency  = Chem.GetAdjacencyMatrix(mol)
    n          = adjacency.shape[0]
    
    adjacency  = adjacency + np.eye(n)
    degree     = sum(adjacency)
    d_half     = np.sqrt(np.diag(degree))
    d_half_inv = np.linalg.inv(d_half)
    adjacency  = np.matmul(d_half_inv,np.matmul(adjacency,d_half_inv))
    return np.array(adjacency)


def dump_dictionary(dictionary, file_name):
    with open(file_name, 'wb') as f:
        pickle.dump(dict(dictionary), f)
        

def load_tensor(file_name, dtype):
    return [dtype(d).to(device) for d in np.load(file_name + '.npy', allow_pickle=True)]


def load_numpy(file_name):
    return np.load(file_name + '.npy', allow_pickle=True)


def load_pickle(file_name):
    with open(file_name, 'rb') as f:
        return pickle.load(f)


def shuffle_dataset(dataset, seed):
    np.random.seed(seed)
    np.random.shuffle(dataset)
    return dataset


def split_dataset(dataset, ratio):
    n = int(ratio * len(dataset))
    dataset_1, dataset_2 = dataset[:n], dataset[n:]
    return dataset_1, dataset_2

### Read data

In [4]:
radius = 2

with open('kegg_classes.txt', 'r') as f:
    data_list = f.read().strip().split('\n')
    
"""Exclude the data contains "." in the smiles, which correspond to non-bonds"""
data_list = list(filter(lambda x: '.' not in x.strip().split()[0], data_list))
N = len(data_list)

print('Total number of molecules : %d' %(N))

atom_dict = defaultdict(lambda: len(atom_dict))
bond_dict = defaultdict(lambda: len(bond_dict))
fingerprint_dict = defaultdict(lambda: len(fingerprint_dict))

Molecules, Adjacencies, Properties, MACCS_list = [], [], [], []

max_MolMR, min_MolMR     = -1000, 1000
max_MolLogP, min_MolLogP = -1000, 1000
max_MolWt, min_MolWt     = -1000, 1000
max_NumRotatableBonds, min_NumRotatableBonds = -1000, 1000
max_NumAliphaticRings, min_NumAliphaticRings = -1000, 1000
max_NumAromaticRings, min_NumAromaticRings   = -1000, 1000
max_NumSaturatedRings, min_NumSaturatedRings = -1000, 1000

for no, data in enumerate(data_list):
    
    print('/'.join(map(str, [no+1, N])))
    
    smiles, property_indices = data.strip().split('\t')
    property_s = property_indices.strip().split(',')
    
    property = np.zeros((1,11))
    for prop in property_s:
        property[0,int(prop)] = 1
        
    Properties.append(property)
    
    mol = Chem.MolFromSmiles(smiles)
    atoms = create_atoms(mol)
    i_jbond_dict = create_ijbonddict(mol)

    fingerprints = create_fingerprints(atoms, i_jbond_dict, radius)
    Molecules.append(fingerprints)
    
    adjacency = create_adjacency(mol)
    Adjacencies.append(adjacency)
    
    MACCS         = MACCSkeys.GenMACCSKeys(Chem.MolFromSmiles(smiles))
    MACCS_ids     = np.zeros((20,))
    MACCS_ids[0]  = Descriptors.MolMR(mol)
    MACCS_ids[1]  = Descriptors.MolLogP(mol)
    MACCS_ids[2]  = Descriptors.MolWt(mol)
    MACCS_ids[3]  = Descriptors.NumRotatableBonds(mol)
    MACCS_ids[4]  = Descriptors.NumAliphaticRings(mol)
    MACCS_ids[5]  = MACCS[108]
    MACCS_ids[6]  = Descriptors.NumAromaticRings(mol)
    MACCS_ids[7]  = MACCS[98]
    MACCS_ids[8]  = Descriptors.NumSaturatedRings(mol)
    MACCS_ids[9]  = MACCS[137]
    MACCS_ids[10] = MACCS[136]
    MACCS_ids[11] = MACCS[145]
    MACCS_ids[12] = MACCS[116]
    MACCS_ids[13] = MACCS[141]
    MACCS_ids[14] = MACCS[89]
    MACCS_ids[15] = MACCS[50]
    MACCS_ids[16] = MACCS[160]
    MACCS_ids[17] = MACCS[121]
    MACCS_ids[18] = MACCS[149]
    MACCS_ids[19] = MACCS[161]
    
    if max_MolMR < MACCS_ids[0]:
        max_MolMR = MACCS_ids[0]
    if min_MolMR > MACCS_ids[0]:
        min_MolMR = MACCS_ids[0]
        
    if max_MolLogP < MACCS_ids[1]:
        max_MolLogP = MACCS_ids[1]
    if min_MolLogP > MACCS_ids[1]:
        min_MolLogP = MACCS_ids[1]
        
    if max_MolWt < MACCS_ids[2]:
        max_MolWt = MACCS_ids[2]
    if min_MolWt > MACCS_ids[2]:
        min_MolWt = MACCS_ids[2]
        
    if max_NumRotatableBonds < MACCS_ids[3]:
        max_NumRotatableBonds = MACCS_ids[3]
    if min_NumRotatableBonds > MACCS_ids[3]:
        min_NumRotatableBonds = MACCS_ids[3]
        
    if max_NumAliphaticRings < MACCS_ids[4]:
        max_NumAliphaticRings = MACCS_ids[4]
    if min_NumAliphaticRings > MACCS_ids[4]:
        min_NumAliphaticRings = MACCS_ids[4]
        
    if max_NumAromaticRings < MACCS_ids[6]:
        max_NumAromaticRings = MACCS_ids[6]
    if min_NumAromaticRings > MACCS_ids[6]:
        min_NumAromaticRings = MACCS_ids[6]
    
    if max_NumSaturatedRings < MACCS_ids[8]:
        max_NumSaturatedRings = MACCS_ids[8]
    if min_NumSaturatedRings > MACCS_ids[8]:
        min_NumSaturatedRings = MACCS_ids[8]
    
    MACCS_list.append(MACCS_ids)

dir_input = ('pathway/input'+str(radius)+'/')
os.makedirs(dir_input, exist_ok=True)

for n in range(N):
    for b in range(20):
        if b==0:
            MACCS_list[n][b] = (MACCS_list[n][b]-min_MolMR)/(max_MolMR-min_MolMR)
        elif b==1:
            MACCS_list[n][b] = (MACCS_list[n][b]-min_MolLogP)/(max_MolMR-min_MolLogP)
        elif b==2:
            MACCS_list[n][b] = (MACCS_list[n][b]-min_MolWt)/(max_MolMR-min_MolWt)
        elif b==3:
            MACCS_list[n][b] = (MACCS_list[n][b]-min_NumRotatableBonds)/(max_MolMR-min_NumRotatableBonds)
        elif b==4:
            MACCS_list[n][b] = (MACCS_list[n][b]-min_NumAliphaticRings)/(max_MolMR-min_NumAliphaticRings)
        elif b==6:
            MACCS_list[n][b] = (MACCS_list[n][b]-min_NumAromaticRings)/(max_MolMR-min_NumAromaticRings)
        elif b==8:
            MACCS_list[n][b] = (MACCS_list[n][b]-min_NumSaturatedRings)/(max_NumSaturatedRings-min_NumSaturatedRings)

np.save(dir_input + 'molecules', Molecules)
np.save(dir_input + 'adjacencies', Adjacencies)
np.save(dir_input + 'properties', Properties)
np.save(dir_input + 'maccs', np.asarray(MACCS_list))

dump_dictionary(fingerprint_dict, dir_input + 'fingerprint_dict.pickle')

print('The preprocess has finished!')

Total number of molecules : 6669
1/6669
2/6669
3/6669
4/6669
5/6669
6/6669
7/6669
8/6669
9/6669
10/6669
11/6669
12/6669
13/6669
14/6669
15/6669
16/6669
17/6669
18/6669
19/6669
20/6669
21/6669
22/6669
23/6669
24/6669
25/6669
26/6669
27/6669
28/6669
29/6669
30/6669
31/6669
32/6669
33/6669
34/6669
35/6669
36/6669
37/6669
38/6669
39/6669
40/6669
41/6669
42/6669
43/6669
44/6669
45/6669
46/6669
47/6669
48/6669
49/6669
50/6669
51/6669
52/6669
53/6669
54/6669
55/6669
56/6669
57/6669
58/6669
59/6669
60/6669
61/6669
62/6669
63/6669
64/6669
65/6669
66/6669
67/6669
68/6669
69/6669
70/6669
71/6669
72/6669
73/6669
74/6669
75/6669
76/6669
77/6669
78/6669
79/6669
80/6669
81/6669
82/6669
83/6669
84/6669
85/6669
86/6669
87/6669
88/6669
89/6669
90/6669
91/6669
92/6669
93/6669
94/6669
95/6669
96/6669
97/6669
98/6669
99/6669
100/6669
101/6669
102/6669
103/6669
104/6669
105/6669
106/6669
107/6669
108/6669
109/6669
110/6669
111/6669
112/6669
113/6669
114/6669
115/6669
116/6669
117/6669
118/6669
119/6669
120/

932/6669
933/6669
934/6669
935/6669
936/6669
937/6669
938/6669
939/6669
940/6669
941/6669
942/6669
943/6669
944/6669
945/6669
946/6669
947/6669
948/6669
949/6669
950/6669
951/6669
952/6669
953/6669
954/6669
955/6669
956/6669
957/6669
958/6669
959/6669
960/6669
961/6669
962/6669
963/6669
964/6669
965/6669
966/6669
967/6669
968/6669
969/6669
970/6669
971/6669
972/6669
973/6669
974/6669
975/6669
976/6669
977/6669
978/6669
979/6669
980/6669
981/6669
982/6669
983/6669
984/6669
985/6669
986/6669
987/6669
988/6669
989/6669
990/6669
991/6669
992/6669
993/6669
994/6669
995/6669
996/6669
997/6669
998/6669
999/6669
1000/6669
1001/6669
1002/6669
1003/6669
1004/6669
1005/6669
1006/6669
1007/6669
1008/6669
1009/6669
1010/6669
1011/6669
1012/6669
1013/6669
1014/6669
1015/6669
1016/6669
1017/6669
1018/6669
1019/6669
1020/6669
1021/6669
1022/6669
1023/6669
1024/6669
1025/6669
1026/6669
1027/6669
1028/6669
1029/6669
1030/6669
1031/6669
1032/6669
1033/6669
1034/6669
1035/6669
1036/6669
1037/6669
1038/666

1769/6669
1770/6669
1771/6669
1772/6669
1773/6669
1774/6669
1775/6669
1776/6669
1777/6669
1778/6669
1779/6669
1780/6669
1781/6669
1782/6669
1783/6669
1784/6669
1785/6669
1786/6669
1787/6669
1788/6669
1789/6669
1790/6669
1791/6669
1792/6669
1793/6669
1794/6669
1795/6669
1796/6669
1797/6669
1798/6669
1799/6669
1800/6669
1801/6669
1802/6669
1803/6669
1804/6669
1805/6669
1806/6669
1807/6669
1808/6669
1809/6669
1810/6669
1811/6669
1812/6669
1813/6669
1814/6669
1815/6669
1816/6669
1817/6669
1818/6669
1819/6669
1820/6669
1821/6669
1822/6669
1823/6669
1824/6669
1825/6669
1826/6669
1827/6669
1828/6669
1829/6669
1830/6669
1831/6669
1832/6669
1833/6669
1834/6669
1835/6669
1836/6669
1837/6669
1838/6669
1839/6669
1840/6669
1841/6669
1842/6669
1843/6669
1844/6669
1845/6669
1846/6669
1847/6669
1848/6669
1849/6669
1850/6669
1851/6669
1852/6669
1853/6669
1854/6669
1855/6669
1856/6669
1857/6669
1858/6669
1859/6669
1860/6669
1861/6669
1862/6669
1863/6669
1864/6669
1865/6669
1866/6669
1867/6669
1868/6669


2596/6669
2597/6669
2598/6669
2599/6669
2600/6669
2601/6669
2602/6669
2603/6669
2604/6669
2605/6669
2606/6669
2607/6669
2608/6669
2609/6669
2610/6669
2611/6669
2612/6669
2613/6669
2614/6669
2615/6669
2616/6669
2617/6669
2618/6669
2619/6669
2620/6669
2621/6669
2622/6669
2623/6669
2624/6669
2625/6669
2626/6669
2627/6669
2628/6669
2629/6669
2630/6669
2631/6669
2632/6669
2633/6669
2634/6669
2635/6669
2636/6669
2637/6669
2638/6669
2639/6669
2640/6669
2641/6669
2642/6669
2643/6669
2644/6669
2645/6669
2646/6669
2647/6669
2648/6669
2649/6669
2650/6669
2651/6669
2652/6669
2653/6669
2654/6669
2655/6669
2656/6669
2657/6669
2658/6669
2659/6669
2660/6669
2661/6669
2662/6669
2663/6669
2664/6669
2665/6669
2666/6669
2667/6669
2668/6669
2669/6669
2670/6669
2671/6669
2672/6669
2673/6669
2674/6669
2675/6669
2676/6669
2677/6669
2678/6669
2679/6669
2680/6669
2681/6669
2682/6669
2683/6669
2684/6669
2685/6669
2686/6669
2687/6669
2688/6669
2689/6669
2690/6669
2691/6669
2692/6669
2693/6669
2694/6669
2695/6669


3425/6669
3426/6669
3427/6669
3428/6669
3429/6669
3430/6669
3431/6669
3432/6669
3433/6669
3434/6669
3435/6669
3436/6669
3437/6669
3438/6669
3439/6669
3440/6669
3441/6669
3442/6669
3443/6669
3444/6669
3445/6669
3446/6669
3447/6669
3448/6669
3449/6669
3450/6669
3451/6669
3452/6669
3453/6669
3454/6669
3455/6669
3456/6669
3457/6669
3458/6669
3459/6669
3460/6669
3461/6669
3462/6669
3463/6669
3464/6669
3465/6669
3466/6669
3467/6669
3468/6669
3469/6669
3470/6669
3471/6669
3472/6669
3473/6669
3474/6669
3475/6669
3476/6669
3477/6669
3478/6669
3479/6669
3480/6669
3481/6669
3482/6669
3483/6669
3484/6669
3485/6669
3486/6669
3487/6669
3488/6669
3489/6669
3490/6669
3491/6669
3492/6669
3493/6669
3494/6669
3495/6669
3496/6669
3497/6669
3498/6669
3499/6669
3500/6669
3501/6669
3502/6669
3503/6669
3504/6669
3505/6669
3506/6669
3507/6669
3508/6669
3509/6669
3510/6669
3511/6669
3512/6669
3513/6669
3514/6669
3515/6669
3516/6669
3517/6669
3518/6669
3519/6669
3520/6669
3521/6669
3522/6669
3523/6669
3524/6669


4256/6669
4257/6669
4258/6669
4259/6669
4260/6669
4261/6669
4262/6669
4263/6669
4264/6669
4265/6669
4266/6669
4267/6669
4268/6669
4269/6669
4270/6669
4271/6669
4272/6669
4273/6669
4274/6669
4275/6669
4276/6669
4277/6669
4278/6669
4279/6669
4280/6669
4281/6669
4282/6669
4283/6669
4284/6669
4285/6669
4286/6669
4287/6669
4288/6669
4289/6669
4290/6669
4291/6669
4292/6669
4293/6669
4294/6669
4295/6669
4296/6669
4297/6669
4298/6669
4299/6669
4300/6669
4301/6669
4302/6669
4303/6669
4304/6669
4305/6669
4306/6669
4307/6669
4308/6669
4309/6669
4310/6669
4311/6669
4312/6669
4313/6669
4314/6669
4315/6669
4316/6669
4317/6669
4318/6669
4319/6669
4320/6669
4321/6669
4322/6669
4323/6669
4324/6669
4325/6669
4326/6669
4327/6669
4328/6669
4329/6669
4330/6669
4331/6669
4332/6669
4333/6669
4334/6669
4335/6669
4336/6669
4337/6669
4338/6669
4339/6669
4340/6669
4341/6669
4342/6669
4343/6669
4344/6669
4345/6669
4346/6669
4347/6669
4348/6669
4349/6669
4350/6669
4351/6669
4352/6669
4353/6669
4354/6669
4355/6669


5095/6669
5096/6669
5097/6669
5098/6669
5099/6669
5100/6669
5101/6669
5102/6669
5103/6669
5104/6669
5105/6669
5106/6669
5107/6669
5108/6669
5109/6669
5110/6669
5111/6669
5112/6669
5113/6669
5114/6669
5115/6669
5116/6669
5117/6669
5118/6669
5119/6669
5120/6669
5121/6669
5122/6669
5123/6669
5124/6669
5125/6669
5126/6669
5127/6669
5128/6669
5129/6669
5130/6669
5131/6669
5132/6669
5133/6669
5134/6669
5135/6669
5136/6669
5137/6669
5138/6669
5139/6669
5140/6669
5141/6669
5142/6669
5143/6669
5144/6669
5145/6669
5146/6669
5147/6669
5148/6669
5149/6669
5150/6669
5151/6669
5152/6669
5153/6669
5154/6669
5155/6669
5156/6669
5157/6669
5158/6669
5159/6669
5160/6669
5161/6669
5162/6669
5163/6669
5164/6669
5165/6669
5166/6669
5167/6669
5168/6669
5169/6669
5170/6669
5171/6669
5172/6669
5173/6669
5174/6669
5175/6669
5176/6669
5177/6669
5178/6669
5179/6669
5180/6669
5181/6669
5182/6669
5183/6669
5184/6669
5185/6669
5186/6669
5187/6669
5188/6669
5189/6669
5190/6669
5191/6669
5192/6669
5193/6669
5194/6669


5933/6669
5934/6669
5935/6669
5936/6669
5937/6669
5938/6669
5939/6669
5940/6669
5941/6669
5942/6669
5943/6669
5944/6669
5945/6669
5946/6669
5947/6669
5948/6669
5949/6669
5950/6669
5951/6669
5952/6669
5953/6669
5954/6669
5955/6669
5956/6669
5957/6669
5958/6669
5959/6669
5960/6669
5961/6669
5962/6669
5963/6669
5964/6669
5965/6669
5966/6669
5967/6669
5968/6669
5969/6669
5970/6669
5971/6669
5972/6669
5973/6669
5974/6669
5975/6669
5976/6669
5977/6669
5978/6669
5979/6669
5980/6669
5981/6669
5982/6669
5983/6669
5984/6669
5985/6669
5986/6669
5987/6669
5988/6669
5989/6669
5990/6669
5991/6669
5992/6669
5993/6669
5994/6669
5995/6669
5996/6669
5997/6669
5998/6669
5999/6669
6000/6669
6001/6669
6002/6669
6003/6669
6004/6669
6005/6669
6006/6669
6007/6669
6008/6669
6009/6669
6010/6669
6011/6669
6012/6669
6013/6669
6014/6669
6015/6669
6016/6669
6017/6669
6018/6669
6019/6669
6020/6669
6021/6669
6022/6669
6023/6669
6024/6669
6025/6669
6026/6669
6027/6669
6028/6669
6029/6669
6030/6669
6031/6669
6032/6669


### Data preparation

In [5]:
dir_input = ('pathway/input'+str(radius)+'/')

molecules  = load_tensor(dir_input + 'molecules', torch.FloatTensor)
properties = load_numpy(dir_input + 'properties')
maccs      = load_numpy(dir_input + 'maccs')


with open(dir_input + 'fingerprint_dict.pickle', 'rb') as f:
    fingerprint_dict = pickle.load(f)
    
fingerprint_dict = load_pickle(dir_input + 'fingerprint_dict.pickle')
unknown          = 100
n_fingerprint    = len(fingerprint_dict) + unknown

my_maccs = []
for i in range(len(molecules)):
    target_mol = (n_fingerprint-1)*torch.ones([259], dtype=torch.float, device=device)
    target_mol[:molecules[i].size()[0]] = molecules[i]
    my_maccs.append(np.concatenate((target_mol.cpu().data.numpy(),maccs[i]), axis=0))

dataset = list(zip(properties, my_maccs))
dataset = shuffle_dataset(dataset, 4123)
dataset_train, dataset_   = split_dataset(dataset, 0.8)
dataset_dev, dataset_test = split_dataset(dataset_, 0.5)


data_batch = list(zip(*dataset_train))
properties_train, maccs_train = data_batch[-2], data_batch[-1]

data_batch = list(zip(*dataset_dev))
properties_dev, maccs_dev = data_batch[-2], data_batch[-1]

data_batch = list(zip(*dataset_test))
properties_test, maccs_test = data_batch[-2], data_batch[-1]

train_len, dev_len, test_len = len(dataset_train), len(dataset_dev), len(dataset_test)

feature_len = maccs_train[0].shape[0]

X_train, X_dev, X_test = np.zeros((train_len,feature_len)), np.zeros((dev_len,feature_len)), np.zeros((test_len,feature_len))
Y_train, Y_dev, Y_test = np.zeros((train_len,11)), np.zeros((dev_len,11)), np.zeros((test_len,11))

for i in range(train_len):
    X_train[i,:] = maccs_train[i]
    Y_train[i] = properties_train[i][0]
    
for i in range(dev_len):
    X_dev[i,:]   = maccs_dev[i]
    Y_dev[i]   = properties_dev[i][0]
    
for i in range(test_len):
    X_test[i,:]  = maccs_test[i]
    Y_test[i]  = properties_test[i][0]

### Train and analyze classifier

In [6]:
clf = RandomForestClassifier(n_estimators=300, criterion = 'gini', max_depth=60, random_state=0)
multi_target_forest = MultiOutputClassifier(clf, n_jobs=-1)
multi_target_forest.fit(X_train, Y_train)

MultiOutputClassifier(estimator=RandomForestClassifier(bootstrap=True,
                                                       class_weight=None,
                                                       criterion='gini',
                                                       max_depth=60,
                                                       max_features='auto',
                                                       max_leaf_nodes=None,
                                                       min_impurity_decrease=0.0,
                                                       min_impurity_split=None,
                                                       min_samples_leaf=1,
                                                       min_samples_split=2,
                                                       min_weight_fraction_leaf=0.0,
                                                       n_estimators=300,
                                                       n_jobs=None,
                      

### Test set prediction accuracy

In [7]:
Y_pred = multi_target_forest.predict(X_test)

acc_score, prec_score, rec_score = 0., 0., 0.
for i in range(Y_test.shape[0]):
    acc_score  += accuracy_score(Y_test[i],Y_pred[i])
    prec_score += precision_score(Y_test[i],Y_pred[i])
    rec_score  += recall_score(Y_test[i],Y_pred[i])

acc_score  = acc_score/Y_test.shape[0]
prec_score = prec_score/Y_test.shape[0]
rec_score  = rec_score/Y_test.shape[0]

print('Accuracy : %.4f%%, \t Precision : %.4f%%, \t, Recall : %.4f%%' %(acc_score, prec_score, rec_score))

  'precision', 'predicted', average, warn_for)


Accuracy : 0.9757%, 	 Precision : 0.8342%, 	, Recall : 0.8353%


### Dev set prediction accuracy

In [8]:
Y_pred = multi_target_forest.predict(X_dev)

acc_score, prec_score, rec_score = 0., 0., 0.
for i in range(Y_dev.shape[0]):
    acc_score  += accuracy_score(Y_dev[i],Y_pred[i])
    prec_score += precision_score(Y_dev[i],Y_pred[i])
    rec_score  += recall_score(Y_dev[i],Y_pred[i])

acc_score  = acc_score/Y_dev.shape[0]
prec_score = prec_score/Y_dev.shape[0]
rec_score  = rec_score/Y_dev.shape[0]

print('Accuracy : %.4f%%, \t Precision : %.4f%%, \t, Recall : %.4f%%' %(acc_score, prec_score, rec_score))

  'precision', 'predicted', average, warn_for)


Accuracy : 0.9790%, 	 Precision : 0.8478%, 	, Recall : 0.8445%
