In [None]:
from __future__ import annotations

import json
import tempfile
from pathlib import Path

import numpy as np  # noqa: F401 (imported for future extensions)
from ase.calculators.calculator import all_changes
from ase.io import read
from ase.optimize import BFGS
from rdkit import Chem
from rdkit.Chem import AllChem

from cct.energy.consts import ANGSTROM_TO_BOHR, EH2KCALMOL


def xtb_singlepoint(
    xyz_file: Path,
    charge: int = 0,
    mult: int = 1,
    method: Literal["gfn2"] = "gfn2",
    solvent: Literal["water"] | None = None,
):
    spin = (mult - 1) / 2

    with tempfile.TemporaryDirectory() as tmpdir:
        tmpdir = Path(tmpdir)

        json_file = tmpdir / "out.json"

        cmd = [
            "tblite",
            "run",
            "--method",
            str(method),
            "--charge",
            str(charge),
            "--spin",
            str(mult - 1),
            "--json",
            str(json_file),
        ]

        if solvent:
            cmd += ["--alpb", solvent]

        cmd.append(str(xyz_file))

        run_command(cmd, cwd=tmpdir)

        with open(json_file) as f:
            output = json.load(f)

        energy = output["energy"] * EH2KCALMOL

        return energy


def crest_cregen(
    xyz_in: Path,
    charge: int = 0,
    mult: int = 1,
    threads: int = 0,
    workdir: Path | None = None,
) -> Path:
    """Run `crest --screen` and return the resulting `screen.xyz` path."""
    run_dir = workdir or xyz_in.parent
    cmd = [
        "crest",
        "--cregen",
        str(xyz_in),
        "--chrg",
        str(charge),
        "--uhf",
        str(mult - 1),
        "--rthr",
        str(RMSD_THRESHOLD),
        "--rot",
        str(ROTCONST_FRAC),
        "--ethr",
        str(ENERGY_WINDOW),
    ]

    run_command(cmd, cwd=tmpdir)

    return run_dir / "crest_ensemble.xyz"  # CREST writes here by default


def crest_opt(
    xyz_in: Path,
    charge: int = 0,
    mult: int = 1,
    threads: int = 0,
    workdir: Path | None = None,
) -> Path:
    """Run `crest --screen` and return the resulting `screen.xyz` path."""
    run_dir = workdir or xyz_in.parent
    cmd = [
        "crest",
        "--mdopt",
        str(xyz_in),
        "--chrg",
        str(charge),
        "--uhf",
        str(mult - 1),
    ]
    if threads:
        cmd += ["-T", str(threads)]

    run_command(cmd, cwd=tmpdir)

    return run_dir / "crest_ensemble.xyz"  # CREST writes here by default


def filter_by_energy(mol, energy_window=1.0):
    with tempfile.TemporaryDirectory() as tmpdir:
        tmpdir = Path(tmpdir)
        raw_xyz = tmpdir / "conf.xyz"

        energies = []
        for conf_id in range(mol.GetNumConformers()):
            Chem.MolToXYZFile(mol, raw_xyz, confId=conf_id)
            energy = xtb_singlepoint(raw_xyz)
            energies.append(energy)

    energies = np.array(energies)
    energy_min = np.min(energies)
    relative_energies = energies - energy_min

    low_energy_conf_ids = np.argwhere(relative_energies < energy_window).flatten()

    filtered_mol = Chem.Mol(mol)  # Copy the molecule
    filtered_mol.RemoveAllConformers()  # Remove all conformers

    # Add back only the low-energy conformers
    for conf_id in low_energy_conf_ids:
        conf = mol.GetConformer(int(conf_id))
        filtered_mol.AddConformer(conf, assignId=True)

    return filtered_mol


def generate_conformers(
    mol, charge=0, mult=1, energy_window=1.0, num_confs=300, threads=10
):
    mol = Chem.AddHs(mol)
    mol = embed(mol, num_confs=100)
    mol = filter_by_energy(mol, energy_window=energy_window)

    with tempfile.TemporaryDirectory() as tmpdir:
        tmpdir = Path(tmpdir)
        raw_xyz = tmpdir / "ensemble.xyz"
        write_multi_xyz(mol, raw_xyz)
        kept_xyz = crest_opt(raw_xyz, charge, mult, threads=threads, workdir=tmpdir)
        kept_xyz = crest_cregen(kept_xyz, charge, mult, threads=threads, workdir=tmpdir)

        frames = read(kept_xyz, index=":")
        mols = []
        energies = []
        for atoms in frames:
            atoms.calc = calculator
            dyn = BFGS(atoms)
            dyn.run(fmax=0.05)

            calculator.calculate(atoms, ["energy"], all_changes)
            energy = calculator.results["energy"]

            coords = atoms.get_positions()

            new_mol = update_confs(mol, coords)
            mols.append(new_mol)
            energies.append(energy)

    mols = np.array(mols)[np.argsort(energies)].tolist()

    return mols

In [None]:
from cct.utils import run_command

calcs = CalculatorFactory()
mol = Chem.MolFromSmiles("O")
mol = Chem.AddHs(mol)

print(Chem.MolToSmiles(mol))
AllChem.EmbedMolecule(mol)

Chem.MolToXYZFile(mol, "/tmp/mol.xyz")
xtb_singlepoint("/tmp/mol.xyz") - xtb_singlepoint("/tmp/mol.xyz", solvent="water")

In [None]:
from rdkit import Chem
from rdkit.Chem import AllChem

from cct.energy.energy import CalculatorFactory

calcs = CalculatorFactory()
mol = Chem.MolFromSmiles("O")
AllChem.EmbedMolecule(mol)

print(calcs.singlepoint(mol))
print(calcs.optimise(mol))

In [None]:
import tblite
import tblite.interface as tb

tblite.ase

In [None]:
atoms = read("/tmp/conf.xyz")
atomic_numbers = atoms.get_atomic_numbers()
positions = atoms.get_positions() * ANGSTROM_TO_BOHR

xtb = tb.Calculator("GFN2-xTB", atomic_numbers, positions)
g = xtb.singlepoint().get("energy") * EH2KCALMOL

xtb.add("cpcm-solvation", 78.4)
s = xtb.singlepoint().get("energy") * EH2KCALMOL


s - g

In [None]:
!cat /tmp/crest_ensemble.xyz

In [None]:
!tblite run --help

In [None]:
relative_energies.max()