In [1]:
import torch
import torch_geometric

import pandas as pd
import numpy as np

import rdkit
import rdkit.Chem
import rdkit.Chem.AllChem
from rdkit import Chem

from tqdm import tqdm
from copy import deepcopy
import random
import re
import os
import shutil
import argparse
import sys
import ast

from datasets.dataset_3D import *
from models.dimenetpp import *


In [2]:
model_dictionary = {

    ('sec_amine', 'atom', 'NBO_LP_energy', 'Boltz'): 'trained_models/combined_amines/NBO_LP_energy/boltz/model_best.pt',
    ('sec_amine', 'atom', 'NBO_LP_energy', 'max'):   'trained_models/combined_amines/NBO_LP_energy/max/model_best.pt',
    ('sec_amine', 'atom', 'NBO_LP_energy', 'min'):   'trained_models/combined_amines/NBO_LP_energy/min/model_best.pt',
    ('sec_amine', 'atom', 'NBO_LP_energy', 'lowE'): 'trained_models/combined_amines/NBO_LP_energy/min_E/model_best.pt',
    
    ('sec_amine', 'atom', 'NBO_LP_occupancy', 'Boltz'): 'trained_models/combined_amines/NBO_LP_occupancy/boltz/model_best.pt',
    ('sec_amine', 'atom', 'NBO_LP_occupancy', 'max'):   'trained_models/combined_amines/NBO_LP_occupancy/max/model_best.pt',
    ('sec_amine', 'atom', 'NBO_LP_occupancy', 'min'):   'trained_models/combined_amines/NBO_LP_occupancy/min/model_best.pt',
    ('sec_amine', 'atom', 'NBO_LP_occupancy', 'lowE'): 'trained_models/combined_amines/NBO_LP_occupancy/min_E/model_best.pt',
    
    ('sec_amine', 'atom', 'NBO_charge_H', 'Boltz'): 'trained_models/sec_amines/NBO_charge_H/boltz/model_best.pt',
    ('sec_amine', 'atom', 'NBO_charge_H', 'max'):   'trained_models/sec_amines/NBO_charge_H/max/model_best.pt',
    ('sec_amine', 'atom', 'NBO_charge_H', 'min'):   'trained_models/sec_amines/NBO_charge_H/min/model_best.pt',
    ('sec_amine', 'atom', 'NBO_charge_H', 'lowE'): 'trained_models/sec_amines/NBO_charge_H/min_E/model_best.pt',
    
    ('sec_amine', 'atom', 'NMR_shift_H', 'Boltz'): 'trained_models/sec_amines/NMR_shift_H/boltz/model_best.pt',
    ('sec_amine', 'atom', 'NMR_shift_H', 'max'):   'trained_models/sec_amines/NMR_shift_H/max/model_best.pt',
    ('sec_amine', 'atom', 'NMR_shift_H', 'min'):   'trained_models/sec_amines/NMR_shift_H/min/model_best.pt',
    ('sec_amine', 'atom', 'NMR_shift_H', 'lowE'): 'trained_models/sec_amines/NMR_shift_H/min_E/model_best.pt',
    
    ('sec_amine', 'atom', 'Vbur', 'Boltz'): 'trained_models/combined_amines/Vbur/boltz/model_best.pt',
    ('sec_amine', 'atom', 'Vbur', 'max'):   'trained_models/combined_amines/Vbur/max/model_best.pt',
    ('sec_amine', 'atom', 'Vbur', 'min'):   'trained_models/combined_amines/Vbur/min/model_best.pt',
    ('sec_amine', 'atom', 'Vbur', 'lowE'): 'trained_models/combined_amines/Vbur/min_E/model_best.pt',
    
    ('sec_amine', 'atom', 'pyr_agranat', 'Boltz'): 'trained_models/combined_amines/pyr_agranat/boltz/model_best.pt',
    ('sec_amine', 'atom', 'pyr_agranat', 'max'):   'trained_models/combined_amines/pyr_agranat/max/model_best.pt',
    ('sec_amine', 'atom', 'pyr_agranat', 'min'):   'trained_models/combined_amines/pyr_agranat/min/model_best.pt',
    ('sec_amine', 'atom', 'pyr_agranat', 'lowE'): 'trained_models/combined_amines/pyr_agranat/min_E/model_best.pt',
    
    ('sec_amine', 'mol', 'dipole', 'Boltz'): 'trained_models/combined_amines/dipole/boltz/model_best.pt',
    ('sec_amine', 'mol', 'dipole', 'max'):   'trained_models/combined_amines/dipole/max/model_best.pt',
    ('sec_amine', 'mol', 'dipole', 'min'):   'trained_models/combined_amines/dipole/min/model_best.pt',
    ('sec_amine', 'mol', 'dipole', 'lowE'): 'trained_models/combined_amines/dipole/min_E/model_best.pt',
    
}

