In [None]:
import torch
import torch.nn as nn
import torch_geometric
from torch_geometric.nn.acts import swish

import sklearn
from sklearn.preprocessing import LabelEncoder
import pandas as pd
import numpy as np
import scipy
import matplotlib.pyplot as plt
import matplotlib

import rdkit
import rdkit.Chem
from rdkit.Chem import TorsionFingerprints
from rdkit import Chem
from rdkit.Chem.Draw import IPythonConsole

import datetime
import gzip
import math
from tqdm import tqdm
from copy import deepcopy
import random
from collections import OrderedDict
import os
import sys
import json

from model.params_interpreter import string_to_object 

from model.alpha_encoder import Encoder

from model.gnn_3D.schnet import SchNet
from model.gnn_3D.dimenet_pp import DimeNetPlusPlus
from model.gnn_3D.spherenet import SphereNet

from model.train_functions import contrastive_loop_alpha
from model.train_models import train_contrastive_model

from model.gnn_3D.train_functions import contrastive_loop
from model.gnn_3D.train_models import train_contrastive_model

from model.datasets_samplers import Dataset_3D_GNN, MaskedGraphDataset, StereoBatchSampler, SiameseBatchSampler, Sample_Map_To_Positives, Sample_Map_To_Negatives, NegativeBatchSampler, SingleConformerBatchSampler


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
test_dataframe = pd.read_pickle('final_data_splits/test_contrastive_MOL_448017_89914_38659.pkl')

In [None]:
def get_ChIRo_model(path_to_params_file = None, path_to_model_dict = None):
    
    with open(str(path_to_params_file)) as f: 
        params_model = json.load(f)
    best_model_state = torch.load(str(path_to_model_dict), map_location=device)
    
    layers_dict = deepcopy(params_model['layers_dict'])
    activation_dict = deepcopy(params_model['activation_dict'])
    for key, value in params_model['activation_dict'].items(): 
        activation_dict[key] = string_to_object[value] # convert strings to actual python objects/functions using pre-defined mapping
    
    num_node_features = 52
    num_edge_features = 14
    
    model = Encoder(
        F_z_list = params_model['F_z_list'], # dimension of latent space
        F_H = params_model['F_H'], # dimension of final node embeddings, after EConv and GAT layers
        F_H_embed = num_node_features, # dimension of initial node feature vector, currently 41
        F_E_embed = num_edge_features, # dimension of initial edge feature vector, currently 12
        F_H_EConv = params_model['F_H_EConv'], # dimension of node embedding after EConv layer
        layers_dict = layers_dict,
        activation_dict = activation_dict,
        GAT_N_heads = params_model['GAT_N_heads'],
        chiral_message_passing = params_model['chiral_message_passing'],
        CMP_EConv_MLP_hidden_sizes = params_model['CMP_EConv_MLP_hidden_sizes'],
        CMP_GAT_N_layers = params_model['CMP_GAT_N_layers'],
        CMP_GAT_N_heads = params_model['CMP_GAT_N_heads'],
        c_coefficient_normalization = params_model['c_coefficient_normalization'], # None, or one of ['softmax']
        sinusoidal_shift = params_model['sinusoidal_shift'], # true or false
        encoder_reduction = params_model['encoder_reduction'], #mean or sum
        output_concatenation_mode = params_model['output_concatenation_mode'], # none or 'contrastive' (if contrastive), conformer, molecule, or z_alpha (if regression)
        EConv_bias = params_model['EConv_bias'], 
        GAT_bias = params_model['GAT_bias'], 
        encoder_biases = params_model['encoder_biases'], 
        dropout = params_model['dropout'], # applied to hidden layers (not input/output layer) of Encoder MLPs, hidden layers (not input/output layer) of EConv MLP, and all GAT layers (using their dropout parameter)
        )
    
    model.load_state_dict(best_model_state, strict=True)
    
    model.to(device)
    
    return model

