In [26]:
import h5py
import pandas as pd
import numpy as np

In [23]:

def parse_database_file_into_data_frame(path_to_db):
    # Initialize an empty list to collect data
    molecule_data = []

    print(path_to_db)

    # Open the HDF5 file
    with h5py.File(path_to_db, "r") as f:
        print("Loading file")
        print(f.keys())
        for molecule_id in f.keys():
            molecule_group = f[molecule_id]
            
            # Retrieve SMILES string
            smiles = molecule_group['smiles'][0] if 'smiles' in molecule_group else None
            
            # Retrieve conformations (if they exist)
            if 'conformations' in molecule_group:
                conformations = molecule_group['conformations'][:]
                num_conformations, num_atoms, _ = conformations.shape

                # Add each conformation separately
                for i in range(num_conformations):
                    molecule_data.append({
                        'id': molecule_id,
                        'smiles': smiles,
                        'conformation_index': i,
                        'num_atoms': num_atoms,
                        'positions': conformations[i]
                    })
            else:
                # If there are no conformations, add molecule data without positions
                molecule_data.append({
                    'id': molecule_id,
                    'smiles': smiles,
                    'conformation_index': None,
                    'num_atoms': None,
                    'positions': None
                })

    return pd.DataFrame(molecule_data)


In [24]:
# Define the path to the HDF5 file
hdf5_path = "/home/sebidom/dom/manifold_contgfn/spice-dataset/pubchem/pubchem-1-2500.hdf5"

# Load the data into a DataFrame
df = parse_database_file_into_data_frame(hdf5_path)

# Display the first few rows
print("Data loaded into DataFrame:")
print(df.head())

/home/sebidom/dom/manifold_contgfn/spice-dataset/pubchem/pubchem-1-2500.hdf5
Loading file
KeysView(<HDF5 file "pubchem-1-2500.hdf5" (mode r)>)
Data loaded into DataFrame:
          id                                             smiles  \
