In [None]:
import ase.io as ase_io
import matplotlib.pyplot as plt

from e3nn.io import CartesianTensor

from soprano.properties import nmr
from soprano.properties.nmr import MSTensor

import torch
import numpy as np

In [None]:
ct_symm = CartesianTensor("ij=ji")
ct_antisymm = CartesianTensor("ij=-ji")

In [None]:
def get_rmse(a, b, perc=True):
    rmse = np.sqrt(np.mean((a - b) ** 2))
    if perc:
        return 100 * rmse / b.std(ddof=1)
    return rmse


def get_rmse_euler(ml_all, dft_all):

    r = []
    p = []
    for ml, dft in zip(ml_all, dft_all):
        for j in range(len(ml)):
            da = np.abs(dft[j][:, 0][:, np.newaxis] - ml[j][:, 0])
            db = np.abs(dft[j][:, 1][:, np.newaxis] - ml[j][:, 1])
            dc = np.abs(dft[j][:, 2][:, np.newaxis] - ml[j][:, 2])
            r.append(
                [
                    dft[j][:, 0][np.where(np.isclose(da, da.min()))[0]].min(),
                    dft[j][:, 1][np.where(np.isclose(db, db.min()))[0]].min(),
                    dft[j][:, 2][np.where(np.isclose(dc, dc.min()))[0]].min(),
                ]
            )
            p.append(
                [
                    ml[j][:, 0][np.where(np.isclose(da, da.min()))[1]].min(),
                    ml[j][:, 1][np.where(np.isclose(db, db.min()))[1]].min(),
                    ml[j][:, 2][np.where(np.isclose(dc, dc.min()))[1]].min(),
                ]
            )
    r = np.array(r)
    p = np.array(p)
    return (
        get_rmse(p[:, 0], r[:, 0]),
        get_rmse(p[:, 1], r[:, 1]),
        get_rmse(p[:, 2], r[:, 2]),
    )

In [None]:
def get_contrib(frm, tag):

    frame = frm.copy()
    frame.arrays["ms"] = frame.arrays[tag].reshape(-1, 3, 3)
    magres = frame.arrays["ms"]
    l1 = ct_antisymm.from_cartesian(torch.tensor(magres)).numpy()

    symm = ct_symm.from_cartesian(torch.tensor(magres)).numpy()
    l0 = symm[:, 0]
    l2 = symm[:, 1:]

    aniso = np.abs(nmr.MSAnisotropy.get(frame))
    asymm = nmr.MSAsymmetry.get(frame)

    skew = nmr.MSSkew.get(frame)
    span = nmr.MSSpan.get(frame)

    euler = []
    ms_tensor = MSTensor.get(frame)
    for t in ms_tensor:
        equivalent_euler = t.equivalent_euler_angles("zyz", passive=True)
        euler.append(equivalent_euler)

    return l0, l1, l2, aniso, asymm, skew, span, euler

In [None]:
frames = ase_io.read("../data/train_test/test_with_all_labels.xyz", ":")

In [None]:
list(frames[0].arrays.keys())

In [None]:
qm_l0 = []
qm_l1 = []
qm_l2 = []
qm_aniso = []
qm_asymm = []
qm_skew = []
qm_span = []
qm_euler = []

for frm in frames:
    res = get_contrib(frm, "QM_ms")
    qm_l0.append(res[0])
    qm_l1.append(res[1])
    qm_l2.append(res[2])
    qm_aniso.append(res[3])
    qm_asymm.append(res[4])
    qm_skew.append(res[5])
    qm_span.append(res[6])
    qm_euler.append(res[7])

qm_l0 = np.array(qm_l0)
qm_l1 = np.array(qm_l1)
qm_l2 = np.array(qm_l2)
qm_aniso = np.array(qm_aniso)
qm_asymm = np.array(qm_asymm)
qm_skew = np.array(qm_skew)
qm_span = np.array(qm_span)
qm_euler = np.array(qm_euler)

In [None]:
ml_isd_l0 = []
ml_isd_l1 = []
ml_isd_l2 = []
ml_isd_aniso = []
ml_isd_asymm = []
ml_isd_skew = []
ml_isd_span = []
ml_isd_euler = []

for frm in frames:
    res = get_contrib(frm, "ML_ISD_ms")
    ml_isd_l0.append(res[0])
    ml_isd_l1.append(res[1])
    ml_isd_l2.append(res[2])
    ml_isd_aniso.append(res[3])
    ml_isd_asymm.append(res[4])
    ml_isd_skew.append(res[5])
    ml_isd_span.append(res[6])
    ml_isd_euler.append(res[7])

