In [None]:
import ase.io as ase_io
import matplotlib.pyplot as plt

from e3nn.io import CartesianTensor

from soprano.properties import nmr
from soprano.properties.nmr import MSTensor

import torch
import numpy as np

In [None]:
ct_symm = CartesianTensor("ij=ji")
ct_antisymm = CartesianTensor("ij=-ji")

In [None]:
def get_rmse(a, b, perc=True):
    rmse = np.sqrt(np.mean((a - b) ** 2))
    if perc:
        return 100 * rmse / b.std(ddof=1)
    return rmse

In [None]:
def get_contrib(frm, tag):

    frame = frm.copy()
    frame.arrays["ms"] = frame.arrays[tag].reshape(-1, 3, 3)
    magres = frame.arrays["ms"]
    l1 = ct_antisymm.from_cartesian(torch.tensor(magres)).numpy()

    symm = ct_symm.from_cartesian(torch.tensor(magres)).numpy()
    l0 = symm[:, 0]
    l2 = symm[:, 1:]

    # aniso = nmr.MSAnisotropy.get(frame)
    # asymm = nmr.MSAsymmetry.get(frame)

    # skew = nmr.MSSkew.get(frame)
    # span = nmr.MSSpan.get(frame)

    # euler = []
    # ms_tensor = MSTensor.get(frame)
    # for t in ms_tensor:
    #     equivalent_euler = t.equivalent_euler_angles("zyz", passive=True)
    #     euler.append(equivalent_euler)

    return l0, l1, l2  # , aniso, asymm, skew, span, euler

In [None]:
frames = ase_io.read("../data/train_test/test_with_all_labels.xyz", ":")

In [None]:
list(frames[0].arrays.keys())

In [None]:
qm_l0 = []
qm_l1 = []
qm_l2 = []

for frm in frames:
    res = get_contrib(frm, "QM_ms")
    qm_l0.append(res[0])
    qm_l1.append(res[1])
    qm_l2.append(res[2])

qm_l0 = np.array(qm_l0)
qm_l1 = np.array(qm_l1)
qm_l2 = np.array(qm_l2)

In [None]:
# sizes = [10, 25, 100, 250, 500]
sizes = [10, 25, 100]

In [None]:
ml_isd_l0 = []
ml_isd_l1 = []
ml_isd_l2 = []


for s in sizes:
    for j in range(4):
        for frm in frames:
            res = get_contrib(frm, f"ML_LC_ISD_{s}_{j}_ms")
            ml_isd_l0.append(res[0])
            ml_isd_l1.append(res[1])
            ml_isd_l2.append(res[2])

ml_isd_l0 = np.array(ml_isd_l0).reshape(len(sizes), 4, 150, 144)
ml_isd_l1 = np.array(ml_isd_l1).reshape(len(sizes), 4, 150, 144, 3)
ml_isd_l2 = np.array(ml_isd_l2).reshape(len(sizes), 4, 150, 144, 5)

In [None]:
ml_tp_single_l0 = []
ml_tp_single_l1 = []
ml_tp_single_l2 = []

for s in sizes:
    for j in range(4):
        for frm in frames:
            res = get_contrib(frm, f"ML_LC_TP_single_{s}_{j}_ms")
            ml_tp_single_l0.append(res[0])
            ml_tp_single_l1.append(res[1])
            ml_tp_single_l2.append(res[2])

ml_tp_single_l0 = np.array(ml_tp_single_l0).reshape(len(sizes), 4, 150, 144)
ml_tp_single_l1 = np.array(ml_tp_single_l1).reshape(len(sizes), 4, 150, 144, 3)
ml_tp_single_l2 = np.array(ml_tp_single_l2).reshape(len(sizes), 4, 150, 144, 5)

In [None]:
ml_tp_4096_l0 = []
ml_tp_4096_l1 = []
ml_tp_4096_l2 = []


for s in sizes:
    for j in range(4):
        for frm in frames:
            res = get_contrib(frm, f"ML_LC_TP_4096_{s}_{j}_ms")
            ml_tp_4096_l0.append(res[0])
            ml_tp_4096_l1.append(res[1])
            ml_tp_4096_l2.append(res[2])

