In [None]:
import ase.io as ase_io
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D

In [None]:
# define colors

c0 = "#f18f01"
c1 = "#033f63"
c2 = "#95b46a"
c3 = "#ee4266"

In [None]:
def get_mae(a, b):
    return np.mean(np.abs(a - b))

In [None]:
molecules_main = ["h2o", "ch4", "ch2o"]
molecules_si = ["h2o", "co2", "ch3cho"]

In [None]:
colors_main = {}
colors_main["h2o"] = c2
colors_main["ch2o"] = c1
colors_main["ch4"] = c0

In [None]:
colors_si = {}
colors_si["h2o"] = c2
colors_si["co2"] = c0
colors_si["ch3cho"] = c1

In [None]:
models = ["gomace", "maceoff23", "maceoff24", "castep", "spice"]

In [None]:
trajs_main = {}
ref_main = {}
for mol in molecules_main:
    trajs_main[mol] = ase_io.read(f"../../data/c60-traj/{mol}.xyz", ":")
    # reference all energies to the 0th step
    ref_main[mol] = {}
    for m in models:
        ref_main[mol][m] = trajs_main[mol][0].info[m + "_energy"]
    for frm in trajs_main[mol]:
        frm.info["gomace_energy"] -= ref_main[mol]["gomace"]
        frm.info["maceoff23_energy"] -= ref_main[mol]["maceoff23"]
        frm.info["maceoff24_energy"] -= ref_main[mol]["maceoff24"]

        frm.info["castep_energy"] -= ref_main[mol]["castep"]
        frm.info["spice_energy"] -= ref_main[mol]["spice"]

In [None]:
trajs_si = {}
ref_si = {}
for mol in molecules_si:
    trajs_si[mol] = ase_io.read(f"../../data/c60-traj/{mol}.xyz", ":")
    # reference all energies to the 0th step
    ref_si[mol] = {}
    for m in models:
        ref_si[mol][m] = trajs_si[mol][0].info[m + "_energy"]
    for frm in trajs_si[mol]:
        frm.info["gomace_energy"] -= ref_si[mol]["gomace"]
        frm.info["maceoff23_energy"] -= ref_si[mol]["maceoff23"]
        frm.info["maceoff24_energy"] -= ref_si[mol]["maceoff24"]

        frm.info["castep_energy"] -= ref_si[mol]["castep"]
        frm.info["spice_energy"] -= ref_si[mol]["spice"]

In [None]:
energy_main = {}
for mol in molecules_main:
    energy_main[mol] = {}
    energy_main[mol]["gomace"] = np.abs(
        [
            (x.info["gomace_energy"] - x.info["castep_energy"]) / len(x) * 1e3
            for x in trajs_main[mol]
        ]
    )
    energy_main[mol]["maceoff23"] = np.abs(
        [
            (x.info["maceoff23_energy"] - x.info["spice_energy"]) / len(x) * 1e3
            for x in trajs_main[mol]
        ]
    )
    energy_main[mol]["maceoff24"] = np.abs(
        [
            (x.info["maceoff24_energy"] - x.info["spice_energy"]) / len(x) * 1e3
            for x in trajs_main[mol]
        ]
    )


forces_main = {}
for mol in molecules_main:
    forces_main[mol] = {}
    forces_main[mol]["gomace"] = [
        1e3 * get_mae(x.arrays["gomace_forces"], x.arrays["castep_forces"])
        for x in trajs_main[mol]
    ]
    forces_main[mol]["maceoff23"] = [
        1e3 * get_mae(x.arrays["maceoff23_forces"], x.arrays["spice_forces"])
        for x in trajs_main[mol]
    ]
    forces_main[mol]["maceoff24"] = [
        1e3 * get_mae(x.arrays["maceoff24_forces"], x.arrays["spice_forces"])
        for x in trajs_main[mol]
    ]

In [None]:
energy_si = {}
for mol in molecules_si:
    energy_si[mol] = {}
    energy_si[mol]["gomace"] = np.abs(
        [
            1e3 * (x.info["gomace_energy"] - x.info["castep_energy"]) / len(x)
            for x in trajs_si[mol]
        ]
    )
    energy_si[mol]["maceoff23"] = np.abs(
        [
            1e3 * (x.info["maceoff23_energy"] - x.info["spice_energy"]) / len(x)
            for x in trajs_si[mol]
        ]
    )
    energy_si[mol]["maceoff24"] = np.abs(
        [
            1e3 * (x.info["maceoff24_energy"] - x.info["spice_energy"]) / len(x)
            for x in trajs_si[mol]
        ]
    )


