In [1]:
import torch
import torch_geometric

import pandas as pd
import numpy as np

import rdkit
import rdkit.Chem
import rdkit.Chem.AllChem
from rdkit import Chem

from tqdm import tqdm
from copy import deepcopy
import random
import re
import os
import shutil
import argparse
import sys
import ast

from datasets.dataset_3D import *
from models.dimenetpp import *


In [5]:
model_dictionary = {
    
    ('acid', 'atom', 'C1_NBO_charge', 'Boltz'):   'trained_models/acids/C1_NBO_charge/boltz/model_best.pt', 
    
    ('acid', 'atom', 'C1_Vbur', 'Boltz'): 'trained_models/acids/C1_Vbur/boltz/model_best.pt',
    ('acid', 'atom', 'C1_Vbur', 'min'): 'trained_models/acids/C1_Vbur/min/model_best.pt', 
    ('acid', 'atom', 'C1_Vbur', 'max'): 'trained_models/acids/C1_Vbur/max/model_best.pt', 
    ('acid', 'atom', 'C1_Vbur', 'lowE'): 'trained_models/acids/C1_Vbur/min_E/model_best.pt', 
    
    ('acid', 'atom', 'O3_NBO_charge', 'Boltz'): 'trained_models/acids/O3_NBO_charge/boltz/model_best.pt',
    ('acid', 'atom', 'O3_NBO_charge', 'min'): 'trained_models/acids/O3_NBO_charge/min/model_best.pt', 
    ('acid', 'atom', 'O3_NBO_charge', 'max'): 'trained_models/acids/O3_NBO_charge/max/model_best.pt',
    ('acid', 'atom', 'O3_NBO_charge', 'lowE'): 'trained_models/acids/O3_NBO_charge/min_E/model_best.pt',
    
    ('acid', 'atom', 'H5_NBO_charge', 'Boltz'): 'trained_models/acids/H5_NBO_charge/boltz/model_best.pt',
    ('acid', 'atom', 'H5_NBO_charge', 'min'): 'trained_models/acids/H5_NBO_charge/min/model_best.pt',
    ('acid', 'atom', 'H5_NBO_charge', 'max'): 'trained_models/acids/H5_NBO_charge/max/model_best.pt',
    ('acid', 'atom', 'H5_NBO_charge', 'lowE'): 'trained_models/acids/H5_NBO_charge/min_E/model_best.pt',
    
    ('acid', 'atom', 'H5_NMR_shift', 'Boltz'): 'trained_models/acids/H5_NMR_shift/boltz/model_best.pt',
    ('acid', 'atom', 'H5_NMR_shift', 'min'): 'trained_models/acids/H5_NMR_shift/min/model_best.pt',
    ('acid', 'atom', 'H5_NMR_shift', 'max'): 'trained_models/acids/H5_NMR_shift/max/model_best.pt',
    ('acid', 'atom', 'H5_NMR_shift', 'lowE'): 'trained_models/acids/H5_NMR_shift/min_E/model_best.pt',
    
    
    ('acid', 'bond', 'IR_freq', 'Boltz'): 'trained_models/acids/IR_freq/boltz/model_best.pt',
    ('acid', 'bond', 'IR_freq', 'min'):   'trained_models/acids/IR_freq/min/model_best.pt',
    ('acid', 'bond', 'IR_freq', 'max'):   'trained_models/acids/IR_freq/max/model_best.pt',
    ('acid', 'bond', 'IR_freq', 'lowE'): 'trained_models/acids/IR_freq/min_E/model_best.pt',
    
    ('acid', 'bond', 'Sterimol_B1', 'Boltz'): 'trained_models/acids/Sterimol_B1/boltz/model_best.pt',
    ('acid', 'bond', 'Sterimol_B1', 'min'):   'trained_models/acids/Sterimol_B1/min/model_best.pt',
    ('acid', 'bond', 'Sterimol_B1', 'max'):   'trained_models/acids/Sterimol_B1/max/model_best.pt',
    ('acid', 'bond', 'Sterimol_B1', 'lowE'): 'trained_models/acids/Sterimol_B1/min_E/model_best.pt',
    
    ('acid', 'bond', 'Sterimol_B5', 'Boltz'): 'trained_models/acids/Sterimol_B5/boltz/model_best.pt',
    ('acid', 'bond', 'Sterimol_B5', 'min'):   'trained_models/acids/Sterimol_B5/min/model_best.pt',
    ('acid', 'bond', 'Sterimol_B5', 'max'):   'trained_models/acids/Sterimol_B5/max/model_best.pt',
    ('acid', 'bond', 'Sterimol_B5', 'lowE'): 'trained_models/acids/Sterimol_B5/min_E/model_best.pt',
    
    ('acid', 'bond', 'Sterimol_L', 'Boltz'): 'trained_models/acids/Sterimol_L/boltz/model_best.pt',
    ('acid', 'bond', 'Sterimol_L', 'min'):   'trained_models/acids/Sterimol_L/min/model_best.pt',
    ('acid', 'bond', 'Sterimol_L', 'max'):   'trained_models/acids/Sterimol_L/max/model_best.pt',
    ('acid', 'bond', 'Sterimol_L', 'lowE'): 'trained_models/acids/Sterimol_L/min_E/model_best.pt',
    
    ('acid', 'mol', 'dipole', 'Boltz'): 'trained_models/acids/dipole/boltz/model_best.pt',
    ('acid', 'mol', 'dipole', 'min'):   'trained_models/acids/dipole/min/model_best.pt',
    ('acid', 'mol', 'dipole', 'max'):   'trained_models/acids/dipole/max/model_best.pt',
    ('acid', 'mol', 'dipole', 'lowE'): 'trained_models/acids/dipole/min_E/model_best.pt',
    
    
}