ml_tp_4096_l0 = np.array(ml_tp_4096_l0).reshape(len(sizes), 4, -1, 144)
ml_tp_4096_l1 = np.array(ml_tp_4096_l1).reshape(len(sizes), 4, -1, 144, 3)
ml_tp_4096_l2 = np.array(ml_tp_4096_l2).reshape(len(sizes), 4, -1, 144, 5)

In [None]:
si_idx = frames[0].numbers == 14
o_idx = frames[0].numbers == 8

In [None]:
si_err_isd_l0 = np.mean(
    [[get_rmse(y[:, si_idx], qm_l0[:, si_idx]) for y in x] for x in ml_isd_l0], axis=1
)
si_err_isd_l1 = np.mean(
    [[get_rmse(y[:, si_idx], qm_l1[:, si_idx]) for y in x] for x in ml_isd_l1], axis=1
)
si_err_isd_l2 = np.mean(
    [[get_rmse(y[:, si_idx], qm_l2[:, si_idx]) for y in x] for x in ml_isd_l2], axis=1
)


o_err_isd_l0 = np.mean(
    [[get_rmse(y[:, o_idx], qm_l0[:, o_idx]) for y in x] for x in ml_isd_l0], axis=1
)
o_err_isd_l1 = np.mean(
    [[get_rmse(y[:, o_idx], qm_l1[:, o_idx]) for y in x] for x in ml_isd_l1], axis=1
)
o_err_isd_l2 = np.mean(
    [[get_rmse(y[:, o_idx], qm_l2[:, o_idx]) for y in x] for x in ml_isd_l2], axis=1
)

In [None]:
si_err_tp_single_l0 = np.mean(
    [[get_rmse(y[:, si_idx], qm_l0[:, si_idx]) for y in x] for x in ml_tp_single_l0],
    axis=1,
)
si_err_tp_single_l1 = np.mean(
    [[get_rmse(y[:, si_idx], qm_l1[:, si_idx]) for y in x] for x in ml_tp_single_l1],
    axis=1,
)
si_err_tp_single_l2 = np.mean(
    [[get_rmse(y[:, si_idx], qm_l2[:, si_idx]) for y in x] for x in ml_tp_single_l2],
    axis=1,
)


o_err_tp_single_l0 = np.mean(
    [[get_rmse(y[:, o_idx], qm_l0[:, o_idx]) for y in x] for x in ml_tp_single_l0],
    axis=1,
)
o_err_tp_single_l1 = np.mean(
    [[get_rmse(y[:, o_idx], qm_l1[:, o_idx]) for y in x] for x in ml_tp_single_l1],
    axis=1,
)
o_err_tp_single_l2 = np.mean(
    [[get_rmse(y[:, o_idx], qm_l2[:, o_idx]) for y in x] for x in ml_tp_single_l2],
    axis=1,
)

In [None]:
si_err_tp_4096_l0 = np.mean(
    [[get_rmse(y[:, si_idx], qm_l0[:, si_idx]) for y in x] for x in ml_tp_4096_l0],
    axis=1,
)
si_err_tp_4096_l1 = np.mean(
    [[get_rmse(y[:, si_idx], qm_l1[:, si_idx]) for y in x] for x in ml_tp_4096_l1],
    axis=1,
)
si_err_tp_4096_l2 = np.mean(
    [[get_rmse(y[:, si_idx], qm_l2[:, si_idx]) for y in x] for x in ml_tp_4096_l2],
    axis=1,
)


o_err_tp_4096_l0 = np.mean(
    [[get_rmse(y[:, o_idx], qm_l0[:, o_idx]) for y in x] for x in ml_tp_4096_l0], axis=1
)
o_err_tp_4096_l1 = np.mean(
    [[get_rmse(y[:, o_idx], qm_l1[:, o_idx]) for y in x] for x in ml_tp_4096_l1], axis=1
)
o_err_tp_4096_l2 = np.mean(
    [[get_rmse(y[:, o_idx], qm_l2[:, o_idx]) for y in x] for x in ml_tp_4096_l2], axis=1
)

In [None]:
fig = plt.figure(figsize=(3.5, 4.5))