In [3]:
num_workers = 4
use_atom_features = 1

conformer_data_file = 'data/3D_model_secondaryamine_rdkit_conformers.csv'
conformers_df = pd.read_csv(conformer_data_file).reset_index(drop = True)
conformers_df['mols'] = [rdkit.Chem.MolFromMolBlock(m, removeHs = False) for m in conformers_df.mol_block]
conformers_df['mols_noHs'] = [rdkit.Chem.RemoveHs(m) for m in conformers_df['mols']]


In [None]:
# testing loop definition

def loop(model, batch, property_type = 'bond'):
    
    batch = batch.to(device)
        
    if property_type == 'bond':
        out = model(
            batch.x.squeeze(), 
            batch.pos, 
            batch.batch,
            batch.atom_features,
            select_bond_start_atom_index = batch.bond_start_ID_index,
            select_bond_end_atom_index = batch.bond_end_ID_index,
        )
    
    elif property_type == 'atom':
        out = model(
            batch.x.squeeze(),
            batch.pos, 
            batch.batch,
            batch.atom_features,
            select_atom_index = batch.atom_ID_index,
        )
        
    elif property_type == 'mol':
        out = model(
            batch.x.squeeze(),
            batch.pos,
            batch.batch,
            batch.atom_features,
        )
    
    targets = batch.targets
    pred_targets = out[0].squeeze()
    mse_loss = torch.mean(torch.square(targets - pred_targets))
    mae = torch.mean(torch.abs(targets - pred_targets))    
    
    return targets.detach().cpu().numpy(), pred_targets.detach().cpu().numpy()


