# Part 2: Training a surrogate model with `rholearn` (supporting notebook)

In [1]:
import torch

from rholearn.utils import cube, system
from rholearn.utils.io import unpickle_dict

## 2.3: Train a model

In [None]:
# Inspect the cross-validation splits of idxs
idxs = unpickle_dict("output/crossval_idxs.pickle")
idxs

**Run these cells after training is complete**

In [None]:
from dft_settings import XYZ

# Load frames
frames = system.read_frames_from_xyz(XYZ)

# Load a model checkpoint
model = torch.load("checkpoint/epoch_best/model.pt")

In [None]:
# Make a prediction on coefficients for one of the test frames
test_idx = idxs["test"][0]
test_frame = frames[test_idx]
print("testIdx: ", test_idx)

test_ml_coeffs = model.predict(frames=[test_frame], frame_idxs=[test_idx])
test_ml_coeffs

## 2.4: Evaluate the model

In [None]:
# WARNING: execute with care! 
# Display the electron density volumetric data. This relies on py3Dmol, which is often
# unreliable in jupyter notebooks. Better to use another external software, such as
# VESTA.

# Calculate and display the delta electron density (i.e. ML error) of one of the test structures

# Load the cube files for the ML prediction and SCF reference
rhocube_scf = cube.RhoCube(
    f"../part-1-dft/data/raw/{test_idx}/cube_001_total_density.cube"
)
rhocube_ml = cube.RhoCube(
    f"evaluation/epoch_best/{test_idx}/cube_001_total_density.cube"
)

# Show the SCF density
rhocube_ml.show_volumetric(isovalue=0.02)

# Create the delta density cube - (ML - SCF)
rhocube_ml.data -= rhocube_scf.data

# Write to a new cube file
delta_path = f"evaluation/epoch_best/{test_idx}/cube_delta_ml_scf.cube"
rhocube_ml.write_cube(delta_path)

# Show delta density
rhocube_delta = cube.RhoCube(delta_path)
rhocube_delta.show_volumetric(isovalue=0.02)