In [3]:
num_workers = 4
use_atom_features = 1

conformer_data_file = 'data/3D_model_acid_rdkit_conformers.csv'
conformers_df = pd.read_csv(conformer_data_file).reset_index(drop = True)
conformers_df['mols'] = [rdkit.Chem.MolFromMolBlock(m, removeHs = False) for m in conformers_df.mol_block]
conformers_df['mols_noHs'] = [rdkit.Chem.RemoveHs(m) for m in conformers_df['mols']]


[19:04:24] Conflicting single bond directions around double bond at index 25.
[19:04:24]   BondStereo set to STEREONONE and single bond directions set to NONE.
[19:04:24] Conflicting single bond directions around double bond at index 22.
[19:04:24]   BondStereo set to STEREONONE and single bond directions set to NONE.


In [None]:
# testing loop definition

def loop(model, batch, property_type = 'bond'):
    
    batch = batch.to(device)
        
    if property_type == 'bond':
        out = model(
            batch.x.squeeze(), 
            batch.pos, 
            batch.batch,
            batch.atom_features,
            select_bond_start_atom_index = batch.bond_start_ID_index,
            select_bond_end_atom_index = batch.bond_end_ID_index,
        )
    
    elif property_type == 'atom':
        out = model(
            batch.x.squeeze(),
            batch.pos, 
            batch.batch,
            batch.atom_features,
            select_atom_index = batch.atom_ID_index,
        )
        
    elif property_type == 'mol':
        out = model(
            batch.x.squeeze(),
            batch.pos,
            batch.batch,
            batch.atom_features,
        )
    
    targets = batch.targets
    pred_targets = out[0].squeeze()
    mse_loss = torch.mean(torch.square(targets - pred_targets))
    mae = torch.mean(torch.abs(targets - pred_targets))    
    
    return targets.detach().cpu().numpy(), pred_targets.detach().cpu().numpy()