In [4]:
for model_selection in model_dictionary:
    
    keep_explicit_hydrogens = model_selection[2] in ['H5_NBO_charge', 'H5_NMR_shift', 'NBO_charge_H', 'NMR_shift_H']
    remove_Hs_except_functional = 2 if keep_explicit_hydrogens else False
    
    mol_type = model_selection[0]
    property_type = model_selection[1]
    prop =  model_selection[2]
    agg = model_selection[3]
    
    if mol_type == 'acid':
        descriptor_data_file = f'data/acid/{property_type}/{prop}_{agg}.csv'
    if mol_type == 'amine':
        try:
            descriptor_data_file = f'data/combined_amine/{property_type}/{prop}_{agg}.csv'
            descriptor_df = pd.read_csv(descriptor_data_file, converters={"bond_atom_tuple": ast.literal_eval})
        except:
            descriptor_data_file = f'data/primary_amine/{property_type}/{prop}_{agg}.csv'
            descriptor_df = pd.read_csv(descriptor_data_file, converters={"bond_atom_tuple": ast.literal_eval})
    if mol_type == 'sec_amine':
        try:
            descriptor_data_file = f'data/combined_amine/{property_type}/{prop}_{agg}.csv'
            descriptor_df = pd.read_csv(descriptor_data_file, converters={"bond_atom_tuple": ast.literal_eval})
        except:
            descriptor_data_file = f'data/secondary_amine/{property_type}/{prop}_{agg}.csv'
            descriptor_df = pd.read_csv(descriptor_data_file, converters={"bond_atom_tuple": ast.literal_eval})
        
    
    pretrained_model = model_dictionary[model_selection]
    
    # -------------------------------------------
    # Loading training data (regression targets and input conformers)
    
    descriptor_df = pd.read_csv(descriptor_data_file, converters={"bond_atom_tuple": ast.literal_eval})
    
    merged_df = conformers_df.merge(descriptor_df, on = 'Name_int')
    test_dataframe = merged_df[merged_df.split == 'test'].reset_index(drop = True)
        
    # -------------------------------------------
    # creating model, optimizer, dataloaders
    
    device = "cpu" #torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    if property_type == 'atom':
        atom_ID_test = np.array(list(test_dataframe.atom_index), dtype = int)
    else:
        atom_ID_test = None
    
    if property_type == 'bond':
        bond_ID_1_test = np.array(list(test_dataframe.bond_atom_tuple), dtype = int)[:, 0]
        bond_ID_2_test = np.array(list(test_dataframe.bond_atom_tuple), dtype = int)[:, 1]
    else:
        bond_ID_1_test = None
        bond_ID_2_test = None
    
    if keep_explicit_hydrogens:
        mols_test = list(test_dataframe.mols)
    else:
        mols_test = list(test_dataframe.mols_noHs)
    
    test_dataset = Dataset_3D(
        property_type = property_type,
        mols = mols_test, 
        mol_types = list(test_dataframe['mol_type']),
        targets = list(test_dataframe['y']),
        ligand_ID = np.array(test_dataframe['Name_int']),
        atom_ID = atom_ID_test,
        bond_ID_1 = bond_ID_1_test,
        bond_ID_2 = bond_ID_2_test,
        remove_Hs_except_functional = remove_Hs_except_functional,
    )
    test_loader = torch_geometric.loader.DataLoader(
        dataset = test_dataset,
        batch_size = 100,
        shuffle = False,
        num_workers = num_workers,
    )
    
    example_data = test_dataset[0]
    atom_feature_dim = int(example_data.atom_features.shape[-1]) # 53
    
    model = DimeNetPlusPlus(
        property_type = property_type, 
        use_atom_features = use_atom_features, 
        atom_feature_dim = atom_feature_dim if use_atom_features else 1, 
    )
    if pretrained_model != '':
        model.load_state_dict(torch.load(pretrained_model, map_location=next(model.parameters()).device), strict = True)
    
    model.to(device)
    
        
    # -------------------------------------------
    # testing
    
    model.eval()
    test_targets = []
    test_pred_targets = []
    for batch in tqdm(test_loader):
        with torch.no_grad():
            target, pred_target = loop(
                model, 
                batch, 
                property_type = property_type, 
            )
        test_targets.append(target)
        test_pred_targets.append(pred_target)
        
    test_targets = np.concatenate(test_targets)
    test_pred_targets = np.concatenate(test_pred_targets)
    
    test_results = pd.DataFrame()
    test_results['Name_int'] = test_dataframe.Name_int
    test_results['targets'] = test_targets
    test_results['predictions'] = test_pred_targets
    
    test_results=test_results.groupby('Name_int').apply(lambda x: x.mean())
    
    #test_results.to_csv(f'test_set_predictions/{mol_type}_{property_type}_{prop}_{agg}.csv')
    
    test_MAE = np.mean(np.abs(np.array(test_results['targets']) - np.array(test_results['predictions'])))
    test_R2 = np.corrcoef(np.array(test_results['targets']), np.array(test_results['predictions']))[0][1] ** 2
    
    print(model_selection)
    print('    ','MAE:', test_MAE, 'R2:', test_R2)
    

100%|██████████| 57/57 [00:40<00:00,  1.39it/s]


('sec_amine', 'atom', 'NBO_LP_energy', 'Boltz')
     MAE: 0.0024081594853037334 R2: 0.845303778662989


100%|██████████| 57/57 [01:03<00:00,  1.11s/it]


('sec_amine', 'atom', 'NBO_LP_energy', 'max')
     MAE: 0.0038074796578012795 R2: 0.8037607601520149


100%|██████████| 57/57 [00:40<00:00,  1.41it/s]


('sec_amine', 'atom', 'NBO_LP_energy', 'min')
     MAE: 0.002079426763526886 R2: 0.8853669807245566


100%|██████████| 57/57 [00:45<00:00,  1.25it/s]


('sec_amine', 'atom', 'NBO_LP_energy', 'lowE')
     MAE: 0.003170346699086538 R2: 0.7791230817558149


100%|██████████| 57/57 [00:56<00:00,  1.01it/s]


('sec_amine', 'atom', 'NBO_LP_occupancy', 'Boltz')
     MAE: 0.001378415578819183 R2: 0.7841043126400434


100%|██████████| 57/57 [00:41<00:00,  1.37it/s]


('sec_amine', 'atom', 'NBO_LP_occupancy', 'max')
     MAE: 0.0015730067908045757 R2: 0.45890631884191585


100%|██████████| 57/57 [00:43<00:00,  1.30it/s]


('sec_amine', 'atom', 'NBO_LP_occupancy', 'min')
     MAE: 0.0021843216026643194 R2: 0.741270307226808


100%|██████████| 57/57 [00:51<00:00,  1.10it/s]


