In [4]:
import torch
import torch_geometric

import pandas as pd
import numpy as np

import rdkit
import rdkit.Chem
import rdkit.Chem.AllChem
from rdkit import Chem

from tqdm import tqdm
from copy import deepcopy
import random
import re
import os
import shutil
import argparse
import sys
import ast

from datasets.dataset_3D import *
from models.dimenetpp import *


In [5]:
model_dictionary = {
    ('amine', 'atom', 'NBO_LP_energy', 'Boltz'): 'trained_models/combined_amines/NBO_LP_energy/boltz/model_best.pt',
    ('amine', 'atom', 'NBO_LP_energy', 'max'):   'trained_models/combined_amines/NBO_LP_energy/max/model_best.pt',
    ('amine', 'atom', 'NBO_LP_energy', 'min'):   'trained_models/combined_amines/NBO_LP_energy/min/model_best.pt',
    ('amine', 'atom', 'NBO_LP_energy', 'lowE'): 'trained_models/combined_amines/NBO_LP_energy/min_E/model_best.pt',
    
    ('amine', 'atom', 'NBO_LP_occupancy', 'Boltz'): 'trained_models/combined_amines/NBO_LP_occupancy/boltz/model_best.pt',
    ('amine', 'atom', 'NBO_LP_occupancy', 'max'):   'trained_models/combined_amines/NBO_LP_occupancy/max/model_best.pt',
    ('amine', 'atom', 'NBO_LP_occupancy', 'min'):   'trained_models/combined_amines/NBO_LP_occupancy/min/model_best.pt',
    ('amine', 'atom', 'NBO_LP_occupancy', 'lowE'): 'trained_models/combined_amines/NBO_LP_occupancy/min_E/model_best.pt',
    
    ('amine', 'atom', 'pyr_agranat', 'Boltz'): 'trained_models/combined_amines/pyr_agranat/boltz/model_best.pt',
    ('amine', 'atom', 'pyr_agranat', 'max'):   'trained_models/combined_amines/pyr_agranat/max/model_best.pt',
    ('amine', 'atom', 'pyr_agranat', 'min'):   'trained_models/combined_amines/pyr_agranat/min/model_best.pt',
    ('amine', 'atom', 'pyr_agranat', 'lowE'): 'trained_models/combined_amines/pyr_agranat/min_E/model_best.pt',
    
    ('amine', 'atom', 'Vbur', 'Boltz'): 'trained_models/combined_amines/Vbur/boltz/model_best.pt',
    ('amine', 'atom', 'Vbur', 'max'):   'trained_models/combined_amines/Vbur/max/model_best.pt',
    ('amine', 'atom', 'Vbur', 'min'):   'trained_models/combined_amines/Vbur/min/model_best.pt',
    ('amine', 'atom', 'Vbur', 'lowE'): 'trained_models/combined_amines/Vbur/min_E/model_best.pt',
    
    ('amine', 'bond', 'Sterimol_B1', 'Boltz'): 'trained_models/amines/Sterimol_B1/boltz/model_best.pt',
    ('amine', 'bond', 'Sterimol_B1', 'max'):   'trained_models/amines/Sterimol_B1/max/model_best.pt',
    ('amine', 'bond', 'Sterimol_B1', 'min'):   'trained_models/amines/Sterimol_B1/min/model_best.pt',
    ('amine', 'bond', 'Sterimol_B1', 'lowE'): 'trained_models/amines/Sterimol_B1/min_E/model_best.pt',
    
    ('amine', 'bond', 'Sterimol_B5', 'Boltz'): 'trained_models/amines/Sterimol_B5/boltz/model_best.pt',
    ('amine', 'bond', 'Sterimol_B5', 'max'):   'trained_models/amines/Sterimol_B5/max/model_best.pt',
    ('amine', 'bond', 'Sterimol_B5', 'min'):   'trained_models/amines/Sterimol_B5/min/model_best.pt',
    ('amine', 'bond', 'Sterimol_B5', 'lowE'): 'trained_models/amines/Sterimol_B5/min_E/model_best.pt',
    
    ('amine', 'bond', 'Sterimol_L', 'Boltz'): 'trained_models/amines/Sterimol_L/boltz/model_best.pt',
    ('amine', 'bond', 'Sterimol_L', 'max'):   'trained_models/amines/Sterimol_L/max/model_best.pt',
    ('amine', 'bond', 'Sterimol_L', 'min'):   'trained_models/amines/Sterimol_L/min/model_best.pt',
    ('amine', 'bond', 'Sterimol_L', 'lowE'): 'trained_models/amines/Sterimol_L/min_E/model_best.pt',

    ('amine', 'mol', 'dipole', 'Boltz'): 'trained_models/combined_amines/dipole/boltz/model_best.pt',
    ('amine', 'mol', 'dipole', 'max'):   'trained_models/combined_amines/dipole/max/model_best.pt',
    ('amine', 'mol', 'dipole', 'min'):   'trained_models/combined_amines/dipole/min/model_best.pt',
    ('amine', 'mol', 'dipole', 'lowE'): 'trained_models/combined_amines/dipole/min_E/model_best.pt',
    
}

