[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/lowdanie/hartree-fock-solver/blob/main/notebooks/geometry_optimization.ipynb)

# Molecular Geometry Optimization

This notebook demonstrates differentiable quantum chemistry using
[slaterform](https://github.com/lowdanie/hartree-fock-solver).

We optimize the geometry of various molecules by differentiating through `slaterform`'s Hartree-Fock SCF solver using `jax`.

An animation of the electronic density trajectory is generated at the end.

## System Setup

Run the following cells to initialize the optimization loop controller, visualization utilities and molecule definitions.

In [None]:
# @title Pip Installs

!pip install -qq py3Dmol
!pip install -qq git+https://github.com/lowdanie/hartree-fock-solver


In [None]:
# @title Imports { display-mode: "form" }

import jax

jax.config.update("jax_enable_x64", True)

import dataclasses
import io
import time
from typing import Callable, NamedTuple
from collections.abc import Sequence

import jax.numpy as jnp
import numpy as np
import optax

import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import ipywidgets as widgets
from IPython.display import display, clear_output, HTML

import py3Dmol

import slaterform as sf
import slaterform.hartree_fock.scf as scf

In [None]:
print(f"JAX Backend: {jax.devices()[0]}")

In [None]:
# @title Optimization Loop { display-mode: "form" }


class ElectronicState(NamedTuple):
    """The converged electronic state for a given geometry."""

    density: jax.Array
    basis_blocks: Sequence[sf.BasisBlock]


class Snapshot(NamedTuple):
    """Snapshot of one optimization step."""

    params: jax.Array
    energy: jax.Array
    grad: jax.Array
    electronic_state: ElectronicState


def update_molecule_geometry(
    template_mol: sf.Molecule, new_positions: jax.Array
) -> sf.Molecule:
    """Creates a new Molecule with the new positions."""
    new_atoms = []
    for i, atom in enumerate(template_mol.atoms):
        new_atoms.append(
            sf.Atom(
                symbol=atom.symbol,
                number=atom.number,
                position=new_positions[i],
                shells=atom.shells,
            )
        )
    return sf.Molecule(new_atoms)


class GeometryOptimizer:
    """Geometry optimization loop."""

    def __init__(
        self,
        energy_and_grad_fn: Callable,
        start_mol: sf.Molecule,
        fixed_atomic_indices: Sequence[int] = [],
        lr=0.05,
    ):
        self.template_mol = start_mol
        self.initial_params = jnp.array([a.position for a in start_mol.atoms])
        self.params = self.initial_params
        self.mask = jnp.ones((len(start_mol.atoms), 1))
        if fixed_atomic_indices:
            self.mask = self.mask.at[jnp.array(fixed_atomic_indices)].set(0.0)

        self.optimizer = optax.adam(learning_rate=lr)
        self.opt_state = self.optimizer.init(self.params)

        @jax.jit
        def update_step(params, opt_state):
            (E, electronic_state), grad = energy_and_grad_fn(
                params, self.template_mol
            )

            updates, opt_state = self.optimizer.update(
                grad * self.mask, opt_state
            )
            new_params = optax.apply_updates(params, updates)

            snapshot = Snapshot(
                params=params,
                energy=E,
                grad=grad,
                electronic_state=electronic_state,
            )

            return new_params, opt_state, snapshot

        self.update_step = update_step

    def warmup(self):
        """Triggers JIT compilation."""
        print("Warming up JAX kernel (this may take a few minutes)...")
        _ = self.update_step(self.params, self.opt_state)
        print("Warmup complete.")

    def run(self, steps=50, callback: Callable | None = None):
        """Runs the optimization loop from the initial state."""
        history = []

        for i in range(steps):
            self.params, self.opt_state, snapshot = self.update_step(
                self.params, self.opt_state
            )

            history.append(jax.device_get(snapshot))
            if callback is not None:
                callback(i, history[-1])

        return history

    def reset(self):
        """Resets the optimizer to the initial state."""
        self.params = self.initial_params
        self.opt_state = self.optimizer.init(self.params)

In [None]:
# @title Optimization Dashboard { display-mode: "form" }


def _get_lims(values, min_span=0.1, padding_fraction=0.05):
    if not values:
        return 0, 1

    min_val = min(values)
    max_val = max(values)
    span = max_val - min_val

    if span < min_span:
        mid = (max_val + min_val) / 2.0
        return mid - (min_span / 2.0), mid + (min_span / 2.0)

    pad = span * padding_fraction
    return min_val - pad, max_val + pad


class Dashboard:
    def __init__(self, total_steps, min_y_span=0.1):
        self.total_steps = total_steps
        self.min_y_span = min_y_span
        self.energies = []
        self.steps = []
        color = "tab:blue"

        self.fig, self.ax = plt.subplots(figsize=(8, 4))
        self.ax.set_title("Optimization Progress")
        self.ax.set_xlabel("Step")
        self.ax.set_ylabel("Total Energy (Hartree)", color=color)
        self.ax.grid(True, alpha=0.3, linestyle="--")
        self.ax.tick_params(axis="y", labelcolor=color)

        formatter = ticker.ScalarFormatter(useOffset=False)
        formatter.set_scientific(False)
        self.ax.yaxis.set_major_formatter(formatter)

        (self.line,) = self.ax.plot([], [], color=color, lw=2, label="Energy")
        self.ax.legend(loc="upper right")

        plt.close(self.fig)
        self.display_handle = None

    def __call__(self, step, snapshot):
        self.energies.append(snapshot.energy)
        self.steps.append(len(self.energies))

        if self.display_handle is None:
            self.display_handle = display(self.fig, display_id=True)

        if step % 2 != 0 and step != self.total_steps - 1:
            return

        self.line.set_data(self.steps, self.energies)

        ymin, ymax = _get_lims(self.energies, self.min_y_span)
        self.ax.set_ylim(ymin, ymax)
        self.ax.set_xlim(0, max(self.total_steps, len(self.steps)))

        self.display_handle.update(self.fig)

In [None]:
# @title Molecule Renderer  { display-mode: "form" }


@dataclasses.dataclass
class Frame:
    step: int
    snapshot: Snapshot
    density_data: str


def build_frames(
    template_mol: sf.Molecule, history: Sequence[Snapshot], resolution=25
) -> Sequence[Frame]:
    mol = update_molecule_geometry(template_mol, history[0].params)
    grid = sf.analysis.build_bounding_grid(
        mol, padding=3.0, spacing=10.0 / resolution
    )
    frames = []

    for step, snapshot in enumerate(history):
        mol = update_molecule_geometry(template_mol, snapshot.params)
        rho = sf.analysis.evaluate_density(
            snapshot.electronic_state.basis_blocks,
            snapshot.electronic_state.density,
            grid,
        )

        with io.StringIO() as buffer:
            sf.analysis.write_cube_data(
                mol=mol,
                grid=grid,
                data=rho,
                description=f"Step {len(frames)}",
                f=buffer,
            )
            frames.append(Frame(step, snapshot, buffer.getvalue()))

    return frames


def _render(frame, view):
    view.removeAllModels()
    view.removeAllShapes()
    view.removeAllLabels()

    view.addModel(frame.density_data, "cube")
    view.setStyle({"stick": {"radius": 0.15}, "sphere": {"scale": 0.3}})

    view.addVolumetricData(
        frame.density_data,
        "cube",
        {
            "algo": "volume",
            "transferfn": [
                {"value": 0.00, "color": "white", "opacity": 0.0},
                {"value": 0.005, "color": "blue", "opacity": 0.002},
                {"value": 0.05, "color": "blue", "opacity": 0.01},
                {"value": 0.20, "color": "blue", "opacity": 0.05},
            ],
            "smoothness": 5,
        },
    )


def _update_status(frame, status_label):
    energy = frame.snapshot.energy
    status_label.value = (
        f"<div style='font-family: monospace; font-size: 14px;'>"
        f"<b>Energy:</b> {energy:.6f} Ha"
        f"</div>"
    )


def molecule_viewer_app(frames):
    display(
        HTML(
            '<link rel="stylesheet" href="//stackpath.bootstrapcdn.com/font-awesome/4.7.0/css/font-awesome.min.css"/>'
        )
    )
    slider = widgets.IntSlider(
        min=0, max=len(frames) - 1, step=1, description="Step:"
    )
    play = widgets.Play(
        value=0,
        min=0,
        max=len(frames) - 1,
        step=1,
        interval=150,
        description="Play",
        show_repeat=False,
    )
    widgets.jslink((play, "value"), (slider, "value"))

    status_label = widgets.HTML(
        value="<b>Initializing...</b>",
        layout=widgets.Layout(margin="0 0 0 20px", padding="5px"),
    )
    view = py3Dmol.view(width=700, height=500)
    output_container = widgets.Output()

    def on_change(change):
        if change["name"] == "value":
            frame = frames[change["new"]]
            _render(frame, view)
            _update_status(frame, status_label)
            view.update()

    slider.observe(on_change)

    with output_container:
        _render(frames[0], view)
        _update_status(frames[0], status_label)
        view.show()
        view.zoomTo()

    controls = widgets.HBox([play, slider, status_label])
    layout = widgets.VBox([controls, output_container])

    return layout

In [None]:
# @title Molecule Zoo{ display-mode: "form" }


@dataclasses.dataclass
class ExperimentConfig:
    """Configuration for a single experiment."""

    name: str
    molecule: sf.Molecule
    fixed_indices: list[int]


def build_water():
    oh_dist = 1.8
    mol = sf.Molecule.from_geometry(
        [
            sf.Atom("O", 8, jnp.array([0.0, 0.0, 0.0])),
            sf.Atom("H", 1, jnp.array([-oh_dist, 0.1, 0.0])),
            sf.Atom("H", 1, jnp.array([oh_dist, 0.1, 0.0])),
        ],
        basis_name="sto-3g",
    )

    return mol, [0]  # fix oxygen


def build_methane():
    ch_dist = 2.0

    mol = sf.Molecule.from_geometry(
        [
            sf.Atom("C", 6, jnp.array([0.0, 0.0, 0.0])),
            sf.Atom("H", 1, jnp.array([ch_dist, 0.0, 0.1])),
            sf.Atom("H", 1, jnp.array([-ch_dist, 0.0, 0.1])),
            sf.Atom("H", 1, jnp.array([0.0, ch_dist, -0.1])),
            sf.Atom("H", 1, jnp.array([0.0, -ch_dist, 0.1])),
        ],
        basis_name="sto-3g",
    )
    return mol, [0]  # fix carbon


def build_ammonia():
    nh_dist = 1.9
    nh_x = nh_dist * np.sin(np.pi / 6)
    nh_y = nh_dist * np.cos(np.pi / 6)
    mol = sf.Molecule.from_geometry(
        [
            sf.Atom("N", 7, jnp.array([0.0, 0.0, 0.0])),
            sf.Atom("H", 1, jnp.array([nh_dist, 0.0, 0.1])),
            sf.Atom("H", 1, jnp.array([-nh_x, nh_y, 0.1])),
            sf.Atom("H", 1, jnp.array([-nh_x, -nh_y, -0.1])),
        ],
        basis_name="sto-3g",
    )
    return mol, [0]  # fix nitrogen


def build_ethylene():
    # Start twisted 90 degrees (Broken Pi-bond)
    # The hydrogens on the left are flat (XY plane)
    # The hydrogens on the right are vertical (XZ plane)
    cc_dist = 2.5
    ch_dist = 2.0
    h_delta = ch_dist / np.sqrt(2)

    mol = sf.Molecule.from_geometry(
        [
            sf.Atom("C", 6, jnp.array([0.0, 0.0, 0.0])),
            sf.Atom("C", 6, jnp.array([cc_dist, 0.0, 0.0])),
            # Left Hydrogens (Flat)
            sf.Atom("H", 1, jnp.array([-h_delta, h_delta, 0.0])),
            sf.Atom("H", 1, jnp.array([-h_delta, -h_delta, 0.0])),
            # Right Hydrogens (Twisted 90 degrees up/down)
            sf.Atom("H", 1, jnp.array([cc_dist + h_delta, 0.0, h_delta])),
            sf.Atom("H", 1, jnp.array([cc_dist + h_delta, 0.0, -h_delta])),
        ],
        basis_name="sto-3g",
    )

    # Fix one carbon
    return mol, [0]


EXPERIMENT_FACTORIES = {
    "Water (H2O)": build_water,
    "Methane (CH4)": build_methane,
    "Ammonia (NH4)": build_ammonia,
    "Ethylene (C2H4)": build_ethylene,
}


def load_experiment(name: str) -> ExperimentConfig:
    builder_func = EXPERIMENT_FACTORIES[name]
    mol, fixed_indices = builder_func()
    return ExperimentConfig(name, mol, fixed_indices)

# Differentiable Energy

Use `slaterform`'s Hartree-Fock SCF solver to define a differentiable molecular energy function.

In [None]:
def total_energy(positions: jax.Array, template_mol: sf.Molecule):
    """Total energy of the molecule with the specified atomic positions.

    The template molecule determines the atomic numbers and basis set.
    """
    mol = update_molecule_geometry(template_mol, positions)

    options = scf.Options(
        max_iterations=20,
        execution_mode=scf.ExecutionMode.FIXED,
        integral_strategy=scf.IntegralStrategy.CACHED,
        perturbation=1e-10,
    )
    result = scf.solve(mol, options)

    return (
        result.total_energy,
        ElectronicState(result.density, result.basis.basis_blocks),
    )


total_energy_and_grad = jax.value_and_grad(total_energy, has_aux=True)

# Run Simulation

Select a molecule below and watch the geometry evolve to minimize the energy.

In [None]:
# @title Select Molecule { display-mode: "form" }
# @markdown Choose a simulation scenario from the dropdown.

experiment_name = "Water (H2O)"  # @param ["Water (H2O)", "Methane (CH4)", "Ammonia (NH4)", "Ethylene (C2H4)"]
experiment_cfg = load_experiment(experiment_name)

print(f"✅ Loaded: {experiment_name}")
print(f"   • Atoms: {[atom.symbol for atom in experiment_cfg.molecule.atoms]}")
print(f"   • Fixed Atoms Indices: {experiment_cfg.fixed_indices}")

In [None]:
# @title Configure Optimizer

scheduler = optax.cosine_decay_schedule(
    init_value=0.05, decay_steps=100, alpha=0.05
)

# Optimize the geometry using the differentiable energy function.
geo_optimizer = GeometryOptimizer(
    total_energy_and_grad,
    experiment_cfg.molecule,
    experiment_cfg.fixed_indices,
    lr=scheduler,
)

print(f"Configured optimizer for: {experiment_cfg.name}")

In [None]:
# @title JAX Kernel Warmup

geo_optimizer.warmup()

In [None]:
# @title Run The Optimizer

n_steps = 80
dashboard = Dashboard(n_steps)
history = geo_optimizer.run(n_steps, dashboard)

In [None]:
# @title Electronic State Trajectory  { display-mode: "form" }

resolution = 20
frames = build_frames(experiment_cfg.molecule, history, resolution)
clear_output(wait=True)
display(molecule_viewer_app(frames))