forces_si = {}
for mol in molecules_si:
    forces_si[mol] = {}
    forces_si[mol]["gomace"] = [
        get_mae(x.arrays["gomace_forces"], x.arrays["castep_forces"])
        for x in trajs_si[mol]
    ]
    forces_si[mol]["maceoff23"] = [
        get_mae(x.arrays["maceoff23_forces"], x.arrays["spice_forces"])
        for x in trajs_si[mol]
    ]
    forces_si[mol]["maceoff24"] = [
        get_mae(x.arrays["maceoff24_forces"], x.arrays["spice_forces"])
        for x in trajs_si[mol]
    ]

In [None]:
# we sample every 50ps so we need to create the proper xaxis
xaxis = np.arange(len(energy_main[molecules_main[0]]["gomace"])) * 2 / 100

In [None]:
fig = plt.figure(figsize=(3.5, 3.1), constrained_layout=True)

emin, emax = 0.01, 200
fmin, fmax = 100, 3000


ax = fig.add_subplot(231)

for mol, err in energy_main.items():
    ax.plot(xaxis, err["gomace"], lw=0.75, color=colors_main[mol])

ax.tick_params(axis="both", labelsize=8)
ax.set_xticklabels(())
ax.set_ylabel("Energy AE (meV at.$^{-1}$)", fontsize=8)
ax.set_yscale("log")
ax.set_ylim(emin, emax)

labelx = -0.35
ax.yaxis.set_label_coords(labelx, 0.5, transform=ax.transAxes)
ax.set_title("GO-MACE-23", fontsize=8)


ax = fig.add_subplot(232)
for mol, err in energy_main.items():
    ax.plot(xaxis, err["maceoff23"], lw=0.75, color=colors_main[mol])
ax.tick_params(axis="both", labelsize=8)
ax.set_yscale("log")
ax.set_xticklabels(())
ax.set_yticklabels(())
ax.set_ylim(emin, emax)
ax.set_title("MACE-OFF23", fontsize=8)


ax = fig.add_subplot(233)
for mol, err in energy_main.items():
    ax.plot(xaxis, err["maceoff24"], lw=0.75, color=colors_main[mol])
ax.tick_params(axis="both", labelsize=8)
ax.set_yscale("log")
ax.set_xticklabels(())
ax.set_yticklabels(())
ax.set_ylim(emin, emax)
ax.set_title("MACE-OFF24", fontsize=8)

ax = fig.add_subplot(234)
for mol, err in forces_main.items():
    ax.plot(xaxis, err["gomace"], lw=0.75, color=colors_main[mol])

ax.set_ylabel("Forces MAE (meV $\mathrm{\AA}^{-1}$)", fontsize=8)
ax.set_yscale("log")
ax.tick_params(axis="both", labelsize=8, which="both")
ax.set_ylim(fmin, fmax)
ax.yaxis.set_label_coords(labelx, 0.5, transform=ax.transAxes)

ax = fig.add_subplot(235)
for mol, err in forces_main.items():
    ax.plot(xaxis, err["maceoff23"], lw=0.75, color=colors_main[mol])
ax.tick_params(axis="both", labelsize=8)
ax.set_yscale("log")
ax.set_yticklabels(())
ax.set_ylim(fmin, fmax)


ax = fig.add_subplot(236)
for mol, err in forces_main.items():
    ax.plot(xaxis, err["maceoff24"], lw=0.75, color=colors_main[mol])

ax.tick_params(axis="both", labelsize=8)
ax.set_yscale("log")
ax.set_yticklabels(())
ax.set_ylim(fmin, fmax)


fig.supxlabel("time (ns)", fontsize=8)


lines = [
    Line2D([0], [0], lw=0.75, color="k", ls="-"),
    Line2D([0], [0], lw=0.75, color="k", ls="-"),
    Line2D([0], [0], lw=0.75, color="k", ls="-"),
]