In [6]:
num_workers = 4
use_atom_features = 1

conformer_data_file = 'data/3D_model_primaryamine_rdkit_conformers.csv'
conformers_df = pd.read_csv(conformer_data_file).reset_index(drop = True)
conformers_df['mols'] = [rdkit.Chem.MolFromMolBlock(m, removeHs = False) for m in conformers_df.mol_block]
conformers_df['mols_noHs'] = [rdkit.Chem.RemoveHs(m) for m in conformers_df['mols']]


In [None]:
# testing loop definition

def loop(model, batch, property_type = 'bond'):
    
    batch = batch.to(device)
        
    if property_type == 'bond':
        out = model(
            batch.x.squeeze(), 
            batch.pos, 
            batch.batch,
            batch.atom_features,
            select_bond_start_atom_index = batch.bond_start_ID_index,
            select_bond_end_atom_index = batch.bond_end_ID_index,
        )
    
    elif property_type == 'atom':
        out = model(
            batch.x.squeeze(),
            batch.pos, 
            batch.batch,
            batch.atom_features,
            select_atom_index = batch.atom_ID_index,
        )
        
    elif property_type == 'mol':
        out = model(
            batch.x.squeeze(),
            batch.pos,
            batch.batch,
            batch.atom_features,
        )
    
    targets = batch.targets
    pred_targets = out[0].squeeze()
    mse_loss = torch.mean(torch.square(targets - pred_targets))
    mae = torch.mean(torch.abs(targets - pred_targets))    
    
    return targets.detach().cpu().numpy(), pred_targets.detach().cpu().numpy()


