# Crystaltoolkit Relaxation Viewer

This notebook shows how to visualize a CHGNet relaxation trajectory in a Plotly Dash app using Crystal Toolkit.


Running the last cell in this notebook should spin up a `dash` app that looks like this:

![Crystaltoolkit Relaxation Viewer Screenshot](https://user-images.githubusercontent.com/30958850/230510639-2e659c9b-3a99-438b-9668-628299171602.png)


In [None]:
try:
    import chgnet
except ImportError:
    # install CHGNet with extra dependencies to run the dash app in this notebook
    # https://github.com/materialsproject/crystaltoolkit
    # (only needed on Google Colab or if you didn't install these packages yet)
    !git clone --depth 1 https://github.com/CederGroupHub/chgnet
    !pip install './chgnet[examples]'

Cloning into 'chgnet'...
remote: Enumerating objects: 50, done.[K
remote: Counting objects: 100% (50/50), done.[K
remote: Compressing objects: 100% (47/47), done.[K
remote: Total 50 (delta 1), reused 17 (delta 0), pack-reused 0[K
Receiving objects: 100% (50/50), 4.25 MiB | 2.70 MiB/s, done.
Resolving deltas: 100% (1/1), done.
zsh:1: no matches found: ./chgnet[crystal-toolkit]


In [None]:
import numpy as np
from pymatgen.core import Structure

In [None]:
try:
    from chgnet import ROOT

    structure = Structure.from_file(f"{ROOT}/examples/mp-18767-LiMnO2.cif")
except Exception:
    from urllib.request import urlopen

    url = "https://github.com/CederGroupHub/chgnet/raw/-/examples/mp-18767-LiMnO2.cif"
    cif = urlopen(url).read().decode("utf-8")
    structure = Structure.from_str(cif, fmt="cif")

In [None]:
print(f"original: {structure.get_space_group_info()}")

# perturb all atom positions by a small amount
for site in structure:
    site.coords += np.random.normal(size=3) * 0.3

# stretch the cell by a small amount
structure.scale_lattice(structure.volume * 1.1)

print(f"perturbed: {structure.get_space_group_info()}")

original: ('Pmmn', 59)
perturbed: ('P1', 1)


In [None]:
import pandas as pd

from chgnet.model import StructOptimizer

trajectory = StructOptimizer().relax(structure)["trajectory"]

CHGNet initialized with 400,438 parameters
CHGNet will run on cpu
      Step     Time          Energy         fmax
*Force-consistent energies used in optimization.
FIRE:    0 14:01:10      -51.912251*      27.2278
FIRE:    1 14:01:10      -54.259518*      12.3964
FIRE:    2 14:01:10      -54.778671*       8.5672
FIRE:    3 14:01:11      -55.339821*       5.5388
FIRE:    4 14:01:11      -55.653206*       7.1592
FIRE:    5 14:01:11      -56.225849*       6.6752
FIRE:    6 14:01:11      -56.975388*       4.2375
FIRE:    7 14:01:11      -57.431259*       4.4837
FIRE:    8 14:01:11      -57.696171*       5.3055
FIRE:    9 14:01:11      -57.933193*       3.3038
FIRE:   10 14:01:11      -57.887894*       6.1535
FIRE:   11 14:01:11      -57.981998*       4.7339
FIRE:   12 14:01:12      -58.107471*       3.2390
FIRE:   13 14:01:12      -58.196518*       2.3609
FIRE:   14 14:01:12      -58.237015*       2.6211
FIRE:   15 14:01:12      -58.271477*       3.3198
FIRE:   16 14:01:12      -58.323418*

In [None]:
e_col = "Energy (eV)"
force_col = "Force (eV/Å)"
df_traj = pd.DataFrame(trajectory.energies, columns=[e_col])
df_traj[force_col] = [
    np.linalg.norm(force, axis=1).mean()  # mean of norm of force on each atom
    for force in trajectory.forces
]
df_traj.index.name = "step"

In [None]:
mp_id = "mp-18767"

dft_energy = -59.09
print(f"{dft_energy=:.2f} eV (see https://materialsproject.org/materials/{mp_id})")

dft_energy=-59.09 eV (see https://materialsproject.org/materials/mp-18767)


In [None]:
import crystal_toolkit.components as ctc
import plotly.graph_objects as go
from crystal_toolkit.settings import SETTINGS
from dash import Dash, dcc, html
from dash.dependencies import Input, Output
from pymatgen.core import Structure

app = Dash(prevent_initial_callbacks=True, assets_folder=SETTINGS.ASSETS_PATH)

step_size = max(1, len(trajectory) // 20)  # ensure slider has max 20 steps
slider = dcc.Slider(
    id="slider", min=0, max=len(trajectory) - 1, step=step_size, updatemode="drag"
)


def plot_energy_and_forces(
    df: pd.DataFrame, step: int, e_col: str, force_col: str, title: str
) -> go.Figure:
    """Plot energy and forces as a function of relaxation step."""
    fig = go.Figure()
    # energy trace = primary y-axis
    fig.add_trace(go.Scatter(x=df.index, y=df[e_col], mode="lines", name="Energy"))
    # get energy line color
    line_color = fig.data[0].line.color

    # forces trace = secondary y-axis
    fig.add_trace(
        go.Scatter(x=df.index, y=df[force_col], mode="lines", name="Forces", yaxis="y2")
    )

    fig.update_layout(
        template="plotly_white",
        title=title,
        xaxis=dict(title="Relaxation Step"),
        yaxis=dict(title=e_col),
        yaxis2=dict(title=force_col, overlaying="y", side="right"),
        legend=dict(yanchor="top", y=1, xanchor="right", x=1),
    )

    # vertical line at the specified step
    fig.add_vline(x=step, line=dict(dash="dash", width=1))

    # horizontal line for DFT final energy
    anno = dict(text="DFT final energy", yanchor="top")
    fig.add_hline(
        y=dft_energy,
        line=dict(dash="dot", width=1, color=line_color),
        annotation=anno,
    )

    return fig


def make_title(spg_symbol: str, spg_num: int) -> str:
    """Return a title for the figure."""
    href = f"https://materialsproject.org/materials/{mp_id}/"
    return f"<a {href=}>{mp_id}</a> - {spg_symbol} ({spg_num})"


title = make_title(*structure.get_space_group_info())

graph = dcc.Graph(
    id="fig",
    figure=plot_energy_and_forces(df_traj, 0, e_col, force_col, title),
    style={"maxWidth": "50%"},
)

struct_comp = ctc.StructureMoleculeComponent(id="structure", struct_or_mol=structure)

app.layout = html.Div(
    [
        html.H1(
            "Structure Relaxation Trajectory", style=dict(margin="1em", fontSize="2em")
        ),
        html.P("Drag slider to see structure at different relaxation steps."),
        slider,
        html.Div([struct_comp.layout(), graph], style=dict(display="flex", gap="2em")),
    ],
    style=dict(margin="auto", textAlign="center", maxWidth="1200px", padding="2em"),
)

ctc.register_crystal_toolkit(app=app, layout=app.layout)


@app.callback(
    Output(struct_comp.id(), "data"), Output(graph, "figure"), Input(slider, "value")
)
def update_structure(step: int) -> tuple[Structure, go.Figure]:
    """Update the structure displayed in the StructureMoleculeComponent and the
    dashed vertical line in the figure when the slider is moved.
    """
    lattice = trajectory.cells[step]
    coords = trajectory.atom_positions[step]
    structure.lattice = lattice  # update structure in place for efficiency
    assert len(structure) == len(coords)
    for site, coord in zip(structure, coords):
        site.coords = coord

    title = make_title(*structure.get_space_group_info())
    fig = plot_energy_and_forces(df_traj, step, e_col, force_col, title)

    return structure, fig


app.run(mode="inline", height=800, use_reloader=False)

  warn("The TEMDiffractionComponent requires the py4DSTEM package.")
  warn(


No module named 'phonopy'
