In [2]:
from collections.abc import Iterator
from typing import Any, Union, Optional

from torch_geometric.data import Data, Batch

import numpy as np
from ase import Atoms
from ase.optimize import BFGS
from loguru import logger as logging
from rdkit import Chem


In [None]:
class BFGSBatched(BFGS):
    """BFGS optimiser with exit conditions for strain relief.

    Exit conditions:
    1. Maximum force on any atom > fexit (dynamics exploding).
    2. Number of steps exceeds max_steps.
    3. Forces have converged (max force < fmax).
    """
    def __init__(self,
                 atoms: Atoms | Iterator[Atoms],
                 fmax: Optional[float] = None,
                 fexit: Optional[float] = None,
                 max_steps: int = -1,
                 **kwargs: Any) -> None:
        
        self.fexit = fexit
        self.max_steps = max_steps
        self.calculator = None
        super().__init__(atoms, fmax=fmax, **kwargs)

    def __post_init__(self):
        if self.max_stesp == -1 and not self.fmax:
            raise ValueError("Either fmax or max_steps must be set to define convergence.")

    def run(self, batch: Union[Data, Batch]) -> None:
        # if self.max_steps = -1, run until convergence

        if self.calculator is None:
            raise AttributeError("BFGSBatched.calculator must be set before running dynamics.")

        self._check_batch(batch)
        pass

    def _check_batch(batch):
        # check that the batch has correct attributues
        pass


dyn = BFGSBatched()
dyn.run()

In [None]:
class MACECalculator():
    pass

class FAIRChemCalculator():
    pass

In [None]:
class Conformer(Data):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    # .pos [n_atoms, 3]
    # .atom_types [n_atoms]

    @classmethod
    def from_ase(self, atoms: Atoms) -> "Conformer":
        pass

    @classmethod
    def from_rdkit(self, mol: Chem.Mol) -> "Conformer":
        pass

    def to_ase(self) -> Atoms:
        pass

    def to_rdkit(self) -> Chem.Mol:
        pass


class ConformerBatch(Batch):

    # batch.batch [n_atoms] -> which atoms belongs to which conformer
    # batch.ptr [n_conformers + 1] -> index pointers to start of each conformer in .pos and .atom_types
    # batch.molecule_idxs [n_atoms] -> which molecule each atom belongs to

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    @classmethod
    def from_rdkit(cls, mols: list[Chem.Mol]|Chem.Mol) -> "ConformerBatch":
        if isinstance(mols, Chem.Mol):
            mols = [mols]

        conformers = []

        for mol in mols:
            for conformer in mol.GetConformers():
                conformers.append(Conformer.from_rdkit(conformer))

        return cls.from_data_list(conformers)
    
    def to_rdkit(self) -> list[Chem.Mol]:
        # write list of conformers to list? or do i want to have multiple conformers per molecule?
        pass

In [None]:
# example with rdkit moleucles
# run with fmax

# example with graphein.Protein objects
# run for n steps

# example with ase.Atoms objects
# run with fmax and n steps and fexit