ml_isd_l0 = np.array(ml_isd_l0)
ml_isd_l1 = np.array(ml_isd_l1)
ml_isd_l2 = np.array(ml_isd_l2)
ml_isd_aniso = np.array(ml_isd_aniso)
ml_isd_asymm = np.array(ml_isd_asymm)
ml_isd_skew = np.array(ml_isd_skew)
ml_isd_span = np.array(ml_isd_span)
ml_isd_euler = np.array(ml_isd_euler)

In [None]:
ml_tp_l0 = []
ml_tp_l1 = []
ml_tp_l2 = []
ml_tp_aniso = []
ml_tp_asymm = []
ml_tp_skew = []
ml_tp_span = []
ml_tp_euler = []

for frm in frames:
    res = get_contrib(frm, "ML_TP_4096_4e_ms")
    ml_tp_l0.append(res[0])
    ml_tp_l1.append(res[1])
    ml_tp_l2.append(res[2])
    ml_tp_aniso.append(res[3])
    ml_tp_asymm.append(res[4])
    ml_tp_skew.append(res[5])
    ml_tp_span.append(res[6])
    ml_tp_euler.append(res[7])

ml_tp_l0 = np.array(ml_tp_l0)
ml_tp_l1 = np.array(ml_tp_l1)
ml_tp_l2 = np.array(ml_tp_l2)
ml_tp_aniso = np.array(ml_tp_aniso)
ml_tp_asymm = np.array(ml_tp_asymm)
ml_tp_skew = np.array(ml_tp_skew)
ml_tp_span = np.array(ml_tp_span)
ml_tp_euler = np.array(ml_tp_euler)

In [None]:
si_idx = frames[0].numbers == 14
o_idx = frames[0].numbers == 8

In [None]:
si_err_isd_l0 = get_rmse(ml_isd_l0[:, si_idx], qm_l0[:, si_idx])
si_err_isd_l1 = get_rmse(ml_isd_l1[:, si_idx], qm_l1[:, si_idx])
si_err_isd_l2 = get_rmse(ml_isd_l2[:, si_idx], qm_l2[:, si_idx])

si_err_isd_skew = get_rmse(ml_isd_skew[:, si_idx], qm_skew[:, si_idx])
si_err_isd_span = get_rmse(ml_isd_span[:, si_idx], qm_span[:, si_idx])

si_err_isd_aniso = get_rmse(ml_isd_aniso[:, si_idx], qm_aniso[:, si_idx])
si_err_isd_asymm = get_rmse(ml_isd_asymm[:, si_idx], qm_asymm[:, si_idx])


si_err_isd_alpha, si_err_isd_beta, si_err_isd_gamma = get_rmse_euler(
    ml_isd_euler[:, si_idx], qm_euler[:, si_idx]
)

o_err_isd_l0 = get_rmse(ml_isd_l0[:, o_idx], qm_l0[:, o_idx])
o_err_isd_l1 = get_rmse(ml_isd_l1[:, o_idx], qm_l1[:, o_idx])
o_err_isd_l2 = get_rmse(ml_isd_l2[:, o_idx], qm_l2[:, o_idx])

o_err_isd_skew = get_rmse(ml_isd_skew[:, o_idx], qm_skew[:, o_idx])
o_err_isd_span = get_rmse(ml_isd_span[:, o_idx], qm_span[:, o_idx])

o_err_isd_aniso = get_rmse(ml_isd_aniso[:, o_idx], qm_aniso[:, o_idx])
o_err_isd_asymm = get_rmse(ml_isd_asymm[:, o_idx], qm_asymm[:, o_idx])

o_err_isd_alpha, o_err_isd_beta, o_err_isd_gamma = get_rmse_euler(
    ml_isd_euler[:, o_idx], qm_euler[:, o_idx]
)

In [None]:
si_err_tp_l0 = get_rmse(ml_tp_l0[:, si_idx], qm_l0[:, si_idx])
si_err_tp_l1 = get_rmse(ml_tp_l1[:, si_idx], qm_l1[:, si_idx])
si_err_tp_l2 = get_rmse(ml_tp_l2[:, si_idx], qm_l2[:, si_idx])

si_err_tp_skew = get_rmse(ml_tp_skew[:, si_idx], qm_skew[:, si_idx])
si_err_tp_span = get_rmse(ml_tp_span[:, si_idx], qm_span[:, si_idx])