In [None]:
def get_schnet_model(path_to_params_file = None, path_to_model_dict = None):
    
    with open(str(path_to_params_file)) as f: 
        params_schnet = json.load(f)
    best_schnet_state = torch.load(str(path_to_model_dict), map_location=device)
    
    schnet = SchNet(hidden_channels = params_schnet['hidden_channels'], # 128
               num_filters = params_schnet['num_filters'], # 128
               num_interactions = params_schnet['num_interactions'], # 6
               num_gaussians = params_schnet['num_gaussians'], # 50
               cutoff = params_schnet['cutoff'], # 10.0
               max_num_neighbors = params_schnet['max_num_neighbors'], # 32
               out_channels = params_schnet['out_channels'], # 1
               readout = 'add',
               dipole = False,
               mean = None,
               std = None,
               atomref = None, 
               MLP_hidden_sizes = [], # [] for contrastive
    )
    schnet.load_state_dict(best_schnet_state, strict=True)
    schnet.to(device)
    
    return schnet

In [None]:
def get_dimenetpp_model(path_to_params_file = None, path_to_model_dict = None):
    
    with open(str(path_to_params_file)) as f: 
        params_dimenetpp = json.load(f)
    best_dimenetpp_state = torch.load(str(path_to_model_dict), map_location=device)
    
    dimenetpp = DimeNetPlusPlus(
            hidden_channels = params_dimenetpp['hidden_channels'], # 128
            out_channels = params_dimenetpp['out_channels'], # 1
            num_blocks = params_dimenetpp['num_blocks'], # 4
            int_emb_size = params_dimenetpp['int_emb_size'], # 64
            basis_emb_size = params_dimenetpp['basis_emb_size'], # 8
            out_emb_channels = params_dimenetpp['out_emb_channels'], # 256
            num_spherical = params_dimenetpp['num_spherical'], # 7
            num_radial = params_dimenetpp['num_radial'], # 6
            cutoff=params_dimenetpp['cutoff'], # 5.0
            envelope_exponent=params_dimenetpp['envelope_exponent'], # 5
            num_before_skip=params_dimenetpp['num_before_skip'], # 1
            num_after_skip=params_dimenetpp['num_after_skip'], # 2
            num_output_layers=params_dimenetpp['num_output_layers'], # 3
            act=swish,
            MLP_hidden_sizes = [], # [] for contrastive
        )
    
    dimenetpp.load_state_dict(best_dimenetpp_state, strict=True)
    dimenetpp.to(device)
    
    return dimenetpp

In [None]:
def get_spherenet_model(path_to_params_file = None, path_to_model_dict = None):
    
    with open(str(path_to_params_file)) as f: 
        params_spherenet = json.load(f)
    best_spherenet_state = torch.load(str(path_to_model_dict), map_location=device)
    
    spherenet = SphereNet(
                energy_and_force = False, # False
                cutoff = params_spherenet['cutoff'], # 5.0
                num_layers = params_spherenet['num_layers'], # 4
                hidden_channels = params_spherenet['hidden_channels'], # 128
                out_channels = params_spherenet['out_channels'], # 1
                int_emb_size = params_spherenet['int_emb_size'], # 64
                basis_emb_size_dist = params_spherenet['basis_emb_size_dist'], # 8
                basis_emb_size_angle = params_spherenet['basis_emb_size_angle'], # 8
                basis_emb_size_torsion = params_spherenet['basis_emb_size_torsion'], # 8
                out_emb_channels = params_spherenet['out_emb_channels'], # 256
                num_spherical = params_spherenet['num_spherical'], # 7
                num_radial = params_spherenet['num_radial'], # 6
                envelope_exponent = params_spherenet['envelope_exponent'], # 5
                num_before_skip = params_spherenet['num_before_skip'], # 1
                num_after_skip = params_spherenet['num_after_skip'], # 2
                num_output_layers = params_spherenet['num_output_layers'], # 3
                act=swish, 
                output_init='GlorotOrthogonal', 
                use_node_features = True,
                MLP_hidden_sizes = [], # [] for contrastive
        )
    
    spherenet.load_state_dict(best_spherenet_state, strict=True)
    spherenet.to(device)
    
    return spherenet

In [None]:
def show_atom_number(mol, label = 'atomNote'):
    for atom in mol.GetAtoms():
        atom.SetProp(label, str(atom.GetIdx()))
    return mol

