In [1]:
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.Chem import Descriptors
from rdkit.Chem import Draw
import numpy as np
import matplotlib.pyplot as plt
from collections import deque

In [2]:
class SimpleEmbed:
    
    
    # Clean and optimize molecule
    @staticmethod
    def prepare_crystal(crystal):
        
        crystal_molecule = Chem.MolFromSmiles(crystal)
        crystal_molecule = Chem.AddHs(crystal_molecule)
        AllChem.EmbedMolecule(crystal_molecule)
        AllChem.UFFOptimizeMolecule(crystal_molecule)
        
        return crystal_molecule
        
        
    # Visualize the molecule
    @staticmethod
    def visualise_crystal(crystal_molecule):

        Draw.MolToImage(crystal_molecule)
        
        
    # Get a conformer
    @staticmethod
    def get_conformer(crystal_molecule):
        
        conformer = crystal_molecule.GetConformer()
        return conformer
        
        
    #get 3D and charge effect information
    @staticmethod
    def prepare_information(crystal_molecule):
        
        coulomb_matrix = AllChem.CalcCoulombMat(crystal_molecule)
        coulomb_matrix_array = np.array(coulomb_matrix)
        
        conformer = SimpleEmbed.get_conformer(crystal_molecule)
        atom_positions = conformer.GetPositions()
        
        info_type, info_3d, info_coulomb = dict(), dict(), dict()
        for i, pos in enumerate(atom_positions):
            atom_index = i + 1  # Atom indices in RDKit start from 1
            atom_type = crystal_molecule.GetAtomWithIdx(i).GetSymbol()
            coulomb_matrix_row = coulomb_matrix_array[i]

            info_type[i] = atom_type
            info_3d[i] = pos
            info_coulomb[i] = coulomb_matrix_row
        
        return info_type, info_3d, info_coulomb
    
    
    # Computes relative vectors
    @staticmethod
    def calculate_relative_vector(start, destination):
    
        if len(start) != 3 or len(destination) != 3:
            raise ValueError("Both start and destination must be 3D coordinates.")

        relative_vector = [destination[i] - start[i] for i in range(3)]
        return relative_vector
    
    
    # Converts molecule to matrix graph
    @staticmethod
    def generate_adjacency_matrix(crystal_molecule):
        
        num_atoms = crystal_molecule.GetNumAtoms()
        adjacency_matrix = [[0] * num_atoms for _ in range(num_atoms)]

        for bond in crystal_molecule.GetBonds():
            atom1_index = bond.GetBeginAtomIdx()
            atom2_index = bond.GetEndAtomIdx()

            adjacency_matrix[atom1_index-1][atom2_index-1] = 1
            adjacency_matrix[atom2_index-1][atom1_index-1] = 1  

        return adjacency_matrix
    
    
    # Breadth first traversal of molecule
    @staticmethod
    def bfs(adj_matrix, start_node):
    
        path = []
        visited = set()
        queue = deque([start_node])

        while queue:
            node = queue.popleft()
            if node not in visited:
            
                path.append(node)
                visited.add(node)
                neighbors = [i for i, value in enumerate(adj_matrix[node]) if value == 1 and i not in visited]
                queue.extend(neighbors)

        return path

    
    # Depth first traversal of molecule
    @staticmethod
    def dfs(adj_matrix, start_node, path, visited=None):

        if visited is None:
            visited = set()

        path.append(start_node)
        visited.add(start_node)

        neighbors = [i for i, value in enumerate(adj_matrix[start_node]) if value == 1 and i not in visited]
        for neighbor in neighbors:
            if neighbor not in visited:
                dfs(adj_matrix, neighbor, path, visited)

        return path
    
    
    # Finding cyclicity to for determine traversal starting points
    @staticmethod
    def find_all_cyclic_paths(adjacency_matrix):
        
        def dfs(node, path):
            visited[node] = True
            path.append(node)

            for neighbor in range(num_nodes):
                if adjacency_matrix[node][neighbor] == 1:
                    if not visited[neighbor]:
                        dfs(neighbor, path.copy())
                    elif neighbor == path[0]:
                        cyclic_paths.append(path.copy())

        num_nodes = len(adjacency_matrix)
        visited = [False] * num_nodes
        cyclic_paths = []

        for start_node in range(num_nodes):
            dfs(start_node, [])

        return cyclic_paths
    
    
    # 3D relative vector embedding
    @staticmethod
    def embed_3d(path, dict_3d, embed_size = 300):
    
        embedding = []
        eos_token = 0.00
        start = [0.00, 0.00, 0.00]

        for i in path:

            destination = dict_3d[i]
            relative_vector = SimpleEmbed.calculate_relative_vector(start, destination)
            embedding.extend(relative_vector)
            embedding.append(eos_token)
            start = destination

        padding = embed_size - len(embedding)
        l0 = [0 for i in range(padding)]
        embedding.extend(l0)

        return np.array(embedding)


    # Coulomb relative charge embedding
    @staticmethod
    def embed_coulomb(path, dict_coulomb, max_atom = 50):

        embed_size = max_atom**2

        embedding = []
        eos_token = -1.00

        for i in path:

            relative_charges = dict_coulomb[i]
            embedding.extend(relative_charges)
            embedding.append(eos_token)

        padding = embed_size - len(embedding)
        l0 = [0 for i in range(padding)]
        embedding.extend(l0)

        return np.array(embedding)


In [4]:
crystal = "CC(=O)OC1=CC=CC=C1C(=O)O"
crystal_molecule = SimpleEmbed.prepare_crystal(crystal)
info_type, info_3d, info_coulomb = SimpleEmbed.prepare_information(crystal_molecule)
crystal_molecule = SimpleEmbed.generate_adjacency_matrix(crystal_molecule)
crystal_path = SimpleEmbed.bfs(crystal_molecule, 0)
embedding_3d = SimpleEmbed.embed_3d(crystal_path, info_3d)
embedding_coulomb = SimpleEmbed.embed_coulomb(crystal_path, info_coulomb)

In [7]:
embedding_coulomb.tolist()

[36.85810519942594,
 24.150540840255488,
 20.172103970361704,
 19.23952563256662,
 9.43026391056734,
 8.159351555129081,
 6.202949725812838,
 5.480257746544277,
 5.805069044368527,
 7.315137365543111,
 7.04626932697485,
 11.054506960816889,
 7.422417876733587,
 5.406657830713434,
 5.402949229263167,
 5.404798776128775,
 1.47467090765742,
 0.9306919560187847,
 0.7853306935432824,
 0.8462298925985668,
 0.8729846682962361,
 -1.0,
 24.150540840255488,
 36.85810519942594,
 38.06932955712109,
 34.26528966414215,
 14.443473705127248,
 11.961640761825855,
 8.213330644305756,
 6.9782047771541915,
 7.361964015358443,
 9.70039214016131,
 8.566793005397107,
 12.801997262797538,
 8.590273538915163,
 2.8110351661749036,
 2.798844584669498,
 2.803569015590264,
 2.162119617372045,
 1.187075549273478,
 0.9647075829117993,
 1.0283489100092424,
 0.9737796023488487,
 -1.0,
 20.172103970361704,
 38.06932955712109,
 73.51669471981023,
 27.605921395620452,
 16.641515553149407,
 16.364675954877672,
 11.476505