In [4]:
import os
import argparse
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from rdkit import Chem
from tqdm import tqdm

from model import Model
from utils import set_seed
from torch.utils.data import DataLoader
from preprocessing import load_dataset, dataset_split
from utils import collate
from model_interpret import conformer_attn

## Step1: Load dataset and pretrained model

In [275]:
parser = argparse.ArgumentParser()
args = parser.parse_args(args=[])
args.task = 'cdk2'
args.device_id = 0
args.saved_path = '0.001_0_200_0.05_128_128_2'
args.seed = 0
args.task_type = 'regression'
args.split = 'random'
args.hidden_size = 128
args.num_layer = 2
args.topk = 3
args.dropout = 0.05
args.batch_size = 128
args.data_path = '../data'
args.save_path = os.path.join('../result', args.task,
                              args.saved_path + '_' + str(args.seed))
set_seed(args.seed)
dataset = load_dataset(args)
print(f'Load dataset done! Total {len(dataset)} samples.')
if args.device_id != 'cpu':
    print('Validating on GPU')
    device = torch.device('cuda:{}'.format(args.device_id))
else:
    print('Validating on CPU')
    device = torch.device('cpu')

seed = 0
Load dataset done! Total 29869 samples.
Validating on GPU


In [301]:
data_loader = DataLoader(dataset,
                         batch_size=args.batch_size,
                         collate_fn=collate,
                         shuffle=False, drop_last=False)
model = Model_2(in_size=dataset[0][1].ndata['node_feat'].shape[-1],
                      hidden_size=args.hidden_size,
                      edge_feat_size=dataset[0][1].edata['edge_feat'].shape[-1],
                      num_layer=args.num_layer,
                      topk=args.topk,
                      dropout=args.dropout)
model.load_state_dict(
    torch.load(os.path.join(args.save_path, 'model.pkl')))
model = model.to(device)
print('Load model done!')

Load model done!


In [144]:
model

Model_2(
  (egnn): ModuleDict(
    (EGNNLayer_0): EGNNConv_2(
      (dropout): Dropout(p=0.05, inplace=False)
      (edge_mlp): Sequential(
        (0): Linear(in_features=270, out_features=128, bias=True)
        (1): SiLU()
        (2): Linear(in_features=128, out_features=128, bias=True)
        (3): SiLU()
      )
      (node_mlp): Sequential(
        (0): Linear(in_features=256, out_features=128, bias=True)
        (1): SiLU()
        (2): Linear(in_features=128, out_features=128, bias=True)
      )
      (coord_mlp): Sequential(
        (0): Linear(in_features=128, out_features=128, bias=True)
        (1): SiLU()
        (2): Linear(in_features=128, out_features=1, bias=False)
      )
    )
    (EGNNLayer_1): EGNNConv_2(
      (dropout): Dropout(p=0.05, inplace=False)
      (edge_mlp): Sequential(
        (0): Linear(in_features=270, out_features=128, bias=True)
        (1): SiLU()
        (2): Linear(in_features=128, out_features=128, bias=True)
        (3): SiLU()
      )
     

## Setp2: Model interpretation for conformer discovery

In [281]:
df, all_pred = conformer_attn(args, data_loader, model, device)
for col in df.columns[1:]:
    df[col] = df[col].astype(float)

In [285]:
df[df['SMILES']=='Nc1ccc(-c2cc(Nc3ccc(S(N)(=O)=O)cc3)[nH]n2)cc1']

Unnamed: 0,SMILES,Label,Predict,Attn_0,Attn_1,Attn_2,Attn_3,Attn_4,Attn_5,Attn_6,Attn_7,Attn_8,Attn_9
1889,Nc1ccc(-c2cc(Nc3ccc(S(N)(=O)=O)cc3)[nH]n2)cc1,9.0,8.890817,0.113669,0.08537,0.08177,0.124741,0.102913,0.072717,0.080095,0.094559,0.133758,0.110406
3118,Nc1ccc(-c2cc(Nc3ccc(S(N)(=O)=O)cc3)[nH]n2)cc1,9.0,8.890817,0.113669,0.08537,0.08177,0.124741,0.102913,0.072717,0.080095,0.094559,0.133758,0.110406
4347,Nc1ccc(-c2cc(Nc3ccc(S(N)(=O)=O)cc3)[nH]n2)cc1,9.0,8.890817,0.113669,0.08537,0.08177,0.124741,0.102913,0.072717,0.080095,0.094559,0.133758,0.110406
5576,Nc1ccc(-c2cc(Nc3ccc(S(N)(=O)=O)cc3)[nH]n2)cc1,9.0,8.890817,0.113669,0.08537,0.08177,0.124741,0.102913,0.072717,0.080095,0.094559,0.133758,0.110406
6805,Nc1ccc(-c2cc(Nc3ccc(S(N)(=O)=O)cc3)[nH]n2)cc1,9.0,8.890817,0.113669,0.08537,0.08177,0.124741,0.102913,0.072717,0.080095,0.094559,0.133758,0.110406
8034,Nc1ccc(-c2cc(Nc3ccc(S(N)(=O)=O)cc3)[nH]n2)cc1,9.0,8.890817,0.113669,0.08537,0.08177,0.124741,0.102913,0.072717,0.080095,0.094559,0.133758,0.110406
9263,Nc1ccc(-c2cc(Nc3ccc(S(N)(=O)=O)cc3)[nH]n2)cc1,9.0,8.890817,0.11367,0.08537,0.08177,0.124741,0.102913,0.072717,0.080095,0.094559,0.133758,0.110406
10492,Nc1ccc(-c2cc(Nc3ccc(S(N)(=O)=O)cc3)[nH]n2)cc1,9.0,8.890817,0.113669,0.08537,0.08177,0.124741,0.102913,0.072717,0.080095,0.094559,0.133758,0.110406
11721,Nc1ccc(-c2cc(Nc3ccc(S(N)(=O)=O)cc3)[nH]n2)cc1,9.0,8.890817,0.113669,0.08537,0.08177,0.124741,0.102913,0.072717,0.080095,0.094559,0.133758,0.110406


In [286]:
np.array(all_pred)[1889]

array([8.92426872, 8.85383034, 8.91606331, 8.94812584, 8.93545151,
       8.88488102, 8.79909325, 8.93889618, 8.95968819, 8.93560696])

In [149]:
df.to_csv('../interpret_case_conformer.csv')

## Step 3: Manually align the conformers and the attention coefficients

In [267]:
mol = pd.read_csv('../data/molecule_structure.csv')
mol[mol['SMILES'] == 
    'Nc1ccc(-c2cc(Nc3ccc(S(N)(=O)=O)cc3)[nH]n2)cc1'
   ]['Conformer_path'].values

array(['/home/gyw/master thesis/data/molecule_structure/440800.sdf'],
      dtype=object)

In [163]:
mol_suppl = Chem.SDMolSupplier('/home/gyw/master thesis/data/'
                               'molecule_structure/440800.sdf')
for idx, mol in enumerate(mol_suppl):
    writer = Chem.SDWriter(f'../{idx}_mol.sdf')
    writer.write(mol)
    writer.close()