In [None]:
def rotate_bond(mol, dihedral = [1,2,3,4], rot = 0.0): # rot in radians
    mol_rotated = deepcopy(mol)
    angle = rdkit.Chem.rdMolTransforms.GetDihedralRad(mol_rotated.GetConformer(), dihedral[0], dihedral[1], dihedral[2], dihedral[3])
    rdkit.Chem.rdMolTransforms.SetDihedralRad(mol_rotated.GetConformer(), dihedral[0], dihedral[1], dihedral[2], dihedral[3], angle + rot)
    return mol_rotated

In [None]:
def reflect_mol(mol):
    mol_reflected = deepcopy(mol)
    for i in range(mol_reflected.GetNumAtoms()):
        position = rdkit.Chem.rdchem.Conformer.GetAtomPosition(mol_reflected.GetConformer(), i)
        position = list(position)
        position[2] = position[2]*-1
        rdkit.Chem.rdchem.Conformer.SetAtomPosition(mol_reflected.GetConformer(), i, position)
    molblock = rdkit.Chem.MolToMolBlock(mol_reflected)
    mol_reflected = rdkit.Chem.MolFromMolBlock(molblock)
    return mol_reflected

In [None]:
# Choosing example conformers in test set

In [None]:
smiles = 'CC(C)C(C)(Cc1nncn1C)C(=O)O'
conformers_df = test_dataframe[(test_dataframe.SMILES_nostereo == smiles)].reset_index(drop = True)
conformers = list(conformers_df.rdkit_mol_cistrans_stereo)
IDs = list(conformers_df.ID)
show_atom_number(conformers[6])

In [None]:
# reflecting each conformer across the xy plane

In [None]:
reflected_conformers = conformers + [reflect_mol(conf) for conf in conformers]
reflected_df = pd.DataFrame()
reflected_df['ID'] = [rdkit.Chem.MolToSmiles(conf) for conf in reflected_conformers]
reflected_df['SMILES_nostereo'] = [smiles]*len(IDs) + [smiles]*len(IDs)
reflected_df['rdkit_mol_cistrans_stereo'] = reflected_conformers

In [None]:
# rotating bonds near the chiral center

In [None]:
def get_rotated_conformers(smile_IDs, conformers, dihedral, rotations):
    rot_conformers = [deepcopy(c) for c in conformers]

    all_conformers = []
    all_IDs = []
    for i, conf in enumerate(rot_conformers):
        conformers_rotated = [rotate_bond(conf, dihedral = dihedral, rot = rot) for rot in rots]
        smi = [smile_IDs[i]]*len(conformers_rotated)
        all_conformers += conformers_rotated
        all_IDs += smi
    
    rotated_df = pd.DataFrame()
    rotated_df['ID'] = all_IDs
    rotated_df['SMILES_nostereo'] = [smiles]*len(all_IDs)
    rotated_df['rdkit_mol_cistrans_stereo'] = all_conformers
    
    return rotated_df

In [None]:
rots = np.arange(0, 360, 30) * (np.pi/180)

torsion_1 = [12, 7, 5, 8]
rotated_df_1 = get_rotated_conformers(smile_IDs = IDs, 
                                      conformers = conformers, 
                                      dihedral = torsion_1, 
                                      rotations = rots)

torsion_2 = [0, 11, 5, 8]
rotated_df_2 = get_rotated_conformers(smile_IDs = rotated_df_1['ID'], 
                                      conformers = list(rotated_df_1['rdkit_mol_cistrans_stereo']), 
                                      dihedral = torsion_2, 
                                      rotations = rots)

torsion_3 = [0, 11, 5, 8]
rotated_df_3 = get_rotated_conformers(smile_IDs = rotated_df_2['ID'], 
                                      conformers = list(rotated_df_2['rdkit_mol_cistrans_stereo']), 
                                      dihedral = torsion_3, 
                                      rotations = rots)

