In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch_geometric
from torch_geometric.nn.acts import swish

from tqdm import tqdm
from copy import deepcopy
import json

from model.params_interpreter import string_to_object 
from model.alpha_encoder import Encoder

from model.gnn_3D.schnet import SchNet
from model.gnn_3D.dimenet_pp import DimeNetPlusPlus
from model.gnn_3D.spherenet import SphereNet

from model.train_functions import classification_loop_alpha
from model.train_functions import evaluate_classification_loop_alpha
from model.gnn_3D.train_functions import classification_loop
from model.gnn_3D.train_functions import evaluate_classification_loop

from model.datasets_samplers import Dataset_3D_GNN, MaskedGraphDataset, StereoBatchSampler, SiameseBatchSampler, Sample_Map_To_Positives, Sample_Map_To_Negatives, NegativeBatchSampler, SingleConformerBatchSampler


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
def get_accuracy_ChIRo(path_to_results = None, path_to_params_file = None, path_to_model_dict = None):
    try:
        results_df = pd.read_csv(str(path_to_results))
    except:
        print('creating model...')
        best_state_dict = str(path_to_model_dict)

        with open(path_to_params_file) as f: # should contain path to params.json file
            params = json.load(f)
        
        layers_dict = deepcopy(params['layers_dict'])
        
        activation_dict = deepcopy(params['activation_dict'])
        for key, value in params['activation_dict'].items(): 
            activation_dict[key] = string_to_object[value] # convert strings to actual python objects/functions using pre-defined mapping
        
        num_node_features = 52
        num_edge_features = 14
        
        model = Encoder(
            F_z_list = params['F_z_list'], # dimension of latent space
            F_H = params['F_H'], # dimension of final node embeddings, after EConv and GAT layers
            F_H_embed = num_node_features, # dimension of initial node feature vector, currently 41
            F_E_embed = num_edge_features, # dimension of initial edge feature vector, currently 12
            F_H_EConv = params['F_H_EConv'], # dimension of node embedding after EConv layer
            layers_dict = layers_dict,
            activation_dict = activation_dict,
            GAT_N_heads = params['GAT_N_heads'],
            chiral_message_passing = params['chiral_message_passing'],
            CMP_EConv_MLP_hidden_sizes = params['CMP_EConv_MLP_hidden_sizes'],
            CMP_GAT_N_layers = params['CMP_GAT_N_layers'],
            CMP_GAT_N_heads = params['CMP_GAT_N_heads'],
            c_coefficient_normalization = params['c_coefficient_normalization'], # None, or one of ['sigmoid','softmax']
            sinusoidal_shift = params['sinusoidal_shift'], # true or false
            encoder_reduction = params['encoder_reduction'], #mean or sum
            output_concatenation_mode = params['output_concatenation_mode'], # none (if contrastive), conformer, molecule, or z_alpha (if regression)
            EConv_bias = params['EConv_bias'], 
            GAT_bias = params['GAT_bias'], 
            encoder_biases = params['encoder_biases'], 
            dropout = params['dropout'], # applied to hidden layers (not input/output layer) of Encoder MLPs, hidden layers (not input/output layer) of EConv MLP, and all GAT layers (using their dropout parameter)
            )
        
        model.load_state_dict(torch.load(best_state_dict, map_location=next(model.parameters()).device), strict=True)
        model.to(device)
        
        test_dataframe = pd.read_pickle(params['test_datafile'])        
        test_dataset = MaskedGraphDataset(test_dataframe, 
                                            regression = 'RS_label_binary', # top_score,  RS_label_binary, sign_rotation
                                            stereoMask = params['stereoMask'],
                                            mask_coordinates = params['mask_coordinates'], 
                                            )
        test_loader = torch_geometric.data.DataLoader(test_dataset, num_workers = 0, batch_size = 1, shuffle = False)
        
        targets, outputs = evaluate_classification_loop_alpha(model, test_loader, device, batch_size = 1, dataset_size = len(test_dataset))
        
        results_df = deepcopy(test_dataframe[['ID', 'SMILES_nostereo', 'RS_label_binary']])
        results_df['targets'] = targets
        results_df['outputs'] = outputs
    
    y = torch.tensor(np.array(results_df.targets))
    output = torch.tensor(np.array(results_df.outputs))
    
    acc = 1.0 - (torch.sum(torch.abs(y.squeeze().detach() - torch.round(torch.sigmoid(output.squeeze().detach())))) / y.shape[0])
    
    return results_df, float(acc)