In [6]:
for model_selection in model_dictionary:
    
    keep_explicit_hydrogens = model_selection[2] in ['H5_NBO_charge', 'H5_NMR_shift', 'NBO_charge_H', 'NMR_shift_H']
    remove_Hs_except_functional = 2 if keep_explicit_hydrogens else False
    
    mol_type = model_selection[0]
    property_type = model_selection[1]
    prop =  model_selection[2]
    agg = model_selection[3]
    
    if mol_type == 'acid':
        descriptor_data_file = f'data/acid/{property_type}/{prop}_{agg}.csv'
    if mol_type == 'amine':
        try:
            descriptor_data_file = f'data/combined_amine/{property_type}/{prop}_{agg}.csv'
            descriptor_df = pd.read_csv(descriptor_data_file, converters={"bond_atom_tuple": ast.literal_eval})
        except:
            descriptor_data_file = f'data/primary_amine/{property_type}/{prop}_{agg}.csv'
            descriptor_df = pd.read_csv(descriptor_data_file, converters={"bond_atom_tuple": ast.literal_eval})
    if mol_type == 'sec_amine':
        try:
            descriptor_data_file = f'data/combined_amine/{property_type}/{prop}_{agg}.csv'
            descriptor_df = pd.read_csv(descriptor_data_file, converters={"bond_atom_tuple": ast.literal_eval})
        except:
            descriptor_data_file = f'data/secondary_amine/{property_type}/{prop}_{agg}.csv'
            descriptor_df = pd.read_csv(descriptor_data_file, converters={"bond_atom_tuple": ast.literal_eval})
        
    
    pretrained_model = model_dictionary[model_selection]
    
    # -------------------------------------------
    # Loading training data (regression targets and input conformers)
    
    descriptor_df = pd.read_csv(descriptor_data_file, converters={"bond_atom_tuple": ast.literal_eval})
    
    merged_df = conformers_df.merge(descriptor_df, on = 'Name_int')
    test_dataframe = merged_df[merged_df.split == 'test'].reset_index(drop = True)
        
    # -------------------------------------------
    # creating model, optimizer, dataloaders
    
    device = "cpu" #torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    if property_type == 'atom':
        atom_ID_test = np.array(list(test_dataframe.atom_index), dtype = int)
    else:
        atom_ID_test = None
    
    if property_type == 'bond':
        bond_ID_1_test = np.array(list(test_dataframe.bond_atom_tuple), dtype = int)[:, 0]
        bond_ID_2_test = np.array(list(test_dataframe.bond_atom_tuple), dtype = int)[:, 1]
    else:
        bond_ID_1_test = None
        bond_ID_2_test = None
    
    if keep_explicit_hydrogens:
        mols_test = list(test_dataframe.mols)
    else:
        mols_test = list(test_dataframe.mols_noHs)
    
    test_dataset = Dataset_3D(
        property_type = property_type,
        mols = mols_test, 
        mol_types = list(test_dataframe['mol_type']),
        targets = list(test_dataframe['y']),
        ligand_ID = np.array(test_dataframe['Name_int']),
        atom_ID = atom_ID_test,
        bond_ID_1 = bond_ID_1_test,
        bond_ID_2 = bond_ID_2_test,
        remove_Hs_except_functional = remove_Hs_except_functional,
    )
    test_loader = torch_geometric.loader.DataLoader(
        dataset = test_dataset,
        batch_size = 100,
        shuffle = False,
        num_workers = num_workers,
    )
    
    example_data = test_dataset[0]
    atom_feature_dim = int(example_data.atom_features.shape[-1]) # 53
    
    model = DimeNetPlusPlus(
        property_type = property_type, 
        use_atom_features = use_atom_features, 
        atom_feature_dim = atom_feature_dim if use_atom_features else 1, 
    )
    if pretrained_model != '':
        model.load_state_dict(torch.load(pretrained_model, map_location=next(model.parameters()).device), strict = True)
    
    model.to(device)
    
    # -------------------------------------------
    # testing
    
    model.eval()
    test_targets = []
    test_pred_targets = []
    for batch in tqdm(test_loader):
        with torch.no_grad():
            target, pred_target = loop(
                model, 
                batch, 
                property_type = property_type, 
            )
        test_targets.append(target)
        test_pred_targets.append(pred_target)
        
    test_targets = np.concatenate(test_targets)
    test_pred_targets = np.concatenate(test_pred_targets)
    
    test_results = pd.DataFrame()
    test_results['Name_int'] = test_dataframe.Name_int
    test_results['targets'] = test_targets
    test_results['predictions'] = test_pred_targets
    
    test_results=test_results.groupby('Name_int').apply(lambda x: x.mean())
    
    #test_results.to_csv(f'test_set_predictions/{mol_type}_{property_type}_{prop}_{agg}.csv')
    
    test_MAE = np.mean(np.abs(np.array(test_results['targets']) - np.array(test_results['predictions'])))
    test_R2 = np.corrcoef(np.array(test_results['targets']), np.array(test_results['predictions']))[0][1] ** 2
    
    print(model_selection)
    print('    ','MAE:', test_MAE, 'R2:', test_R2)
    

