# Benchmarking MLPs
This notebook contains a template for the "ID" of a potential with all the relevant metrics for evaluation produced with benchmark.py

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import yaml

In [None]:
def plot_values_vs_rel_error(filenames, datalabels):
    x = []  # rel_error
    y = []  # corresponding value
    labels = ["lattice constant", "cohesive energy", "bulk modulus", 
              "(111) surface energy", "(110) surface energy","(100) surface energy"]

    plt.figure(dpi=200, figsize=(4.5,4))

    for fname, lab in zip(filenames, datalabels):
        with open(fname, "r") as f:
            data = yaml.safe_load(f)

        g111 = data["fcc111"]["rel_error"]
        g110 = data["fcc110"]["rel_error"]
        g100 = data["fcc100"]["rel_error"]
        a0   = data["fcc_bulk"]["a0_rel_error"]
        e0   = data["fcc_bulk"]["e0_rel_error"]
        B    = data["fcc_bulk"]["B_rel_error"]

        y = [abs(a0), abs(e0), abs(B), abs(g111), abs(g110), abs(g100)]
        x = np.arange(len(y))

        plt.plot(x, y, label=lab, marker='s')

    plt.xlabel("")
    plt.xticks(x, labels, rotation=90.)
    plt.ylabel("absolute error wrt DFT (%)")
    plt.legend()
    plt.tight_layout()
    plt.show()


def plot_parity_with_yaml(folders, benchmark_yaml_files,  titles=None, dpi=200, figsize=(7.9,7.4)):
    """
    Generate 3x3 parity plots (Energy, Forces, Stress) for multiple test sets,
    with MAE/MAV read from YAML files. coded with llm

    Parameters
    ----------
    folders : list of str
        Paths to folders containing 'e-parity.dat', 'f-parity.dat', 's-parity.dat'.
    benchmark_yaml_files : list of str, optional
        YAML files corresponding to each folder, containing mae/mav values.
    titles : list of str, optional
        Titles for each folder/column.
    dpi : int
        Figure DPI.
    figsize : tuple
        Figure size.
    """

    fig, axs = plt.subplots(3, len(folders), dpi=dpi, figsize=figsize)

    parity_labels = ['E/at [eV]', '|F| (eV/A)', '|σ| [eV/A^3]']
    parity_colors = ['r', 'g', 'b']
    keys = ['e', 'f', 's']

    for i, fol in enumerate(folders):
        # Load parity data
        e_parity = np.loadtxt(fol+'e-parity.dat')
        f_parity = np.loadtxt(fol+'f-parity.dat')
        s_parity = np.loadtxt(fol+'s-parity.dat')
        parity_data = [e_parity, f_parity, s_parity]

        # Load MAE/MAV from YAML if provided

        with open(benchmark_yaml_files[i], 'r') as f:
            ydata = yaml.safe_load(f)
        mae_e = ydata.get("test_set", {}).get("energy_per_atom", {}).get("mae", np.nan)
        mae_f = ydata.get("test_set", {}).get("forces", {}).get("mae", np.nan)
        mae_s = ydata.get("test_set", {}).get("stress", {}).get("mae", np.nan)
        mav_e = ydata.get("test_set", {}).get("energy_per_atom", {}).get("mav", np.nan)
        mav_f = ydata.get("test_set", {}).get("forces", {}).get("mav", np.nan)
        mav_s = ydata.get("test_set", {}).get("stress", {}).get("mav", np.nan)

        # map YAML values to the parity keys so each plot shows the correct physical MAE/MAV
        yaml_values = {
            'e': {'mae': mae_e, 'mav': mav_e, 'unit': 'eV/at',   'name': 'Energy'},
            'f': {'mae': mae_f, 'mav': mav_f, 'unit': 'eV/Å',    'name': 'Force (|F|)'},
            's': {'mae': mae_s, 'mav': mav_s, 'unit': 'eV/Å^3',  'name': 'Stress (|σ|)'}
        }

        for ax, data, label, c, k in zip(axs[:, i], parity_data, parity_labels, parity_colors, keys):
            ax.plot(data[:,0], data[:,0], 'k--')        # parity line
            ax.plot(data[:,0], data[:,1], 'o', color=c) # predicted vs DFT
            ax.set_xlabel('DFT ' + label)
            ax.set_ylabel('MLP ' + label)

            # Add MAE/MAV from YAML (use the corresponding physical quantity)
            mae = yaml_values.get(k, {}).get('mae', np.nan)
            mav = yaml_values.get(k, {}).get('mav', np.nan)
            unit = yaml_values.get(k, {}).get('unit', '')
            name = yaml_values.get(k, {}).get('name', label)

            if np.isnan(mae):
                ann = f"MAE ({name}): n/a"
            else:
                pct = 100. * mae / mav if (mav and not np.isnan(mav)) else np.nan
                pct_str = f"{pct:.2g}%" if not np.isnan(pct) else "n/a"
                ann = f"MAE ({name}): {mae:.2g} {unit} ({pct_str})"

            ax.text(0.05, 0.95, ann,
                    transform=ax.transAxes, verticalalignment='top',
                    bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))

        if titles:
            axs[0, i].set_title(titles[i])

    plt.tight_layout()
    plt.show()

def plot_far_atom_systems(folder):

    files  = [folder+'adsorbate-curve.dat', folder+'dimer-curve.dat', folder+'large_eos_curve.dat']
    titles = ['atom on surface', 'dimer', 'bulk fcc']
    all_data  = [np.loadtxt(f) for f in files]

    fig, axs = plt.subplots(1, len(data), dpi=300)

    for i, ax, data, tit in enumerate(axs, all_data, titles):
        ax.plot(data[:,0], data[:,1])
        ax.xlabel('distance [A]')
        ax.set_title(tit)
    
    axs[0].set_xabel('energy [eV]')
    axs[2].set_xlabel('lattice constant [A]')

    plt.show()

def plot_excess_energy(files):
    





In [None]:
filenames  = ["_benchmark.yaml"]
folders    = ["./"]
datalabels = ["flare FF"]
plot_values_vs_rel_error(filenames, datalabels)
plot_parity_with_yaml(folders, filenames)
plot_far_atom_systems(folders[0])