In [None]:
def get_accuracy_schnet(path_to_results = None, path_to_params_file = None, path_to_model_dict = None):
    try:
        results_df = pd.read_csv(str(path_to_results))
    except:
        print('creating model...')
        params_file = str(path_to_params_file)
        best_state_dict = str(path_to_model_dict)

        with open(params_file) as f: # should contain path to params.json file
            params = json.load(f)
        
        model = SchNet(hidden_channels = params['hidden_channels'],
                       num_filters = params['num_filters'], 
                       num_interactions = params['num_interactions'], 
                       num_gaussians = params['num_gaussians'], 
                       cutoff = params['cutoff'], 
                       max_num_neighbors = params['max_num_neighbors'], 
                       out_channels = params['out_channels'],
                       readout = 'add',
                       dipole = False,
                       mean = None,
                       std = None,
                       atomref = None, 
                       MLP_hidden_sizes = params['MLP_hidden_sizes'], # [] for contrastive
            )
        
        model.load_state_dict(torch.load(best_state_dict, map_location=next(model.parameters()).device), strict=True)
        model.to(device)
        
        test_dataframe = pd.read_pickle(params['test_datafile'])        
        test_dataset = Dataset_3D_GNN(test_dataframe, 
                                regression = 'RS_label_binary', # top_score,  RS_label_binary, sign_rotation
                            )
        test_loader = torch_geometric.data.DataLoader(test_dataset, num_workers = 0, batch_size = 1, shuffle = False)
        
        targets, outputs = evaluate_classification_loop(model, test_loader, device, batch_size = 1, dataset_size = len(test_dataset))
        
        results_df = deepcopy(test_dataframe[['ID', 'SMILES_nostereo', 'RS_label_binary']])
        results_df['targets'] = targets
        results_df['outputs'] = outputs
    
    y = torch.tensor(np.array(results_df.targets))
    output = torch.tensor(np.array(results_df.outputs))
    
    acc = 1.0 - (torch.sum(torch.abs(y.squeeze().detach() - torch.round(torch.sigmoid(output.squeeze().detach())))) / y.shape[0])
    
    return results_df, float(acc)

In [None]:
def get_accuracy_dimenetpp(path_to_results = None, path_to_params_file = None, path_to_model_dict = None):
    try:
        results_df = pd.read_csv(str(path_to_results))
    except:
        print('creating model...')
        params_file = str(path_to_params_file)
        best_state_dict = str(path_to_model_dict)

        with open(params_file) as f: # should contain path to params.json file
            params = json.load(f)
        
        model = DimeNetPlusPlus(
                hidden_channels = params['hidden_channels'], 
                out_channels = params['out_channels'], 
                num_blocks = params['num_blocks'], 
                int_emb_size = params['int_emb_size'], 
                basis_emb_size = params['basis_emb_size'],
                out_emb_channels = params['out_emb_channels'], 
                num_spherical = params['num_spherical'], 
                num_radial = params['num_radial'],
                cutoff=params['cutoff'], 
                envelope_exponent=params['envelope_exponent'], 
                num_before_skip=params['num_before_skip'], 
                num_after_skip=params['num_after_skip'], 
                num_output_layers=params['num_output_layers'], 
                act=swish,
                MLP_hidden_sizes = params['MLP_hidden_sizes'], # [] for contrastive
            )
        
        model.load_state_dict(torch.load(best_state_dict, map_location=next(model.parameters()).device), strict=True)
        model.to(device)
        
        test_dataframe = pd.read_pickle(params['test_datafile'])        
        test_dataset = Dataset_3D_GNN(test_dataframe, 
                                    regression = 'RS_label_binary', # sign_rotation, top_score, RS_label_binary
                            )
        test_loader = torch_geometric.data.DataLoader(test_dataset, num_workers = 0, batch_size = 1, shuffle = False)
        
        targets, outputs = evaluate_classification_loop(model, test_loader, device, batch_size = 1, dataset_size = len(test_dataset))
        
        results_df = deepcopy(test_dataframe[['ID', 'SMILES_nostereo', 'RS_label_binary']])
        results_df['targets'] = targets
        results_df['outputs'] = outputs
    
    y = torch.tensor(np.array(results_df.targets))
    output = torch.tensor(np.array(results_df.outputs))
    
    acc = 1.0 - (torch.sum(torch.abs(y.squeeze().detach() - torch.round(torch.sigmoid(output.squeeze().detach())))) / y.shape[0])
    
    return results_df, float(acc)