100%|██████████| 44/44 [00:42<00:00,  1.03it/s]


('acid', 'bond', 'IR_freq', 'Boltz')
     MAE: 2.8184686228006828 R2: 0.9172257441907982


100%|██████████| 44/44 [01:09<00:00,  1.59s/it]


('acid', 'bond', 'IR_freq', 'min')
     MAE: 3.7841630182346377 R2: 0.817582645305739


100%|██████████| 44/44 [00:58<00:00,  1.33s/it]


('acid', 'bond', 'IR_freq', 'max')
     MAE: 3.0489183954831933 R2: 0.9227014787032608


100%|██████████| 44/44 [00:57<00:00,  1.30s/it]


('acid', 'bond', 'IR_freq', 'lowE')
     MAE: 4.360769576385241 R2: 0.8689578228709121


100%|██████████| 44/44 [00:43<00:00,  1.02it/s]


('acid', 'bond', 'Sterimol_B1', 'Boltz')
     MAE: 0.07710167540221655 R2: 0.8405513955268187


100%|██████████| 44/44 [01:09<00:00,  1.58s/it]


('acid', 'bond', 'Sterimol_B1', 'min')
     MAE: 0.04891798651518942 R2: 0.9121633775027215


100%|██████████| 44/44 [00:59<00:00,  1.35s/it]


('acid', 'bond', 'Sterimol_B1', 'max')
     MAE: 0.1388681743826185 R2: 0.7659256392147735


100%|██████████| 44/44 [01:12<00:00,  1.66s/it]


('acid', 'bond', 'Sterimol_B1', 'lowE')
     MAE: 0.12191491517700068 R2: 0.6470598725302308


100%|██████████| 44/44 [01:10<00:00,  1.60s/it]


('acid', 'bond', 'Sterimol_B5', 'Boltz')
     MAE: 0.41340818825890036 R2: 0.8957372323849371


100%|██████████| 44/44 [01:16<00:00,  1.74s/it]


('acid', 'bond', 'Sterimol_B5', 'min')
     MAE: 0.36123269596019714 R2: 0.8167726041996152


100%|██████████| 44/44 [01:20<00:00,  1.82s/it]


('acid', 'bond', 'Sterimol_B5', 'max')
     MAE: 0.3002559121917276 R2: 0.9678505690138206


100%|██████████| 44/44 [01:09<00:00,  1.58s/it]


('acid', 'bond', 'Sterimol_B5', 'lowE')
     MAE: 0.6712102048537311 R2: 0.7476588932203901


100%|██████████| 44/44 [01:03<00:00,  1.45s/it]


('acid', 'bond', 'Sterimol_L', 'Boltz')
     MAE: 0.5063205586761987 R2: 0.8044965225463881


100%|██████████| 44/44 [00:41<00:00,  1.06it/s]


('acid', 'bond', 'Sterimol_L', 'min')
     MAE: 0.19576590802489208 R2: 0.9690634671519615


100%|██████████| 44/44 [00:55<00:00,  1.27s/it]


('acid', 'bond', 'Sterimol_L', 'max')
     MAE: 0.49023742936238524 R2: 0.8824687250873694


100%|██████████| 44/44 [01:05<00:00,  1.50s/it]


('acid', 'bond', 'Sterimol_L', 'lowE')
     MAE: 0.8000724045168451 R2: 0.63096389900824


100%|██████████| 44/44 [00:59<00:00,  1.36s/it]


('acid', 'mol', 'dipole', 'Boltz')
     MAE: 0.4161579473679807 R2: 0.7377818833196705


100%|██████████| 44/44 [00:32<00:00,  1.34it/s]


('acid', 'mol', 'dipole', 'min')
     MAE: 0.4274945945612022 R2: 0.6996448648638428


100%|██████████| 44/44 [00:33<00:00,  1.30it/s]


('acid', 'mol', 'dipole', 'max')
     MAE: 0.39166391021063346 R2: 0.8557187131160934


100%|██████████| 44/44 [00:32<00:00,  1.34it/s]


('acid', 'mol', 'dipole', 'lowE')
     MAE: 0.7086338995635009 R2: 0.5445997724692987