In [7]:
for model_selection in model_dictionary:
    
    keep_explicit_hydrogens = model_selection[2] in ['H5_NBO_charge', 'H5_NMR_shift', 'NBO_charge_H_avg', 'NBO_charge_H_min', 'NMR_shift_H_avg', 'NBO_charge_H', 'NMR_shift_H']
    remove_Hs_except_functional = 2 if keep_explicit_hydrogens else False
    
    mol_type = model_selection[0]
    property_type = model_selection[1]
    prop =  model_selection[2]
    agg = model_selection[3]
    
    if mol_type == 'acid':
        descriptor_data_file = f'data/acid/{property_type}/{prop}_{agg}.csv'
    if mol_type == 'amine':
        try:
            descriptor_data_file = f'data/combined_amine/{property_type}/{prop}_{agg}.csv'
            descriptor_df = pd.read_csv(descriptor_data_file, converters={"bond_atom_tuple": ast.literal_eval})
        except:
            descriptor_data_file = f'data/primary_amine/{property_type}/{prop}_{agg}.csv'
            descriptor_df = pd.read_csv(descriptor_data_file, converters={"bond_atom_tuple": ast.literal_eval})
    if mol_type == 'sec_amine':
        try:
            descriptor_data_file = f'data/combined_amine/{property_type}/{prop}_{agg}.csv'
            descriptor_df = pd.read_csv(descriptor_data_file, converters={"bond_atom_tuple": ast.literal_eval})
        except:
            descriptor_data_file = f'data/secondary_amine/{property_type}/{prop}_{agg}.csv'
            descriptor_df = pd.read_csv(descriptor_data_file, converters={"bond_atom_tuple": ast.literal_eval})
        
    
    pretrained_model = model_dictionary[model_selection]
    
    # -------------------------------------------
    # Loading training data (regression targets and input conformers)
    
    descriptor_df = pd.read_csv(descriptor_data_file, converters={"bond_atom_tuple": ast.literal_eval})
    
    merged_df = conformers_df.merge(descriptor_df, on = 'Name_int')
    test_dataframe = merged_df[merged_df.split == 'test'].reset_index(drop = True)
        
    # -------------------------------------------
    # creating model, optimizer, dataloaders
    
    device = "cpu" #torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    if property_type == 'atom':
        atom_ID_test = np.array(list(test_dataframe.atom_index), dtype = int)
    else:
        atom_ID_test = None
    
    if property_type == 'bond':
        bond_ID_1_test = np.array(list(test_dataframe.bond_atom_tuple), dtype = int)[:, 0]
        bond_ID_2_test = np.array(list(test_dataframe.bond_atom_tuple), dtype = int)[:, 1]
    else:
        bond_ID_1_test = None
        bond_ID_2_test = None
    
    if keep_explicit_hydrogens:
        mols_test = list(test_dataframe.mols)
    else:
        mols_test = list(test_dataframe.mols_noHs)
    
    test_dataset = Dataset_3D(
        property_type = property_type,
        mols = mols_test, 
        mol_types = list(test_dataframe['mol_type']),
        targets = list(test_dataframe['y']),
        ligand_ID = np.array(test_dataframe['Name_int']),
        atom_ID = atom_ID_test,
        bond_ID_1 = bond_ID_1_test,
        bond_ID_2 = bond_ID_2_test,
        remove_Hs_except_functional = remove_Hs_except_functional,
    )
    test_loader = torch_geometric.loader.DataLoader(
        dataset = test_dataset,
        batch_size = 100,
        shuffle = False,
        num_workers = num_workers,
    )
    
    example_data = test_dataset[0]
    atom_feature_dim = int(example_data.atom_features.shape[-1]) # 53
    
    model = DimeNetPlusPlus(
        property_type = property_type, 
        use_atom_features = use_atom_features, 
        atom_feature_dim = atom_feature_dim if use_atom_features else 1, 
    )
    if pretrained_model != '':
        model.load_state_dict(torch.load(pretrained_model, map_location=next(model.parameters()).device), strict = True)
    
    model.to(device)
    
    # -------------------------------------------
    # testing
    
    model.eval()
    test_targets = []
    test_pred_targets = []
    for batch in tqdm(test_loader):
        with torch.no_grad():
            target, pred_target = loop(
                model, 
                batch, 
                property_type = property_type, 
            )
        test_targets.append(target)
        test_pred_targets.append(pred_target)
        
    test_targets = np.concatenate(test_targets)
    test_pred_targets = np.concatenate(test_pred_targets)
    
    test_results = pd.DataFrame()
    test_results['Name_int'] = test_dataframe.Name_int
    test_results['targets'] = test_targets
    test_results['predictions'] = test_pred_targets
    
    test_results=test_results.groupby('Name_int').apply(lambda x: x.mean())
    
    #test_results.to_csv(f'test_set_predictions/{mol_type}_{property_type}_{prop}_{agg}.csv')
    
    test_MAE = np.mean(np.abs(np.array(test_results['targets']) - np.array(test_results['predictions'])))
    test_R2 = np.corrcoef(np.array(test_results['targets']), np.array(test_results['predictions']))[0][1] ** 2
    
    print(model_selection)
    print('    ','MAE:', test_MAE, 'R2:', test_R2)
    

100%|██████████| 49/49 [00:21<00:00,  2.32it/s]


('amine', 'atom', 'NBO_LP_energy', 'Boltz')
     MAE: 0.0021559334838921244 R2: 0.8744058369844313


100%|██████████| 49/49 [00:33<00:00,  1.47it/s]


('amine', 'atom', 'NBO_LP_energy', 'max')
     MAE: 0.002756369198382142 R2: 0.881383000730126


100%|██████████| 49/49 [00:58<00:00,  1.20s/it]


('amine', 'atom', 'NBO_LP_energy', 'min')
     MAE: 0.0021762033464454928 R2: 0.854570859277051


100%|██████████| 49/49 [00:22<00:00,  2.19it/s]


('amine', 'atom', 'NBO_LP_energy', 'lowE')
     MAE: 0.003323667203849144 R2: 0.7591541927223369


100%|██████████| 49/49 [00:28<00:00,  1.72it/s]


('amine', 'atom', 'NBO_LP_occupancy', 'Boltz')
     MAE: 0.001559238684804816 R2: 0.7166697087367544


100%|██████████| 49/49 [00:32<00:00,  1.52it/s]


('amine', 'atom', 'NBO_LP_occupancy', 'max')
     MAE: 0.0012873555967199658 R2: 0.673442586688349


100%|██████████| 49/49 [00:49<00:00,  1.00s/it]


('amine', 'atom', 'NBO_LP_occupancy', 'min')
     MAE: 0.001816798800881575 R2: 0.6734128354577366


100%|██████████| 49/49 [00:19<00:00,  2.45it/s]


('amine', 'atom', 'NBO_LP_occupancy', 'lowE')
     MAE: 0.002374567242286466 R2: 0.4875954423923603


100%|██████████| 49/49 [00:29<00:00,  1.65it/s]


('amine', 'atom', 'pyr_agranat', 'Boltz')
     MAE: 0.005175400359427881 R2: 0.678986041296753


100%|██████████| 49/49 [00:30<00:00,  1.60it/s]


