# Library

In [None]:
# My library
from molgraph.dataset import *
from molgraph.graphmodel import *
from molgraph.training import *
from molgraph.testing import *
from molgraph.visualize import *
from molgraph.interpret import *
from molgraph.experiment import *
# General library
import argparse
import numpy as np
import os
# pytorch
import torch
import pytorch_lightning as pl

# Ensure that all operations are deterministic on GPU (if used) for reproducibility
torch.backends.cudnn.determinstic = True
torch.backends.cudnn.benchmark = False

# Argument

In [None]:
parser = ArgumentParser()
args = parser.getArgument('''
--file bbbp
--model GIN
--schema AR_0
--reduced pharmacophore
--vocab_len 100
--mol_embedding 256
--batch_normalize
--fold 5
'''.split())

args

# Dataset

In [None]:
file = args.file
smiles = args.smiles 
task = args.task
splitting = args.splitting 
splitting_fold = args.fold
splitting_seed = args.splitting_seed

# get validated dataset
datasets = getDataset(file, smiles, task, splitting)
# compute positive weight for classification
if args.graphtask == 'classification':
    args.pos_weight = getPosWeight(datasets)
    print('pos_weight:', args.pos_weight)
# generate dataset splitting
datasets_splitted = generateDatasetSplitting(file, splitting, splitting_fold, splitting_seed)
# generate all graph dataset
datasets_graph = generateGraphDataset(file)
# generate all reduced graph dataset
dict_reducedgraph = dict()
for g in args.reduced:
    if g == 'substructure':
        for i in range(splitting_fold):
            vocab_file = file+'_'+str(i)
            if not os.path.exists('vocab/'+vocab_file+'.txt'):
                generateVocabTrain(file, splitting_seed, splitting_fold, vocab_len=args.vocab_len)
            dict_reducedgraph[g] = generateReducedGraphDict(file, g, vocab_file=vocab_file)
    else:
        dict_reducedgraph[g] = generateReducedGraphDict(file, g)

# Test

In [None]:
args_test = dict()

# Load model
ts = "2023-Apr-29-17:22:22"
reduced_list = '_'.join(args.reduced)
args_test['log_folder_name'] = os.path.join(*[args.file, args.model+'_'+args.schema+'_'+reduced_list, f"{ts}"])
args_test['exp_name'] = args.experiment_number
args_test['fold_number'] = 1
args_test['seed'] = args.seed

In [None]:
training_bin = torch.load('./dataset/'+args_test['log_folder_name']+'/checkpoints/training_args.bin')

args.batch_size = training_bin.batch_size
args.num_layers = training_bin.num_layers
args.num_layers_reduced = training_bin.num_layers_reduced
args.in_channels = training_bin.in_channels
args.hidden_channels = training_bin.hidden_channels
args.out_channels = training_bin.out_channels
args.edge_dim = training_bin.edge_dim
args.num_layers_self = training_bin.num_layers_self
args.num_layers_self_reduced = training_bin.num_layers_self_reduced
args.dropout = training_bin.dropout
args.lr = training_bin.lr
args.weight_decay = training_bin.weight_decay

args

In [None]:
# # test with dataset
# # test_loader, datasets_test =  generateDataLoaderTesting(datasets_graph[1], args.batch_size)
# # test_loader, datasets_test =  generateDataLoaderTesting([datasets_graph[1][360]], 1)
# test_loader, datasets_test =  generateDataLoaderTesting(args.file, 1)
# sample_to_test = datasets_test[0]
# test_loader, datasets_test =  generateDataLoaderListing([sample_to_test], 1)

# molecule_test = datasets_test[0]
# smiles_processes = molecule_test.smiles
# print(molecule_test)


In [None]:
# test with t = sample
smiles_processes = mol_to_smiles(smiles_to_mol('COC(N)=O', with_atom_index=False))
molecule_test = [constructGraph(smiles_processes, 0.74)]
test_loader = DataLoader(molecule_test, batch_size=1, shuffle=True, follow_batch=['x_g', 'x_r'])
molecule_test = molecule_test[0]
print(molecule_test)

In [None]:
tester = Tester(args, args_test)
# tester.test(test_loader, return_attention_weights=True)
tester.test_single(test_loader, return_attention_weights=True)

In [None]:
# att = tester.getAttention()
att_mol = tester.getAttentionMol()
if 'atom' in att_mol:
    if len(args.reduced) >= 1:
        sample_att = (att_mol['atom'], att_mol[args.reduced[0]])
    else:
        sample_att = (att_mol['atom'], None)
else:
    sample_att = (None, att_mol[args.reduced[0]])
sample_graph = molecule_test

In [None]:
plot_attentions(args, sample_graph, sample_att)

In [None]:
smiles = smiles_processes
mol = smiles_to_mol(smiles, with_atom_index=False)

