In [None]:
import ase.io as ase_io
import matplotlib.pyplot as plt
import numpy as np
from rings import rings_distribution

In [None]:
def get_min_rings(atoms):
    r_c = 2.04 / 2
    r_o = 1.6 / 2
    r_h = 0  # 1.20/2
    # create a neighbor list
    cutoff_dict = {
        "CC": r_c * 2,
        "OO": r_o * 2,
        "HH": r_h * 2,
        "HC": r_c + r_h,
        "CH": r_c + r_h,
        "OH": r_o + r_h,
        "HO": r_o + r_h,
        "CO": r_c + r_o,  # 1.85,
        "OC": r_c + r_o,
    }

    # this function will count the number of 6 membered rings in the system
    # create a list to store the number of 6 membered rings in each frame
    min_rings = []
    for i, structure in enumerate(atoms):
        print("Starting structure" + str(i))
        symbols = structure.get_chemical_symbols()
        indices_c = [
            j for j, x in enumerate(symbols) if (x == "C" or x == "O")
        ]  # Change this. ATM it just looks at C atoms
        structure = structure[indices_c]
        # now we can calculate the number of 6 membered rings in the system

        ring = [rings_length, rings_dist, rings_dist_frac, err, err_frac] = (
            rings_distribution(structure, cutoff=cutoff_dict)
        )
        if np.sum(rings_dist) == 0:
            min_rings.append(0)
        else:
            min_rings.append(min(rings_length[rings_dist.astype(bool)]))
    return np.array(min_rings)

In [None]:
qm7x = ["C6H10O", "C5H8O2", "C6H8O", "C6H12O", "C5H10O2", "C5H6O2"]
frames = []
for mol in qm7x:
    frames += ase_io.read(f"../../data/QM7-X/{mol}.xyz", ":")

In [None]:
minrings = get_min_rings(frames)

In [None]:
force_rmse = []
energy_rmse = []

for j in range(600):
    force_go = frames[j].arrays["gomace_forces"]
    force_ca = frames[j].arrays["castep_forces"]

    energy_go = frames[j].info["gomace_energy"]
    energy_ca = frames[j].info["castep_energy"]

    force_rmse.append(np.sqrt(((force_go - force_ca) ** 2).mean()))
    energy_rmse.append(np.abs(energy_go - energy_ca) / len(frames[j]))

force_rmse = np.array(force_rmse)
energy_rmse = np.array(energy_rmse)

In [None]:
unique = np.unique(minrings)

In [None]:
force_rmse_per_ring_size = {}
energy_rmse_per_ring_size = {}


force_rmse_per_ring_size_per_st = {j: [] for j in unique}
energy_rmse_per_ring_size_per_st = {j: [] for j in unique}

for j in unique:
    idx = np.arange(600)[minrings == j]
    forc_go = []
    forc_ca = []
    energy_ca = []
    energy_go = []
    for i in idx:
        forc_go += frames[i].arrays["gomace_forces"].tolist()
        forc_ca += frames[i].arrays["castep_forces"].tolist()

        energy_go.append(frames[i].info["gomace_energy"] / len(frames[i]))
        energy_ca.append(frames[i].info["castep_energy"] / len(frames[i]))

        force_rmse_per_ring_size_per_st[j].append(
            np.sqrt(((np.array(forc_go) - np.array(forc_ca)) ** 2).mean())
        )
        energy_rmse_per_ring_size_per_st[j].append(
            100 * np.abs(energy_go[-1] - energy_ca[-1])
        )

    forc_go = np.array(forc_go)
    forc_ca = np.array(forc_ca)

    energy_ca = np.array(energy_ca)
    energy_go = np.array(energy_go)

    force_rmse_per_ring_size[j] = np.sqrt(((forc_go - forc_ca) ** 2).mean())
    energy_rmse_per_ring_size[j] = np.sqrt(((energy_go - energy_ca) ** 2).mean())

In [None]:
# define colors
c0 = "#f18f01"
c1 = "#033f63"
c1bis = "#39b1f9"
c2 = "#95b46a"
c3 = "#ee4266"

In [None]:
fig = plt.figure(figsize=(3.5, 4.0), constrained_layout=True)

ax = fig.add_subplot(211)
c = ax.violinplot(
    [energy_rmse_per_ring_size_per_st[j] for j in unique],
    positions=unique,
    showmedians=True,
)
for partname in ("cbars", "cmins", "cmaxes", "cmedians"):
    vp = c[partname]
    vp.set_edgecolor(c3)
    vp.set_linewidth(0.75)
for pc in c["bodies"]:
    pc.set_facecolor(c3)
    pc.set_edgecolor(c3)
ax.set_xlabel("")
ax.set_ylabel("Energy RMSE (meV at.$^{-1}$)", fontsize=8)
ax.tick_params(axis="both", labelsize=8)
ax.set_ylim([-2.5, 30])
ax.set_xticks((0, 3, 4, 5, 6, 7))
ax.set_xticklabels([])

ax = fig.add_subplot(212)
c = ax.violinplot(
    [force_rmse_per_ring_size_per_st[j] for j in unique],
    positions=unique,
    showmedians=True,
)
for partname in ("cbars", "cmins", "cmaxes", "cmedians"):
    vp = c[partname]
    vp.set_edgecolor(c3)
    vp.set_linewidth(0.75)
for pc in c["bodies"]:
    pc.set_facecolor(c3)
    pc.set_edgecolor(c3)
ax.set_xlabel("smallest ring size", fontsize=8)
ax.set_ylabel("Force RMSE (eV $\mathrm{\AA}^{-1}$)", fontsize=8)
ax.tick_params(axis="both", labelsize=8)
ax.set_ylim([0.1, 1.1])
ax.set_xticks((0, 3, 4, 5, 6, 7))
ax.set_xticklabels(["no rings", 3, 4, 5, 6, 7])

# fig.savefig("fig3.svg", dpi=300, bbox_inches="tight")