('sec_amine', 'atom', 'NBO_LP_occupancy', 'lowE')
     MAE: 0.00195109796332547 R2: 0.7273960123105808


100%|██████████| 57/57 [01:14<00:00,  1.30s/it]


('sec_amine', 'atom', 'NBO_charge_H', 'Boltz')
     MAE: 0.0026984641470583566 R2: 0.7985223848416709


100%|██████████| 57/57 [01:10<00:00,  1.23s/it]


('sec_amine', 'atom', 'NBO_charge_H', 'max')
     MAE: 0.0033872152069007538 R2: 0.762107592821542


100%|██████████| 57/57 [01:11<00:00,  1.25s/it]


('sec_amine', 'atom', 'NBO_charge_H', 'min')
     MAE: 0.002854032449454189 R2: 0.6263734014241007


100%|██████████| 57/57 [01:08<00:00,  1.20s/it]


('sec_amine', 'atom', 'NBO_charge_H', 'lowE')
     MAE: 0.0038267479364173 R2: 0.639389525647986


100%|██████████| 57/57 [01:12<00:00,  1.28s/it]


('sec_amine', 'atom', 'NMR_shift_H', 'Boltz')
     MAE: 0.2196028491100633 R2: 0.6830932863073548


100%|██████████| 57/57 [01:12<00:00,  1.27s/it]


('sec_amine', 'atom', 'NMR_shift_H', 'max')
     MAE: 0.22333299874301896 R2: 0.6403560829866938


100%|██████████| 57/57 [01:15<00:00,  1.32s/it]


('sec_amine', 'atom', 'NMR_shift_H', 'min')
     MAE: 0.23051672479713778 R2: 0.7514558973085477


100%|██████████| 57/57 [01:20<00:00,  1.42s/it]


('sec_amine', 'atom', 'NMR_shift_H', 'lowE')
     MAE: 0.30812358090197706 R2: 0.5412609145020595


100%|██████████| 57/57 [01:01<00:00,  1.07s/it]


('sec_amine', 'atom', 'Vbur', 'Boltz')
     MAE: 0.641250135429413 R2: 0.9352503027577399


100%|██████████| 57/57 [00:48<00:00,  1.19it/s]


('sec_amine', 'atom', 'Vbur', 'max')
     MAE: 1.2012775068780983 R2: 0.8785779029492834


100%|██████████| 57/57 [00:43<00:00,  1.30it/s]


('sec_amine', 'atom', 'Vbur', 'min')
     MAE: 0.40453091969930505 R2: 0.9571938124314442


100%|██████████| 57/57 [00:49<00:00,  1.14it/s]


('sec_amine', 'atom', 'Vbur', 'lowE')
     MAE: 0.915331568583906 R2: 0.8916808934044417


100%|██████████| 57/57 [00:39<00:00,  1.43it/s]


('sec_amine', 'atom', 'pyr_agranat', 'Boltz')
     MAE: 0.008351488525130183 R2: 0.7918630058053182


100%|██████████| 57/57 [00:49<00:00,  1.14it/s]


('sec_amine', 'atom', 'pyr_agranat', 'max')
     MAE: 0.007136650832302599 R2: 0.8625775767593991


100%|██████████| 57/57 [00:52<00:00,  1.08it/s]


('sec_amine', 'atom', 'pyr_agranat', 'min')
     MAE: 0.01588471718581326 R2: 0.7024449225644213


100%|██████████| 57/57 [00:32<00:00,  1.74it/s]


('sec_amine', 'atom', 'pyr_agranat', 'lowE')
     MAE: 0.01220511013724239 R2: 0.6100966854712868


100%|██████████| 57/57 [00:32<00:00,  1.73it/s]


('sec_amine', 'mol', 'dipole', 'Boltz')
     MAE: 0.29944628267642487 R2: 0.8791248076565209


100%|██████████| 57/57 [00:31<00:00,  1.81it/s]


('sec_amine', 'mol', 'dipole', 'max')
     MAE: 0.29921099381992616 R2: 0.9248958201124422


100%|██████████| 57/57 [00:28<00:00,  1.97it/s]


('sec_amine', 'mol', 'dipole', 'min')
     MAE: 0.2929454683270081 R2: 0.8103380493954161


100%|██████████| 57/57 [00:29<00:00,  1.96it/s]


('sec_amine', 'mol', 'dipole', 'lowE')
     MAE: 0.42748689603614043 R2: 0.8084231008413366