In [None]:
def get_ChIRo_latent_space(model, conformer_df):
    
    test_dataset_model = MaskedGraphDataset(conformer_df, 
                                    regression = '', #'', score, score_range_binary, relative_score_range_binary, RS_label_binary
                                    stereoMask = True,
                                    mask_coordinates = False, 
                                    )
    test_loader_model = torch_geometric.data.DataLoader(test_dataset_model, shuffle = False, batch_size = 100)
    
    def get_local_structure_map(psi_indices):
        LS_dict = OrderedDict()
        LS_map = torch.zeros(psi_indices.shape[1], dtype = torch.long)
        v = 0
        for i, indices in enumerate(psi_indices.T):
            tupl = (int(indices[1]), int(indices[2]))
            if tupl not in LS_dict:
                LS_dict[tupl] = v
                v += 1
            LS_map[i] = LS_dict[tupl]
    
        alpha_indices = torch.zeros((2, len(LS_dict)), dtype = torch.long)
        for i, tupl in enumerate(LS_dict):
            alpha_indices[:,i] = torch.LongTensor(tupl)
    
        return LS_map, alpha_indices
    
    latent_space = torch.zeros((len(test_dataset_model), 2))
    start = 0
    for batch_data in tqdm(test_loader_model):    
        psi_indices = batch_data.dihedral_angle_index
        LS_map, alpha_indices = get_local_structure_map(psi_indices)
    
        batch_data = batch_data.to(device)
        LS_map = LS_map.to(device)
        alpha_indices = alpha_indices.to(device)
    
        latent_vector, phase_shift_norm, z_alpha, mol_embedding, c_tensor, phase_cos, phase_sin, sin_cos_psi, sin_cos_alpha = model(batch_data, LS_map, alpha_indices)
        
        latent_vector = latent_vector[:, latent_vector.shape[1]//3 * 2 :]
        latent_space[start:start + latent_vector.shape[0]] = latent_vector
        start += latent_vector.shape[0]
    
    return latent_space

In [None]:
def get_3D_GNN_latent_space(model, conformer_df):
    test_dataset_3D_GNN = Dataset_3D_GNN(conformer_df, 
                                    regression = '',
                              )
    test_loader_3D_GNN = torch_geometric.data.DataLoader(test_dataset_3D_GNN, shuffle = False, batch_size = 100)
    
    latent_space = torch.zeros((len(test_dataset_3D_GNN), 2))

    start = 0
    for batch_data in tqdm(test_loader_3D_GNN):
        batch_data = batch_data.to(device)
        
        node_batch = deepcopy(batch_data.batch)
        z = deepcopy(batch_data.x)
        pos = deepcopy(batch_data.pos)
        
        try:
            latent_vector = model(z.squeeze(), pos, node_batch)
        except Exception as e:
            print('3D GNN failed to process batch: ', start)
            print(e)
            latent_vector = torch.zeros((int(max(node_batch.squeeze().detach().numpy())), 2))
        
        latent_space[start:start + latent_vector.shape[0]] = latent_vector.detach().cpu()
        
        start += latent_vector.shape[0]
    
    return latent_space

In [None]:
# initialize a model, compute latent vectors for each conformer in specified dataframe, and plot latent space

In [None]:
chiro = get_ChIRo_model(path_to_params_file = 'paper_results/contrastive_experiment/ChIRo/params_contrastive_ChIRo.json', 
                        path_to_model_dict = 'paper_results/contrastive_experiment/ChIRo/best_model.pt')

plot_df = reflected_df

latent_space = get_ChIRo_latent_space(chiro, plot_df)

In [None]:
matplotlib.rcParams['pdf.fonttype'] = 42

cmap = matplotlib.cm.bwr(np.linspace(0.,1,40,))
cmap = matplotlib.colors.ListedColormap(cmap[2:,:-1])

fig, ax = plt.subplots(1, 1, figsize = [4, 4])

le = LabelEncoder()
labels = le.fit_transform(plot_df.ID) 

plot = ax.scatter(latent_space[:, 0], latent_space[:, 1], c = labels, cmap=cmap, s = 400, alpha = 0.5, edgecolors = 'black')

ax.ticklabel_format(scilimits = (-1, 1))
fig.tight_layout(pad = 1.0)
plt.show()