In [None]:
"""Jupyter notebook demo for pymatviz widgets."""
# /// script
# dependencies = [
#     "pymatgen>=2024.1.1",
#     "ase>=3.22.0",
#     "phonopy>=2.20.0",
# ]
# ///

# %%
import itertools
from typing import Final

import numpy as np
from ase.build import bulk, molecule
from ipywidgets import GridBox, Layout
from phonopy.structure.atoms import PhonopyAtoms
from pymatgen.core import Composition, Lattice, Structure

import pymatviz as pmv


np_rng = np.random.default_rng(seed=0)

In [None]:
# %% Convex Hull — compute stability from PhaseDiagram entries
from pymatgen.analysis.phase_diagram import PDEntry, PhaseDiagram


phase_diag = PhaseDiagram(
    [
        PDEntry(Composition("Li"), -1.9),
        PDEntry(Composition("Fe"), -4.2),
        PDEntry(Composition("O"), -3.0),
        PDEntry(Composition("Li2O"), -14.3),
        PDEntry(Composition("Fe2O3"), -25.5),
        PDEntry(Composition("LiFeO2"), -18.0),
        PDEntry(Composition("FeO"), -8.5),
    ]
)

pmv.ConvexHullWidget(entries=phase_diag, style="height: 500px;")

In [None]:
# %% Structure Widget — wurtzite GaN (hexagonal, more interesting than cubic)
struct = Structure(
    lattice=Lattice.hexagonal(3.19, 5.19),
    species=["Ga", "Ga", "N", "N"],
    coords=[
        [1 / 3, 2 / 3, 0],
        [2 / 3, 1 / 3, 0.5],
        [1 / 3, 2 / 3, 0.375],
        [2 / 3, 1 / 3, 0.875],
    ],
)

pmv.StructureWidget(structure=struct, show_bonds=True, style="height: 400px;")

In [None]:
# %% Brillouin Zone — hexagonal BZ from GaN
pmv.BrillouinZoneWidget(structure=struct, show_vectors=True, style="height: 400px;")

In [None]:
# %% XRD Pattern — rutile TiO2 (tetragonal, richer peak pattern than cubic Si)
from pymatgen.analysis.diffraction.xrd import XRDCalculator


tio2_struct = Structure(
    Lattice.tetragonal(4.594, 2.959),
    ["Ti", "Ti", "O", "O", "O", "O"],
    [
        [0, 0, 0],
        [0.5, 0.5, 0.5],
        [0.305, 0.305, 0],
        [0.695, 0.695, 0],
        [0.195, 0.805, 0.5],
        [0.805, 0.195, 0.5],
    ],
)
xrd_pattern = XRDCalculator().get_pattern(tio2_struct)

pmv.XrdWidget(patterns=xrd_pattern, style="height: 350px;")

In [None]:
# %% Trajectory Widget — expanding lattice with energy/force properties
trajectory = []
base_struct = Structure(
    lattice=Lattice.cubic(3.0),
    species=("Fe", "Fe"),
    coords=((0, 0, 0), (0.5, 0.5, 0.5)),
)

for idx in range(n_steps := 20):
    struct_frame = base_struct.perturb(distance=0.2).copy()
    energy = n_steps / 2 - idx * np_rng.random()
    np.fill_diagonal(dist_max := struct_frame.distance_matrix, np.inf)
    trajectory.append(
        {"structure": struct_frame, "energy": energy, "force_max": 1 / dist_max.min()}
    )

pmv.TrajectoryWidget(
    trajectory=trajectory,
    display_mode="structure+scatter",
    show_force_vectors=True,
    style="height: 600px;",
)

In [None]:
# %% Band Structure Widget — dict-based demo
band_data = {
    "@module": "pymatgen.electronic_structure.bandstructure",
    "@class": "BandStructureSymmLine",
    "bands": {
        "1": [[0.5 * np.sin(k / 10) + idx for k in range(50)] for idx in range(4)]
    },
    "efermi": 0.0,
    "kpoints": [[k / 50, 0, 0] for k in range(50)],
    "labels_dict": {"\\Gamma": [0, 0, 0], "X": [0.5, 0, 0], "M": [0.5, 0.5, 0]},
    "lattice_rec": {"matrix": [[1, 0, 0], [0, 1, 0], [0, 0, 1]]},
}

pmv.BandStructureWidget(band_structure=band_data, style="height: 400px;")

In [None]:
# %% DOS Widget — dict-based demo
dos_data = {
    "@module": "pymatgen.electronic_structure.dos",
    "@class": "Dos",
    "energies": np.linspace(-5, 5, 200).tolist(),
    "densities": {"1": np.exp(-0.5 * np.linspace(-5, 5, 200) ** 2).tolist()},
    "efermi": 0.0,
}

pmv.DosWidget(dos=dos_data, style="height: 400px;")

In [None]:
# %% Combined Bands + DOS
pmv.BandsAndDosWidget(band_structure=band_data, dos=dos_data, style="height: 500px;")

In [None]:
# %% Composition Widget — grid of compositions x modes
comps = (
    "Fe2 O3",
    Composition("Li P O4"),
    dict(Co=20, Cr=20, Fe=20, Mn=20, Ni=20),
    dict(Ti=20, Zr=20, Nb=20, Mo=20, V=20),
)
modes = ("pie", "bar", "bubble")
size = 100
children = [
    pmv.CompositionWidget(
        composition=comp,
        mode=mode,
        style=f"width: {(1 + (mode == 'bar')) * size}px; height: {size}px;",
    )
    for comp, mode in itertools.product(comps, modes)
]
layout = Layout(
    grid_template_columns=f"repeat({len(modes)}, auto)",
    grid_gap="2em 4em",
    padding="2em",
)
GridBox(children=children, layout=layout)

In [None]:
# %% Render remote ASE trajectory file
matterviz_traj_dir_url: Final = (
    "https://github.com/janosh/matterviz/raw/6288721042/src/site/trajectories"
)

file_name = "Cr0.25Fe0.25Co0.25Ni0.25-mace-omat-qha.xyz.gz"
pmv.TrajectoryWidget(
    data_url=f"{matterviz_traj_dir_url}/{file_name}",
    display_mode="structure+scatter",
    show_force_vectors=True,
    force_vector_scale=0.5,
    force_vector_color="#ff4444",
    show_bonds=True,
    bonding_strategy="nearest_neighbor",
    style="height: 600px;",
)

In [None]:
# %% ASE Atoms — auto-rendered via MIME type recognition
ase_atoms = bulk("Al", "fcc", a=4.05)
ase_atoms *= (2, 2, 2)
ase_atoms

In [None]:
# %% ASE molecule
ase_molecule = molecule("H2O")
ase_molecule.center(vacuum=3.0)
ase_molecule

In [None]:
# %% PhonopyAtoms — auto-rendered via MIME type recognition
PhonopyAtoms(
    symbols=["Na", "Cl"],
    positions=[[0.0, 0.0, 0.0], [0.5, 0.5, 0.5]],
    cell=[[4, 0, 0], [0, 4, 0], [0, 0, 4]],
)