In [1]:
import numpy as np
import pickle
from sklearn.metrics import roc_auc_score, precision_score, recall_score
from sklearn.model_selection import train_test_split
import timeit
import sys

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [3]:
class GraphNeuralNetwork(nn.Module):
    def __init__(self, dim, n_fingerprint, hidden_layer, output_layer, update_func, output_func):
        super(GraphNeuralNetwork, self).__init__()
        self.embed_fingerprint = nn.Embedding(n_fingerprint, dim)
        self.W_fingerprint = nn.ModuleList([nn.Linear(dim, dim)] * hidden_layer)
        self.W_output = nn.ModuleList([nn.Linear(dim, dim)] * output_layer)
        self.W_property = nn.Linear(dim, 2)
        self.update_func = update_func
        self.output_func = output_func

    def pad(self, matrices, pad_value):
        '''Pad adjacency matrices for batch processing.'''
        sizes = [m.shape[0] for m in matrices]
        M = sum(sizes)
        pad_matrices = pad_value + np.zeros((M, M))
        i = 0
        for j, m in enumerate(matrices):
            j = sizes[j]
            pad_matrices[i:i+j, i:i+j] = m
            i += j
        return torch.FloatTensor(pad_matrices)

    def sum_axis(self, xs, axis):
        y = [torch.sum(x, 0) for x in torch.split(xs, axis)]
        return torch.stack(y)

    def mean_axis(self, xs, axis):
        y = [torch.mean(x, 0) for x in torch.split(xs, axis)]
        return torch.stack(y)

    def update(self, xs, A, M, i):
        '''Update the node vectors in a graph considering their neighboring node vectors (i.e., sum or mean),
        which are non-linear transformed by neural network.'''
        hs = torch.relu(self.W_fingerprint[i](xs))
        if self.update_func == 'sum':
            return xs + torch.matmul(A, hs)
        if self.update_func == 'mean':
            return xs + torch.matmul(A, hs) / (M-1)

    def forward(self, inputs, device):
        Smiles, fingerprints, adjacencies = inputs
        axis = [len(f) for f in fingerprints]

        M = np.concatenate([np.repeat(len(f), len(f)) for f in fingerprints])
        M = torch.unsqueeze(torch.FloatTensor(M), 1)

        fingerprints = torch.cat(fingerprints)
        fingerprint_vectors = self.embed_fingerprint(fingerprints)

        adjacencies = self.pad(adjacencies, 0).to(device)

        # GNN updates the fingerprint vectors.
        for i in range(len(self.W_fingerprint)):
            fingerprint_vectors = self.update(fingerprint_vectors, adjacencies, M, i)

        if self.output_func == 'sum':
            molecular_vectors = self.sum_axis(fingerprint_vectors, axis)
        if self.output_func == 'mean':
            molecular_vectors = self.mean_axis(fingerprint_vectors, axis)

        for j in range(len(self.W_output)):
            molecular_vectors = torch.relu(self.W_output[j](molecular_vectors))

        molecular_properties = self.W_property(molecular_vectors)

        return Smiles, molecular_properties

In [4]:
def train(dataset, model, optimizer, batch, device):
    model.train()
    loss_total = 0
    for i in range(0, len(dataset), batch):
        data_batch = list(zip(*dataset[i:i+batch]))       
        inputs = data_batch[:-1]
        correct_properties = torch.cat(data_batch[-1])
        Smiles, predicted_properties = model.forward(inputs, device)
        loss = F.cross_entropy(predicted_properties, correct_properties)    
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        loss_total += loss.to('cpu').data.numpy()
    return loss_total

In [5]:
def test(dataset, model, batch, device):
    model.eval()
    SMILES, Ts, Ys, Ss = '', [], [], []

    for i in range(0, len(dataset), batch):
        data_batch = list(zip(*dataset[i:i+batch]))
        inputs = data_batch[:-1]
        correct_properties = torch.cat(data_batch[-1])
        Smiles, predicted_properties = model.forward(inputs, device)

        correct_labels = correct_properties.to('cpu').data.numpy()
        ys = F.softmax(predicted_properties, 1).to('cpu').data.numpy()
        predicted_labels = [np.argmax(y) for y in ys]
        predicted_scores = [x[1] for x in ys]
        
        SMILES += ' '.join(Smiles) + ' '
        Ts.append(correct_labels)
        Ys.append(predicted_labels)
        Ss.append(predicted_scores)

    SMILES = SMILES.strip().split()
    T = np.concatenate(Ts)
    Y = np.concatenate(Ys)
    S = np.concatenate(Ss)

    AUC = roc_auc_score(T, S)
    precision = precision_score(T, Y)
    recall = recall_score(T, Y)

    T, Y, S = map(str, T), map(str, Y), map(str, S)
    predictions = '\n'.join(['\t'.join(p) for p in zip(SMILES, T, Y, S)])

    return AUC, precision, recall, predictions

In [6]:
def load_tensor(filename, dtype, device):
    return [dtype(d).to(device) for d in np.load(filename + '.npy', allow_pickle=True)]