ax = fig.add_subplot(321)
ax.plot(sizes, si_err_isd_l0, ".-", color="C1", label="ISD", lw=0.75)
ax.plot(sizes, si_err_tp_single_l0, ".--", color="C0", label="1 TP", lw=0.75)
ax.plot(sizes, si_err_tp_4096_l0, ".-", color="C0", label="4096 TPs", lw=0.75)
ax.set_xscale("log")
ax.set_yscale("log")
ax.set_xticklabels(())  # must after setting the scales to log
ax.set_ylim((7, 110))
ax.set_ylabel("$\\sigma^{(0)}$ %RMSE", fontsize=9)
ax.tick_params(axis="both", labelsize=9)
ax.set_title("silicon", fontsize=9)
# ax.legend(fontsize=9, ncols=3, bbox_to_anchor=(0.9, 1.40))

ax = fig.add_subplot(323)
ax.plot(sizes, o_err_isd_l0, ".-", color="C1", lw=0.75)
ax.plot(sizes, o_err_tp_single_l0, ".--", color="C0", lw=0.75)
ax.plot(sizes, o_err_tp_4096_l0, ".-", color="C0", lw=0.75)
ax.set_xscale("log")
ax.set_yscale("log")
ax.set_xticklabels(())  # must after setting the scales to log
ax.set_ylim((7, 110))
ax.set_ylabel("$\\sigma^{(1)}$ %RMSE", fontsize=9)
ax.tick_params(axis="both", labelsize=9)

ax = fig.add_subplot(325)
ax.plot(sizes, si_err_isd_l2, ".-", color="C1", lw=0.75)
ax.plot(sizes, si_err_tp_single_l2, ".--", color="C0", lw=0.75)
ax.plot(sizes, si_err_tp_4096_l2, ".-", color="C0", lw=0.75)
ax.set_xscale("log")
ax.set_yscale("log")
ax.set_ylim((7, 110))
ax.set_ylabel("$\\sigma^{(2)}$ %RMSE", fontsize=9)
ax.tick_params(axis="both", labelsize=9)

ax = fig.add_subplot(322)
ax.plot(sizes, si_err_isd_l1, ".-", color="C1", lw=0.75)
ax.plot(sizes, si_err_tp_single_l1, ".--", color="C0", lw=0.75)
ax.plot(sizes, si_err_tp_4096_l1, ".-", color="C0", lw=0.75)
ax.set_xscale("log")
ax.set_yscale("log")
ax.set_xticklabels(())  # must after setting the scales to log
ax.set_yticklabels(())  # must after setting the scales to log
ax.set_ylim((7, 110))
ax.tick_params(axis="both", labelsize=9)
ax.set_title("oxygen", fontsize=9)

ax = fig.add_subplot(324)
ax.plot(sizes, o_err_isd_l1, ".-", color="C1", lw=0.75)
ax.plot(sizes, o_err_tp_single_l1, ".--", color="C0", lw=0.75)
ax.plot(sizes, o_err_tp_4096_l1, ".-", color="C0", lw=0.75)
ax.set_xscale("log")
ax.set_yscale("log")
ax.set_xticklabels(())  # must after setting the scales to log
ax.set_yticklabels(())  # must after setting the scales to log
ax.set_ylim((7, 110))
ax.tick_params(axis="both", labelsize=9)

ax = fig.add_subplot(326)
ax.plot(sizes, o_err_isd_l2, ".-", color="C1", lw=0.75)
ax.plot(sizes, o_err_tp_single_l2, ".--", color="C0", lw=0.75)
ax.plot(sizes, o_err_tp_4096_l2, ".-", color="C0", lw=0.75)
ax.set_xscale("log")
ax.set_yscale("log")
ax.set_yticklabels(())  # must after setting the scales to log
ax.set_ylim((7, 110))
ax.tick_params(axis="both", labelsize=9)

fig.supxlabel("train set size", fontsize=9, y=0.00, x=0.55)

fig.legend(fontsize=9, ncols=3, bbox_to_anchor=(0.97, 1.05), frameon=False)
# fig.legend(fontsize=9)

fig.tight_layout(rect=[0, -0.065, 1.0, 0.995], w_pad=1)

# fig.savefig("./lc_l0l1l2.pdf", dpi=300, bbox_inches="tight")