0  103914790  [N:1]1=[C:2]2[N:3]([C:5]([H:17])([H:18])[C:4]1...   
1  103914790  [N:1]1=[C:2]2[N:3]([C:5]([H:17])([H:18])[C:4]1...   
2  103914790  [N:1]1=[C:2]2[N:3]([C:5]([H:17])([H:18])[C:4]1...   
3  103914790  [N:1]1=[C:2]2[N:3]([C:5]([H:17])([H:18])[C:4]1...   
4  103914790  [N:1]1=[C:2]2[N:3]([C:5]([H:17])([H:18])[C:4]1...   

   conformation_index  num_atoms  \
0                   0         32   
1                   1         32   
2                   2         32   
3                   3         32   
4                   4         32   

                                           positions  
0  [[-0.058113288, 0.26829422, 0.07045778], [-0.0...  
1  [[-0.1205943, -0.17438053, 0.18773785], [-0.05...  
2  [[0.11087129, 0.0039972714, 0.25160453], 

In [25]:
with pd.option_context('display.max_colwidth', None):
    # Print the first row in full
    print("\nFirst row in full:")
    print(df.iloc[0])


First row in full:
id                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                  

In [29]:
def create_conformer_dataset_from_data_frame(df):
    #  Save a list of smiles strings
    smiles_strings = df['smiles'].unique()
    smiles_strings = np.array(smiles_strings)
    # Save as a pkl file for later use
    np.save("smiles_strings.npy", smiles_strings)

create_conformer_dataset_from_data_frame(df)


In [4]:
class ConformerDataset:

    def __init__(self, data_pkl_path: str):
        self.conformers = pickle.load(open(data_pkl_path, "rb"))

    def get_random_molecule(self):
        return np.random.choice(self.conformers)

class Molecule:

    def __init__(self,
                 smiles,
                 n_torsion_angles,
                 torsion_indices):
        
        self.smiles = smiles
        self.n_torsion_angles = n_torsion_angles
        self.torsion_indices = torsion_indices

In [27]:
import copy
from typing import List, Optional, Tuple, Union
import pickle

import dgl
import numpy as np
import numpy.typing as npt
import torch
from rdkit import Chem
from rdkit.Chem import AllChem
from torchtyping import TensorType

from gflownet.envs.ctorus import ContinuousTorus
from gflownet.utils.molecule.constants import ad_atom_types
from gflownet.utils.molecule.featurizer import MolDGLFeaturizer
from gflownet.utils.molecule.rdkit_conformer import RDKitConformer
from gflownet.utils.molecule.rotatable_bonds import find_rotor_from_smiles

class Conformers(ContinuousTorus):
    """
    Extension of continuous torus to conformer generation. Based on AlanineDipeptide,
    but accepts any molecule (defined by SMILES and freely rotatable torsion angles).
    """

    def __init__(
        self,
        conformer_dataset: ConformerDataset,
        remove_hs: bool = True,
        **kwargs,
    ):  
        mol = conformer_dataset.get_random_molecule()
        self.smiles = mol.smiles
        self.torsion_indices = list(range(mol.n_torsion_angles))
        self.atom_positions = Conformers._get_positions(self.smiles)
        self.torsion_angles = Conformers._get_torsion_angles(self.smiles, self.torsion_indices)
        self.set_conformer()

        # Conversions
        self.statebatch2oracle = self.statebatch2proxy
        self.statetorch2oracle = self.statetorch2proxy

        self.graph = MolDGLFeaturizer(ad_atom_types).mol2dgl(self.conformer.rdk_mol)
        # TODO: use DGL conformer instead
        rotatable_edges = [ta[1:3] for ta in self.torsion_angles]
        for i in range(self.graph.num_edges()):
            if (
                self.graph.edges()[0][i].item(),
                self.graph.edges()[1][i].item(),
            ) not in rotatable_edges:
                self.graph.edata["rotatable_edges"][i] = False

        # Hydrogen removal
        self.remove_hs = remove_hs
        self.hs = torch.where(self.graph.ndata["atom_features"][:, 0] == 1)[0]
        self.non_hs = torch.where(self.graph.ndata["atom_features"][:, 0] != 1)[0]
        if remove_hs:
            self.graph = dgl.remove_nodes(self.graph, self.hs)

        super().__init__(n_dim=len(self.conformer.freely_rotatable_tas), **kwargs)

        self.sync_conformer_with_state()

    def reset_with_random_molecule(self):
        mol = self.conformer_dataset.get_random_molecule()
        self.reset_conformer(mol)

    def reset_conformer(self, molecule: Molecule):
        self.smiles = molecule.smiles
        self.torsion_indices = list(range(molecule.n_torsion_angles))
        self.atom_positions = Conformers._get_positions(self.smiles)
        self.torsion_angles = Conformers._get_torsion_angles(self.smile, self.torsion_indices)
        self.set_conformer()
        self.graph = MolDGLFeaturizer(ad_atom_types).mol2dgl(self.conformer.rdk_mol)
        rotatable_edges = [ta[1:3] for ta in self.torsion_angles]
        for i in range(self.graph.num_edges()):
            if (
                self.graph.edges()[0][i].item(),
                self.graph.edges()[1][i].item(),
            ) not in rotatable_edges:
                self.graph.edata["rotatable_edges"][i] = False

        # Hydrogen removal
        self.hs = torch.where(self.graph.ndata["atom_features"][:, 0] == 1)[0]
        self.non_hs = torch.where(self.graph.ndata["atom_features"][:, 0] != 1)[0]
        if self.remove_hs:
            self.graph = dgl.remove_nodes(self.graph, self.hs)

        # TODO: Initialise the parent class

        self.sync_conformer_with_state()


    def set_conformer(self, state: Optional[List] = None) -> RDKitConformer:
        self.conformer = RDKitConformer(
            self.atom_positions, self.smiles, self.torsion_angles
        )

        if state is not None:
            self.sync_conformer_with_state(state)

        return self.conformer

    @staticmethod
    def _get_positions(smiles: str) -> npt.NDArray:
        mol = Chem.MolFromSmiles(smiles)
        mol = Chem.AddHs(mol)
        AllChem.EmbedMolecule(mol, randomSeed=0)
        return mol.GetConformer().GetPositions()

    @staticmethod
    def _get_torsion_angles(
        smiles: str, indices: Optional[List[int]]
    ) -> List[Tuple[int]]:
        torsion_angles = find_rotor_from_smiles(smiles)
        if indices is not None:
            torsion_angles = [torsion_angles[i] for i in indices]
        return torsion_angles

    def sync_conformer_with_state(self, state: List = None):
        if state is None:
            state = self.state
        for idx, ta in enumerate(self.conformer.freely_rotatable_tas):
            self.conformer.set_torsion_angle(ta, state[idx])
        return self.conformer

    def statebatch2proxy(self, states: List[List]) -> npt.NDArray:
        """
        Returns a list of proxy states, each being a numpy array with dimensionality
        (n_atoms, 4), in which the first column encodes atomic number, and the last
        three columns encode atom positions.
        """
        states_proxy = []
        for st in states:
            conf = self.sync_conformer_with_state(st)
            states_proxy.append(
                np.concatenate(
                    [
                        conf.get_atomic_numbers()[..., np.newaxis],
                        conf.get_atom_positions(),
                    ],
                    axis=1,
                )
            )
        return np.array(states_proxy)

    def statetorch2proxy(self, states: TensorType["batch", "state_dim"]) -> npt.NDArray:
        return self.statebatch2proxy(states.cpu().numpy())

    def statebatch2policy_gnn(self, states: List[List]) -> npt.NDArray[np.float32]:
        """
        Returns an array of GNN-format policy inputs with dimensionality
        (n_states, n_atoms, 4), in which the first three columns encode atom positions,
        and the last column encodes current timestep.
        """
        policy_input = []
        for state in states:
            conformer = self.sync_conformer_with_state(state)
            positions = conformer.get_atom_positions()
            if self.remove_hs:
                positions = positions[self.non_hs]
            policy_input.append(
                np.concatenate(
                    [positions, np.full((positions.shape[0], 1), state[-1])],
                    axis=1,
                )
            )
        return np.array(policy_input)

    def statebatch2kde(self, states: List[List]) -> npt.NDArray[np.float32]:
        return np.array(states)[:, :-1]

    def statetorch2kde(
        self, states: TensorType["batch_size", "state_dim"]
    ) -> TensorType["batch_size", "state_proxy_dim"]:
        return states.cpu().numpy()[:, :-1]

    def __deepcopy__(self, memo):
        cls = self.__class__
        new_instance = cls.__new__(cls)

        for attr_name, attr_value in self.__dict__.items():
            if attr_name != "conformer":
                setattr(new_instance, attr_name, copy.copy(attr_value))

        new_instance.conformer = self.conformer

        return new_instance


ModuleNotFoundError: No module named 'dgl'