In [9]:
def main():
    '''Hyperparameters.'''
    DATASET = 'HIV'
    #DATASET = yourdata

    #radius = 1
    radius = 2
    #radius = 3

    update_func = 'sum'
    #update_func = 'mean'

    #output_func = 'sum'
    output_func = 'mean'

    dim = 25
    hidden_layer = 6
    output_layer = 3
    batch = 32
    lr = 1e-3
    lr_decay = 0.9
    decay_interval = 10
    weight_decay = 1e-6
    
    iteration = 30
    
    setting = 'default'

    # CPU or GPU.
    if torch.cuda.is_available():
        device = torch.device('cuda')
        print('The code uses GPU...')
    else:
        device = torch.device('cpu')
        print('The code uses CPU...')

    # Load preprocessed data.
    dir_input = '../../dataset/classification/%s/input/radius%d/' % (DATASET, radius)
    with open(dir_input + 'Smiles.txt') as f:
        Smiles = f.read().strip().split()
    molecules = load_tensor(dir_input + 'molecules', torch.LongTensor, device)
    adjacencies = np.load(dir_input + 'adjacencies' + '.npy', allow_pickle=True)
    properties = load_tensor(dir_input + 'properties', torch.LongTensor, device)
    with open(dir_input + 'fingerprint_dict.pkl', 'rb') as f:
        fingerprint_dict = pickle.load(f)
    n_fingerprint = len(fingerprint_dict)

    # Create a dataset and split it into train/test.
    dataset = list(zip(Smiles, molecules, adjacencies, properties))
    np.random.shuffle(dataset)
    dataset_train, dataset_test = train_test_split(dataset, train_size=0.8, test_size=0.2)
    print(len(dataset), len(dataset_train), len(dataset_test))

    # Set a model.
    model = GraphNeuralNetwork(dim, n_fingerprint, hidden_layer, output_layer, update_func, output_func)
    model = model.to(device)
    
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    
    # Output files.
    file_AUCs = '../../output/result/AUCs--%s.txt' % setting
    file_predictions = '../../output/result/predictions--%s.txt' % setting
    file_model = '../../output/model/%s.pth' % setting
    columns = ['Epoch', 'Time(sec)', 'Loss_train', 'AUC_test', 'Precision_test', 'Recall_test']
    AUCs = '\t'.join(columns)

    # Start training.
    print('Training...')
    print(AUCs)
    start = timeit.default_timer()

    for epoch in range(1, iteration):
        if epoch % decay_interval == 0:
            optimizer.param_groups[0]['lr'] *= lr_decay

        loss_train = train(dataset_train, model, optimizer, batch, device)
        AUC_test, precision_test, recall_test, predictions_test = test(dataset_test, model, batch, device)

        time = timeit.default_timer() - start

        values = [time, loss_train, AUC_test, precision_test, recall_test]
        AUCs = str(epoch) + '\t' + '\t'.join(map(lambda x: '%.3f' % x, values))
        print(AUCs)

    with open(file_predictions, 'w') as out:
        out.write('\t'.join(['Smiles', 'Correct', 'Predict', 'Score']) + '\n')
        out.write(predictions_test + '\n')
    torch.save(model.state_dict(), file_model)

In [10]:
if __name__ == '__main__':
    np.random.seed(123)
    torch.manual_seed(123)
    main()

The code uses GPU...
38775 31020 7755
Training...
Epoch	Time(sec)	Loss_train	AUC_test	Precision_test	Recall_test


  'precision', 'predicted', average, warn_for)


1	87.231	138.032	0.695	0.000	0.000
2	196.894	112.589	0.747	0.800	0.017
3	306.881	90.401	0.767	0.756	0.130
4	416.464	75.353	0.771	0.787	0.155
5	526.385	64.130	0.774	0.662	0.197
6	635.701	55.514	0.775	0.622	0.192
7	750.605	48.403	0.773	0.512	0.264
8	868.812	42.889	0.772	0.472	0.280
9	988.080	38.130	0.771	0.434	0.301
10	1106.810	34.559	0.768	0.463	0.289
11	1225.976	32.697	0.769	0.412	0.314
12	1311.126	30.445	0.772	0.373	0.339
13	1371.624	29.736	0.771	0.351	0.301
14	1432.867	27.174	0.772	0.344	0.326
15	1493.543	26.141	0.773	0.362	0.314
16	1553.911	24.163	0.775	0.446	0.259
17	1621.277	23.137	0.769	0.397	0.289
18	1689.844	23.509	0.770	0.477	0.218
19	1750.227	22.008	0.768	0.404	0.230
20	1810.595	21.805	0.769	0.468	0.272
21	1870.876	20.075	0.768	0.423	0.243
22	1931.487	19.065	0.766	0.446	0.259
23	1992.425	20.041	0.770	0.396	0.272
24	2052.943	18.371	0.769	0.438	0.264
25	2113.278	17.863	0.772	0.383	0.268
26	2173.751	17.383	0.768	0.424	0.280
27	2234.771	17.504	0.765	0.397	0.305
28	2295.346	17.711