## Perform Inference

In [None]:
from functools import partial
from os.path import join

import matplotlib.pyplot as plt
import torch

import chemiscope

from rhocalc.aims import aims_fields, aims_predictor
from rhocalc.cube import rho_cube

from dft_settings import *
from ml_settings import *

In [None]:
eval_id = [i for j in ALL_SUBSET_ID for i in j]
eval_frames = [ALL_STRUCTURE[A] for A in eval_id]
chemiscope.show(eval_frames, mode="structure")

In [None]:
model = torch.load("checkpoint/epoch_2000/model.pt")
eval_preds = model.predict(frames=eval_frames, system_id=eval_id)

In [None]:
aims_kwargs = {k: v for k, v in BASE_AIMS.items()}
aims_kwargs.update({k: v for k, v in REBUILD.items()})
aims_predictor.field_builder(
    system_id=list(eval_id),
    system=eval_frames,
    predicted_coeffs=eval_preds,
    save_dir=partial(EVAL_DIR, epoch="eval"),
    return_field=False,
    aims_kwargs=aims_kwargs,
    aims_path=AIMS_PATH,
    basis_set=TARGET_BASIS,
    cube=CUBE,
    hpc_kwargs=HPC,
    sbatch_kwargs=SBATCH,
)

## Evaluate MAE on the surface

In [None]:
for A, frame in zip(eval_id, eval_frames):
    grid = np.loadtxt(join(RI_DIR(A), "partition_tab.out"))
    rho_ref = np.loadtxt(join(RI_DIR(A), "rho_ri.out"))
    rho_ml = np.loadtxt(join(EVAL_DIR(A, epoch="eval"), "rho_rebuilt.out"))

    grid = aims_fields.sort_field_by_grid_points(grid)
    rho_ref = aims_fields.sort_field_by_grid_points(rho_ref)
    rho_ml = aims_fields.sort_field_by_grid_points(rho_ml)

    grid[:, 3] *= grid[:, 2] > frame.positions[:, 2].max() - (
        DESCRIPTOR_HYPERS["slab_depth"] + DESCRIPTOR_HYPERS["interphase_depth"]
    )

    mae = aims_fields.get_percent_mae_between_fields(input=rho_ml, target=rho_ref, grid=grid)
    print(f"MAE for system {A}: {mae:.2f}%")

## Plot STM image

In [None]:
for A in eval_id:
    paths = [
        join(RI_DIR(A), "rho_ri.cube"),
        join(EVAL_DIR(A, epoch="eval"), "rho_rebuilt.cube"),
    ]
    # Create a scatter matrix
    fig, axes = plt.subplots(
        len(paths), len(paths), figsize=(10, 8), sharey=True, sharex=True
    )

    X, Y, Z = [], [], []
    for path in paths:
        q = rho_cube.RhoCube(path)
        x, y, z = q.get_slab_slice(
            axis=2,
            center_coord=q.ase_frame.positions[:, 2].max(),
            thickness=4,
        )
        X.append(x)
        Y.append(y)
        Z.append(z)

    for row, row_ax in enumerate(axes):
        for col, ax in enumerate(row_ax):
            if row == col:
                x, y, z = X[row], Y[row], np.tanh(Z[row])
            elif row < col:
                x, y, z = X[row], Y[col], np.tanh(Z[row] - Z[col])
            else:
                continue
            cs = ax.contourf(x, y, z, cmap="gray")
            cbar = fig.colorbar(cs)

    plt.savefig(join(EVAL_DIR(A, epoch="eval"), "stm_scatter.png"))