In [None]:
def get_accuracy_spherenet(path_to_results = None, path_to_params_file = None, path_to_model_dict = None):
    try:
        results_df = pd.read_csv(str(path_to_results))
    except:
        print('creating model...')
        params_file = str(path_to_params_file)
        best_state_dict = str(path_to_model_dict)

        with open(params_file) as f: # should contain path to params.json file
            params = json.load(f)
        
        model = SphereNet(
                    energy_and_force = False, 
                    cutoff = params['cutoff'],
                    num_layers = params['num_layers'], 
                    hidden_channels = params['hidden_channels'],
                    out_channels = params['out_channels'], 
                    int_emb_size = params['int_emb_size'],
                    basis_emb_size_dist = params['basis_emb_size_dist'],
                    basis_emb_size_angle = params['basis_emb_size_angle'], 
                    basis_emb_size_torsion = params['basis_emb_size_torsion'],
                    out_emb_channels = params['out_emb_channels'], 
                    num_spherical = params['num_spherical'],
                    num_radial = params['num_radial'],
                    envelope_exponent = params['envelope_exponent'],
                    num_before_skip = params['num_before_skip'],
                    num_after_skip = params['num_after_skip'], 
                    num_output_layers = params['num_output_layers'],
                    act=swish, 
                    output_init='GlorotOrthogonal', 
                    use_node_features = True,
                    MLP_hidden_sizes = params['MLP_hidden_sizes'], # [] for contrastive
            )
        
        model.load_state_dict(torch.load(best_state_dict, map_location=next(model.parameters()).device), strict=True)
        model.to(device)
        
        full_test_dataframe = pd.read_pickle(params['test_datafile'])        
        test_dataframe = full_test_dataframe[full_test_dataframe.ID.isin(missing_IDs)]
        
        test_dataset = Dataset_3D_GNN(test_dataframe, 
                                    regression = 'RS_label_binary', # sign_rotation, top_score, score, score_range_binary, relative_score_range_binary, RS_label_binary
                                    getDataFrom = params['getDataFrom'], #'source' or 'mol'
                                    rotate_augmentation = params['rotate_augmentation'], # range from 0.0 to 360.0
                            )
        
        test_loader = torch_geometric.data.DataLoader(test_dataset, num_workers = 0, batch_size = 1, shuffle = False)
        
        targets, outputs = evaluate_classification_loop(model, test_loader, device, batch_size = 1, dataset_size = len(test_dataset))
        
        results_df = deepcopy(test_dataframe[['ID', 'SMILES_nostereo', 'RS_label_binary']])
        results_df['targets'] = targets
        results_df['outputs'] = outputs
    
    y = torch.tensor(np.array(results_df.targets))
    output = torch.tensor(np.array(results_df.outputs))
    
    acc = 1.0 - (torch.sum(torch.abs(y.squeeze().detach() - torch.round(torch.sigmoid(output.squeeze().detach())))) / y.shape[0])
    
    return results_df, float(acc)

In [None]:
# ChIRo
results_df, accuracy_RS = get_accuracy_ChIRo(
    path_to_results = 'paper_results/RS_experiment/ChIRo/best_model_test_results.csv',
    path_to_params_file = 'paper_results/RS_experiment/ChIRo/params_RS_ChIRo.json',
    path_to_model_dict = 'paper_results/RS_experiment/ChIRo/best_model.pt',
)


In [None]:
# schnet
results_df, accuracy_RS = get_accuracy_schnet(
    path_to_results = 'paper_results/RS_experiment/schnet/best_model_test_results.csv',
    path_to_params_file = 'paper_results/RS_experiment/schnet/params_RS_schnet.json',
    path_to_model_dict = 'paper_results/RS_experiment/schnet/best_model.pt',
)


In [None]:
# dimenetpp
results_df, accuracy_RS = get_accuracy_dimenetpp(
    path_to_results = 'paper_results/RS_experiment/dimenetpp/best_model_test_results.csv',
    path_to_params_file = 'paper_results/RS_experiment/dimenetpp/params_RS_dimenetpp.json',
    path_to_model_dict = 'paper_results/RS_experiment/dimenetpp/best_model.pt',
)


In [None]:
# spherenet
results_df, accuracy_RS = get_accuracy_spherenet(
    path_to_results = 'paper_results/RS_experiment/spherenet/best_model_test_results.csv',
    path_to_params_file = 'paper_results/RS_experiment/spherenet/params_RS_spherenet.json',
    path_to_model_dict = 'paper_results/RS_experiment/spherenet/best_model.pt',
)