si_err_tp_aniso = get_rmse(ml_tp_aniso[:, si_idx], qm_aniso[:, si_idx])
si_err_tp_asymm = get_rmse(ml_tp_asymm[:, si_idx], qm_asymm[:, si_idx])

si_err_tp_alpha, si_err_tp_beta, si_err_tp_gamma = get_rmse_euler(
    ml_tp_euler[:, si_idx], qm_euler[:, si_idx]
)

o_err_tp_l0 = get_rmse(ml_tp_l0[:, o_idx], qm_l0[:, o_idx])
o_err_tp_l1 = get_rmse(ml_tp_l1[:, o_idx], qm_l1[:, o_idx])
o_err_tp_l2 = get_rmse(ml_tp_l2[:, o_idx], qm_l2[:, o_idx])

o_err_tp_skew = get_rmse(ml_tp_skew[:, o_idx], qm_skew[:, o_idx])
o_err_tp_span = get_rmse(ml_tp_span[:, o_idx], qm_span[:, o_idx])

o_err_tp_aniso = get_rmse(ml_tp_aniso[:, o_idx], qm_aniso[:, o_idx])
o_err_tp_asymm = get_rmse(ml_tp_asymm[:, o_idx], qm_asymm[:, o_idx])

o_err_tp_alpha, o_err_tp_beta, o_err_tp_gamma = get_rmse_euler(
    ml_tp_euler[:, o_idx], qm_euler[:, o_idx]
)

In [None]:
fig = plt.figure(figsize=(3.5, 3.7))

xlabels = [
    "$\sigma^{(0)}$",
    "$\sigma^{(1)}$",
    "$\sigma^{(2)}$",
    "$\zeta_\sigma$",
    "$\eta_\sigma$",
    "$\kappa_\sigma$",
    "$\Omega_\sigma$",
    "$\\alpha_\sigma$",
    "$\\beta_\sigma$",
    "$\\gamma_\sigma$",
]

ax = fig.add_subplot(211)
x = np.arange(len(xlabels))
y = [
    si_err_isd_l0,
    si_err_isd_l1,
    si_err_isd_l2,
    si_err_isd_aniso,
    si_err_isd_asymm,
    si_err_isd_skew,
    si_err_isd_span,
    si_err_isd_alpha,
    si_err_isd_beta,
    si_err_isd_gamma,
]
z = [
    si_err_tp_l0,
    si_err_tp_l1,
    si_err_tp_l2,
    si_err_tp_aniso,
    si_err_tp_asymm,
    si_err_tp_skew,
    si_err_tp_span,
    si_err_tp_alpha,
    si_err_tp_beta,
    si_err_tp_gamma,
]

ax.plot(x, y, ".-", lw=0.75, label="from ISD", c="C1")
ax.plot(x, z, ".-", lw=0.75, label="from TP", c="C0")

ax.set_xticklabels([])
ax.set_xticks(x)
ax.set_ylabel("%RMSE", fontsize=9)
ax.set_ylim(0, 45)
ax.legend(
    loc="upper center", fontsize=9, bbox_to_anchor=(0.5, 1.25), ncols=2, frameon=False
)
ax.tick_params(axis="both", labelsize=9)
ax.text(0, 40, "silicon", fontsize=9)
ax = fig.add_subplot(212)
y = [
    o_err_isd_l0,
    o_err_isd_l1,
    o_err_isd_l2,
    o_err_isd_aniso,
    o_err_isd_asymm,
    o_err_isd_skew,
    o_err_isd_span,
    o_err_isd_alpha,
    o_err_isd_beta,
    o_err_isd_gamma,
]
z = [
    o_err_tp_l0,
    o_err_tp_l1,
    o_err_tp_l2,
    o_err_tp_aniso,
    o_err_tp_asymm,
    o_err_tp_skew,
    o_err_tp_span,
    o_err_tp_alpha,
    o_err_tp_beta,
    o_err_tp_gamma,
]

ax.plot(x, y, ".-", lw=0.75, c="C1")
ax.plot(x, z, ".-", lw=0.75, c="C0")

ax.set_xticks(x)
ax.set_xticklabels(xlabels, fontsize=9)
ax.set_ylabel("%RMSE", fontsize=9)
ax.tick_params(axis="both", labelsize=9)
ax.set_ylim(0, 45)
ax.text(0, 40, "oxygen", fontsize=9)

fig.tight_layout()

# fig.savefig("./cs_spher_decomp_vs_tp_v4.pdf", dpi=300, bbox_inches="tight")