lines = [
    Line2D([0], [0], lw=3.75, color=c2, ls="-"),
    Line2D([0], [0], lw=3.75, color=c0, ls="-"),
    Line2D([0], [0], lw=3.75, color=c1, ls="-"),
]

labels = ["H$_2$O@C$_{60}$", "CH$_4$@C$_{60}$", "CH$_2$O@C$_{60}$"]
ccc = fig.legend(
    lines, labels, bbox_to_anchor=(0.95, 1.37), fontsize=8, frameon=False, ncols=3
)


# fig.savefig("./fig6.svg", dpi=300, bbox_inches="tight")

In [None]:
fig = plt.figure(figsize=(3.5, 3.1), constrained_layout=True)


emin, emax = 9e-3, 250
fmin, fmax = 0.1, 3.8


ax = fig.add_subplot(231)

for mol, err in energy_si.items():
    ax.plot(xaxis, err["gomace"], lw=0.75, color=colors_si[mol])

ax.tick_params(axis="both", labelsize=8)
ax.set_xticklabels(())
ax.set_ylabel("Energy AE (meV at.$^{-1}$)", fontsize=8)
ax.set_yscale("log")
ax.set_ylim(emin, emax)

labelx = -0.35
ax.yaxis.set_label_coords(labelx, 0.5, transform=ax.transAxes)
ax.set_title("GO-MACE-23", fontsize=8)


ax = fig.add_subplot(232)
for mol, err in energy_si.items():
    ax.plot(xaxis, err["maceoff23"], lw=0.75, color=colors_si[mol])
ax.tick_params(axis="both", labelsize=8)
ax.set_yscale("log")
ax.set_xticklabels(())
ax.set_yticklabels(())
ax.set_ylim(emin, emax)
ax.set_title("MACE-OFF23", fontsize=8)


ax = fig.add_subplot(233)
for mol, err in energy_si.items():
    ax.plot(xaxis, err["maceoff24"], lw=0.75, color=colors_si[mol])
ax.tick_params(axis="both", labelsize=8)
ax.set_yscale("log")
ax.set_xticklabels(())
ax.set_yticklabels(())
ax.set_ylim(emin, emax)
ax.set_title("MACE-OFF24", fontsize=8)


ax = fig.add_subplot(234)
for mol, err in forces_si.items():
    ax.plot(xaxis, err["gomace"], lw=0.75, color=colors_si[mol])

ax.set_ylabel("Forces MAE (eV $\mathrm{\AA}^{-1}$)", fontsize=8)
ax.set_yscale("log")
ax.tick_params(axis="both", labelsize=8, which="both")
ax.set_ylim(fmin, fmax)
ax.yaxis.set_label_coords(labelx, 0.5, transform=ax.transAxes)

ax = fig.add_subplot(235)
for mol, err in forces_si.items():
    ax.plot(xaxis, err["maceoff23"], lw=0.75, color=colors_si[mol])
ax.tick_params(axis="both", labelsize=8)
ax.set_yscale("log")
ax.set_yticklabels(())
ax.set_ylim(fmin, fmax)


ax = fig.add_subplot(236)
for mol, err in forces_si.items():
    ax.plot(xaxis, err["maceoff24"], lw=0.75, color=colors_si[mol])

ax.tick_params(axis="both", labelsize=8)
ax.set_yscale("log")
ax.set_yticklabels(())
ax.set_ylim(fmin, fmax)


fig.supxlabel("time (ns)", fontsize=8)


lines = [
    Line2D([0], [0], lw=0.75, color="k", ls="-"),
    Line2D([0], [0], lw=0.75, color="k", ls="-"),
    Line2D([0], [0], lw=0.75, color="k", ls="-"),
]


lines = [
    Line2D([0], [0], lw=3.75, color=c2, ls="-"),
    Line2D([0], [0], lw=3.75, color=c0, ls="-"),
    Line2D([0], [0], lw=3.75, color=c1, ls="-"),
]

labels = ["H$_2$O@C$_{60}$", "CO$_2$@C$_{60}$", "CH$_3$CHO@C$_{60}$"]
ccc = fig.legend(
    lines, labels, bbox_to_anchor=(0.95, 1.37), fontsize=8, frameon=False, ncols=3
)


# fig.savefig("./fig_si_S2.svg", dpi=300, bbox_inches="tight")