('amine', 'atom', 'pyr_agranat', 'max')
     MAE: 0.00464439355892691 R2: 0.7635226096649155


100%|██████████| 49/49 [00:43<00:00,  1.13it/s]


('amine', 'atom', 'pyr_agranat', 'min')
     MAE: 0.006828647876075405 R2: 0.5346238918058296


100%|██████████| 49/49 [00:45<00:00,  1.08it/s]


('amine', 'atom', 'pyr_agranat', 'lowE')
     MAE: 0.007824058112827872 R2: 0.5111512026550787


100%|██████████| 49/49 [00:29<00:00,  1.65it/s]


('amine', 'atom', 'Vbur', 'Boltz')
     MAE: 0.7652666810070455 R2: 0.8531777790153916


100%|██████████| 49/49 [00:33<00:00,  1.47it/s]


('amine', 'atom', 'Vbur', 'max')
     MAE: 1.3585899677353832 R2: 0.7807469476908689


100%|██████████| 49/49 [00:31<00:00,  1.57it/s]


('amine', 'atom', 'Vbur', 'min')
     MAE: 0.3764705812400169 R2: 0.938014194451951


100%|██████████| 49/49 [00:56<00:00,  1.16s/it]


('amine', 'atom', 'Vbur', 'lowE')
     MAE: 1.276686046770227 R2: 0.680498664022833


100%|██████████| 49/49 [00:30<00:00,  1.63it/s]


('amine', 'bond', 'Sterimol_B1', 'Boltz')
     MAE: 0.06347293096032702 R2: 0.8682309555357262


100%|██████████| 49/49 [00:32<00:00,  1.49it/s]


('amine', 'bond', 'Sterimol_B1', 'max')
     MAE: 0.11714816720862138 R2: 0.785294302957384


100%|██████████| 49/49 [01:00<00:00,  1.23s/it]


('amine', 'bond', 'Sterimol_B1', 'min')
     MAE: 0.021661810064122745 R2: 0.9588853803955562


100%|██████████| 49/49 [00:40<00:00,  1.22it/s]


('amine', 'bond', 'Sterimol_B1', 'lowE')
     MAE: 0.08855620931517258 R2: 0.7323202355876627


100%|██████████| 49/49 [00:30<00:00,  1.62it/s]


('amine', 'bond', 'Sterimol_B5', 'Boltz')
     MAE: 0.3732281440665365 R2: 0.8247837243147924


100%|██████████| 49/49 [01:03<00:00,  1.29s/it]


('amine', 'bond', 'Sterimol_B5', 'max')
     MAE: 0.2702093553929194 R2: 0.9415711442082542


100%|██████████| 49/49 [00:46<00:00,  1.06it/s]


('amine', 'bond', 'Sterimol_B5', 'min')
     MAE: 0.3145484340335676 R2: 0.8254921770739141


100%|██████████| 49/49 [00:35<00:00,  1.37it/s]


('amine', 'bond', 'Sterimol_B5', 'lowE')
     MAE: 0.6326232559767812 R2: 0.632962845281481


100%|██████████| 49/49 [00:52<00:00,  1.08s/it]


('amine', 'bond', 'Sterimol_L', 'Boltz')
     MAE: 0.4622803776852998 R2: 0.5485634733660868


100%|██████████| 49/49 [00:41<00:00,  1.17it/s]


('amine', 'bond', 'Sterimol_L', 'max')
     MAE: 0.49409236888653835 R2: 0.829760226281683


100%|██████████| 49/49 [00:28<00:00,  1.70it/s]


('amine', 'bond', 'Sterimol_L', 'min')
     MAE: 0.2029010129843646 R2: 0.731000131885836


100%|██████████| 49/49 [00:25<00:00,  1.92it/s]


('amine', 'bond', 'Sterimol_L', 'lowE')
     MAE: 0.8314848245396788 R2: 0.26382518881885475


100%|██████████| 49/49 [00:37<00:00,  1.29it/s]


('amine', 'mol', 'dipole', 'Boltz')
     MAE: 0.30824603601867856 R2: 0.8334346624210046


100%|██████████| 49/49 [00:29<00:00,  1.67it/s]


('amine', 'mol', 'dipole', 'max')
     MAE: 0.3140860234013936 R2: 0.8865966872100713


100%|██████████| 49/49 [00:45<00:00,  1.08it/s]


('amine', 'mol', 'dipole', 'min')
     MAE: 0.31920078408862895 R2: 0.7311354429918763


100%|██████████| 49/49 [00:45<00:00,  1.09it/s]


('amine', 'mol', 'dipole', 'lowE')
     MAE: 0.4911697904469996 R2: 0.6835344215960967