# reduced graph
if args.schema in ['A']:
    reduced_graph, cliques, edges = getReducedGraph(args, ['atom'], smiles, normalize=False)
else:
    reduced_graph, cliques, edges = getReducedGraph(args, args.reduced, smiles, normalize=False)

sample_att_g, sample_att_r = sample_att
if args.schema in ['A', 'R_N', 'AR', 'AR_0', 'AR_N']:
    mask_graph_g = mask_graph(sample_att_g)
if args.schema in ['R', 'R_0', 'R_N', 'AR', 'AR_0', 'AR_N']:
        mask_graph_r = mask_reduced(sample_att_r)   
if not args.schema in ['A']:
    mask_graph_x = mask_rtog(smiles, cliques, mask_graph_r)
    if args.schema in ['AR', 'AR_0', 'AR_N']:
        mask_graph_x = mask_gandr(mask_graph_g, mask_graph_x)
        display_interpret_weight(mol, None, None, mask_graph_x, None, scale=True)

## Running

In [None]:
all_dataset = datasets_graph
print('Number of dataset:', len(all_dataset))

In [None]:
feature_embedding =  dict()
feature_importance =  dict()

count = 0 
tester = Tester(args, args_test)

b = False
count_outlier = 0

for d in tqdm(all_dataset):
    # data loader
    test_loader, datasets_test =  generateDataLoaderListing([all_dataset[d]], 1)
    molecule_test = datasets_test[0]

    # testing
    # print(datasets_test)
    try:
        predicted = tester.test_single(test_loader, return_attention_weights=True, print_result=False)
    except:
        predicted = None

    # if predicted != molecule_test.y:
    #     print(molecule_test.smiles, "TRUE:", molecule_test.y, "PREDICTED:", predicted)

    if predicted is not None:
        # print(predicted)
        try:
            predicted = predicted.item()
        except:
            predicted = predicted[0][0]

        count += 1

        # embedding result
        emb_mol = tester.getXEmbed()

        # attention result
        att_mol = tester.getAttentionMol()
        sample_att = att_mol
        sample_graph = molecule_test
        if 'atom' in sample_att:
            sample_att_g = sample_att['atom']
        else:
            sample_att_g = None
        if len(args.reduced) != 0:
            sample_att_r = sample_att[args.reduced[0]]
        else:
            sample_att_r = None
        # sample_att_g, sample_att_r = sample_att
        if args.schema in ['A', 'R_N', 'AR', 'AR_0', 'AR_N']:
            mask_graph_g = mask_graph(sample_att_g)
        if args.schema in ['R', 'R_0', 'R_N', 'AR', 'AR_0', 'AR_N']:
            mask_graph_r = mask_reduced(sample_att_r)

        # molecule
        smiles = sample_graph.smiles
        mol = Chem.MolFromSmiles(smiles)

        if smiles not in feature_importance:
            feature_embedding[smiles] = dict()
            feature_importance[smiles] = dict()

        # record importance
        mask_graph_x = None

        # if args.schema in ['A']:
        if 'A' in args.schema:
            mask_graph_x = mask_graph_g
            # reduced graph
            reduced_graph, cliques, edges = getReducedGraph(args, ['atom'], smiles, normalize=False)

            # embedding
            feature_embedding[smiles]['atom'] = emb_mol[0][:512]
            # important
            feature_importance[smiles]['atom'] = mask_graph_x['atom']

        # elif args.schema in ['R', 'R_0', 'R_N', 'AR', 'AR_0', 'AR_N']:
        if 'R' in args.schema:
            mask_graph_x = mask_graph_r
            # reduced graph
            reduced_graph, cliques, edges = getReducedGraph(args, args.reduced, smiles, normalize=False)

            if not args.schema in ['A']:
                mask_graph_x = mask_rtog(smiles, cliques, mask_graph_r)
                # if args.schema in ['AR', 'AR_0', 'AR_N']:
                #     mask_graph_x = mask_gandr(mask_graph_g, mask_graph_x)

            for i, r in enumerate(args.reduced):
                # embedding 
                if 'A' in args.schema:
                    feature_embedding[smiles][r] = emb_mol[0][256*(i+1):256*(i+2)]
                else:
                    feature_embedding[smiles][r] = emb_mol[0][256*i:256*(i+1)]
                # important
                feature_importance[smiles][r]= mask_graph_x['atom']
    # break

# feature_embedding

In [None]:
path = './dataset/'+args_test['log_folder_name']+'/embedding'+str(args_test['fold_number'])+'.pickle'
with open(path, 'wb') as handle:
    pickle.dump(feature_embedding, handle, protocol=pickle.HIGHEST_PROTOCOL)

# feature_importance

In [None]:
path = './dataset/'+args_test['log_folder_name']+'/attention'+str(args_test['fold_number'])+'.pickle'
with open(path, 'wb') as handle:
    pickle.dump(feature_importance, handle, protocol=pickle.HIGHEST_PROTOCOL)