In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D

from mrsimulator import Simulator, SpinSystem, Site
from mrsimulator import signal_processor as sp
from mrsimulator.method import SpectralDimension
from mrsimulator.method.lib import ThreeQ_VAS, SSB2D
from mrsimulator.spin_system.tensors import SymmetricTensor

import ase.io as ase_io

from soprano.properties import nmr
from soprano.calculate.nmr.nmr import NMRCalculator, NMRFlags

from copy import deepcopy

from scipy.stats import gaussian_kde, vonmises
from scipy.interpolate import CubicSpline

In [None]:
def interpolate_color(start_color, end_color, num_steps):
    """
    Interpolates between two RGB colors and returns a list of colors forming a gradient.

    Parameters:
    - start_color: tuple (r, g, b) representing the RGB values of the start color
    - end_color: tuple (r, g, b) representing the RGB values of the end color
    - num_steps: integer, the number of colors in the gradient (including start and end)

    Returns:
    - list of tuples representing the interpolated RGB values.
    """
    # Create arrays of red, green, and blue values from the start and end colors
    r_values = np.linspace(start_color[0], end_color[0], num_steps)
    g_values = np.linspace(start_color[1], end_color[1], num_steps)
    b_values = np.linspace(start_color[2], end_color[2], num_steps)

    # Combine the RGB components and form the list of colors
    color_gradient = [
        (int(r), int(g), int(b)) for r, g, b in zip(r_values, g_values, b_values)
    ]

    return color_gradient


def hex_to_rgb(hex_color):
    """
    Converts a hex color string to an RGB tuple.

    Parameters:
    - hex_color: str, a hex color string (e.g., '#FF5733')

    Returns:
    - tuple of integers representing the RGB values (r, g, b)
    """
    # Remove the '#' character if present
    hex_color = hex_color.lstrip("#")

    # Convert the hex string to RGB tuple
    return tuple(int(hex_color[i : i + 2], 16) for i in (0, 2, 4))

In [None]:
def circular_kde(r, angles, bw):
    density = np.zeros(len(r))
    for angle in angles:
        density += vonmises.pdf(r, bw, loc=angle)
    density /= len(angles)
    return density

In [None]:
def get_3qmas(frm, tag):
    x = frm.copy()
    x.arrays["ms"] = x.arrays[tag].reshape(-1, 3, 3)

    if "QM" in tag:
        x.arrays["efg"] = x.arrays["QM_efg"].reshape(-1, 3, 3)
    else:
        x.arrays["efg"] = x.arrays["ML_ISD_efg"].reshape(-1, 3, 3)
    ms_iso = nmr.MSIsotropy().get(x)
    ms_zeta = nmr.MSAnisotropy().get(x)
    ms_eta = nmr.MSAsymmetry().get(x)
    cq = nmr.EFGQuadrupolarConstant().get(x)
    efg_eta = nmr.EFGAsymmetry().get(x)
    O = [
        Site(
            isotope="17O",
            isotropic_chemical_shift=ms_iso[ii],  # in ppm
            shielding_symmetric=SymmetricTensor(
                zeta=ms_zeta[ii], eta=ms_eta[ii]
            ),  # zeta in ppm
            quadrupolar=SymmetricTensor(Cq=cq[ii], eta=efg_eta[ii]),
        )
        for ii in np.arange(len(x))[(x.numbers == 8)]
    ]
    spin_system = [SpinSystem(sites=[y]) for y in O]
    sim = Simulator(spin_systems=spin_system)
    method = ThreeQ_VAS(
        channels=["17O"],
        magnetic_flux_density=14.1,  # in T
        # rotor_frequency=10e3,
        spectral_dimensions=[
            SpectralDimension(
                count=1000,
                spectral_width=4e4,  # in Hz
                # reference_offset=-10.5e3,  # in Hz
                reference_offset=0,  # in Hz
                label="Isotropic dimension",
            ),
            # The last spectral dimension block is the direct-dimension
            SpectralDimension(
                count=1000,
                spectral_width=4.5e4,  # in Hz
                # reference_offset=2e4,  # in Hz
                reference_offset=0,
                label="MAS dimension",
            ),
        ],
    )
    sim.methods = [method]  # add the method.
    sim.run()  # Run the simulation

    dataset = sim.methods[0].simulation
    processor = sp.SignalProcessor(
        operations=[
            # Gaussian convolution along both dimensions.
            sp.IFFT(dim_index=(0, 1)),
            sp.apodization.Gaussian(FWHM="0.3 kHz", dim_index=0),
            sp.apodization.Gaussian(FWHM="0.15 kHz", dim_index=1),
            sp.FFT(dim_index=(0, 1)),
        ]
    )

    processed_dataset = processor.apply_operations(dataset=dataset)
    processed_dataset /= processed_dataset.max()

    return processed_dataset.real

In [None]:
def get_pass(frm, tag):
    x = frm.copy()
    x.arrays["ms"] = x.arrays[tag].reshape(-1, 3, 3)
    ms_iso = nmr.MSIsotropy().get(x)
    ms_zeta = nmr.MSAnisotropy().get(x)
    ms_eta = nmr.MSAsymmetry().get(x)
    Si = [
        Site(
            isotope="29Si",
            isotropic_chemical_shift=ms_iso[ii],  # in ppm
            shielding_symmetric=SymmetricTensor(
                zeta=ms_zeta[ii], eta=ms_eta[ii]
            ),  # zeta in ppm
        )
        for ii in np.arange(len(x))[(x.numbers == 14)]
    ]
    spin_system = [SpinSystem(sites=[y]) for y in Si]
    sim = Simulator(spin_systems=spin_system)

    PASS = SSB2D(
        channels=["29Si"],
        magnetic_flux_density=9.7,
        rotor_frequency=1000,
        spectral_dimensions=[
            SpectralDimension(
                count=30 * 4,
                spectral_width=1000 * 30,  # value in Hz; count vs rotor frequency
                label="Anisotropic dimension",
            ),
            SpectralDimension(
                count=18192,
                spectral_width=1e5,  # value in Hz
                # reference_offset=1.85e4,  # value in Hz
                reference_offset=0,  # value in Hz
                label="Isotropic dimension",
            ),
        ],
    )
    sim.methods = [PASS]  # add the method.

    sim.config.number_of_sidebands = 8

    # run the simulation.
    sim.run()

    dataset = sim.methods[0].simulation
    processor = sp.SignalProcessor(
        operations=[
            sp.IFFT(dim_index=0),
            sp.apodization.Exponential(FWHM="100 Hz", dim_index=0),
            sp.FFT(dim_index=0),
        ]
    )
    processed_dataset = processor.apply_operations(dataset=dataset).real
    processed_dataset /= processed_dataset.max()

    return processed_dataset.real

In [None]:
colors = ["#0000ff", "#FF0000"]
colors_rgb = [hex_to_rgb(c) for c in colors]

In [None]:
frames = ase_io.read(
    "../data/cristobalite/cristobalite_alpha_beta_inversion_ml_isd_ms_efg.xyz", ":"
)

In [None]:
for x in frames:
    x.arrays["ML_ISD_ms"] = x.arrays["ML_ISD_ms"].reshape(-1, 3, 3)
    x.arrays["ML_ISD_efg"] = x.arrays["ML_ISD_efg"].reshape(-1, 3, 3)

In [None]:
scores = np.asarray([x.info["MLSI_alpha_score"] for x in frames])

In [None]:
plt.plot(scores)

In [None]:
# get the thermal averages of alpha nad beta frames

alpha = deepcopy(frames[0])
beta = deepcopy(frames[0])

alpha.arrays["ms"] = np.zeros_like(alpha.arrays["ML_ISD_ms"])
alpha.arrays["efg"] = np.zeros_like(alpha.arrays["ML_ISD_efg"])

beta.arrays["ms"] = np.zeros_like(alpha.arrays["ML_ISD_ms"])
beta.arrays["efg"] = np.zeros_like(alpha.arrays["ML_ISD_efg"])

alph_frames = [
    frames[i]
    for i in np.arange(len(frames))[np.logical_and(scores >= 0.9, scores < 0.99)]
][:]

beta_frames = [frames[i] for i in np.arange(len(frames))[scores <= 0.1]][:]

for x in alph_frames:
    alpha.arrays["ms"] += x.arrays["ML_ISD_ms"]
    alpha.arrays["efg"] += x.arrays["ML_ISD_efg"]
alpha.arrays["ms"] /= len(alph_frames)
alpha.arrays["efg"] /= len(alph_frames)
alpha.arrays["ML_ISD_ms"] = alpha.arrays["ms"]
alpha.arrays["ML_ISD_efg"] = alpha.arrays["efg"]

for x in beta_frames:
    beta.arrays["ms"] += x.arrays["ML_ISD_ms"]
    beta.arrays["efg"] += x.arrays["ML_ISD_efg"]
beta.arrays["ms"] /= len(beta_frames)
beta.arrays["efg"] /= len(beta_frames)
beta.arrays["ML_ISD_ms"] = beta.arrays["ms"]
beta.arrays["ML_ISD_efg"] = beta.arrays["efg"]

print(len(alph_frames), len(beta_frames))

In [None]:
# get the crystal limits

pristine_pred = ase_io.read("../data/cristobalite/crystals.xyz", ":")
for x in pristine_pred:
    x.arrays["ML_ISD_ms"] = x.arrays["ML_ISD_ms"].reshape(-1, 3, 3)
    x.arrays["ML_ISD_efg"] = x.arrays["ML_ISD_efg"].reshape(-1, 3, 3)
    x.arrays["ms"] = x.arrays["ML_ISD_ms"]
    x.arrays["efg"] = x.arrays["ML_ISD_efg"]


pristine_alpha_pred = pristine_pred[1]
si_idx_pristine_alpha_pred = pristine_alpha_pred.numbers == 14
o_idx_pristine_alpha_pred = pristine_alpha_pred.numbers == 8

pristine_beta_pred = pristine_pred[-1]
si_idx_pristine_beta_pred = pristine_beta_pred.numbers == 14
o_idx_pristine_beta_pred = pristine_beta_pred.numbers == 8

In [None]:
# select frames with decreasing scores only
initial = np.arange(scores.argmin() - 810, scores.argmin() + 1)
start = initial[0] + scores[initial].argmax()
idx_traj = [start]
for j in range(start, initial[-1] + 2, 1):
    if scores[j] < scores[idx_traj[-1]]:
        idx_traj.append(j)
idx_traj = np.array(idx_traj)

In [None]:
# plot the orderd scores
plt.plot(scores[idx_traj][:], ".-")

In [None]:
# select an alpha, 50-50, and beta frames
idx_traj = idx_traj[[3, 36, 66]]

In [None]:
# get the 2D spectra

single_3qmas = []
for i in idx_traj:
    single_3qmas.append(get_3qmas(frames[i], tag="ML_ISD_ms"))

single_pass = []
for i in idx_traj:
    single_pass.append(get_pass(frames[i], tag="ML_ISD_ms"))

In [None]:
# build the MAS spectra of the crystal limit

spectra_si_pristine_pred = []
pass_pristine_pred = []
mqmas_pristine_pred = []

N = 16
gb = 0.5
minfreq = 390
maxfreq = 490
nbins = 10 * int((maxfreq - minfreq) / gb)

for x in [pristine_alpha_pred, pristine_beta_pred]:
    x.arrays["ms"] = x.arrays["ML_ISD_ms"]
    calc = NMRCalculator(x, larmor_frequency=79.459)
    calc.set_powder(N=N)
    a, freqs_si_pristine_pred = calc.spectrum_1d(
        "29Si",
        effects=NMRFlags.MAS,
        max_freq=maxfreq,
        min_freq=minfreq,
        bins=nbins,
        freq_broad=gb,
    )
    spectra_si_pristine_pred.append(a)
    pass_pristine_pred.append(get_pass(x, tag="ML_ISD_ms"))
    mqmas_pristine_pred.append(get_3qmas(x, tag="ML_ISD_ms"))
spectra_si_pristine_pred = np.array(spectra_si_pristine_pred)

In [None]:
# build the MAS spectra of the thermal average alpha

spectra_si_alpha = []


N = 16
gb = 0.5
minfreq = 390
maxfreq = 490
nbins = 10 * int((maxfreq - minfreq) / gb)

for x in [alpha]:
    calc = NMRCalculator(x, larmor_frequency=79.459)
    calc.set_powder(N=N)
    a, freqs_si_alpha = calc.spectrum_1d(
        "29Si",
        effects=NMRFlags.MAS,
        max_freq=maxfreq,
        min_freq=minfreq,
        bins=nbins,
        freq_broad=gb,
    )
    spectra_si_alpha.append(a)
spectra_si_alpha = np.array(spectra_si_alpha)[0]

pass_alpha = get_pass(alpha, tag="ML_ISD_ms")
mqmas_alpha = get_3qmas(alpha, tag="ML_ISD_ms")

In [None]:
# build the MAS spectra of the thermal average beta

spectra_si_beta = []


N = 16
gb = 0.5
minfreq = 390
maxfreq = 490
nbins = 10 * int((maxfreq - minfreq) / gb)

for x in [beta]:
    calc = NMRCalculator(x, larmor_frequency=79.459)
    calc.set_powder(N=N)
    a, freqs_si_beta = calc.spectrum_1d(
        "29Si",
        effects=NMRFlags.MAS,
        max_freq=maxfreq,
        min_freq=minfreq,
        bins=nbins,
        freq_broad=gb,
    )
    spectra_si_beta.append(a)
spectra_si_beta = np.array(spectra_si_beta)[0]

pass_beta = get_pass(beta, tag="ML_ISD_ms")
mqmas_beta = get_3qmas(beta, tag="ML_ISD_ms")

In [None]:
# get some properties from the thermal averages

efg_asymm_alpha = nmr.EFGAsymmetry().get(alpha)
efg_asymm_beta = nmr.EFGAsymmetry().get(beta)

ms_euler_alpha = nmr.MSEuler().get(alpha, convention="zyz", passive=True)
ms_euler_beta = nmr.MSEuler().get(beta, convention="zyz", passive=True)

In [None]:
# get some properties from the crystallien limit

efg_asymm_pristine_alpha_pred = nmr.EFGAsymmetry().get(pristine_alpha_pred)
efg_asymm_pristine_beta_pred = nmr.EFGAsymmetry().get(pristine_beta_pred)

ms_euler_pristine_alpha_pred = nmr.MSEuler().get(
    pristine_alpha_pred, convention="zyz", passive=True
)
ms_euler_pristine_beta_pred = nmr.MSEuler().get(
    pristine_beta_pred, convention="zyz", passive=True
)

In [None]:
# load experimental spectra

exp = np.loadtxt(
    "../data/cristobalite//cristo-double-peak-rev.csv", skiprows=1, delimiter=","
)
cs = CubicSpline(exp[:, 0][::-1], exp[:, 1][::-1])
xs = np.linspace(-115, -108, 1000)

shift = xs.max() - 443.2

In [None]:
# silicon and oxygen indices

si_idx = frames[0].numbers == 14
o_idx = frames[0].numbers == 8

In [None]:
color_path = interpolate_color(colors_rgb[0], colors_rgb[1], len(idx_traj))
color_path = [np.array(c) / 256 for c in color_path]

In [None]:
fig = plt.figure(figsize=(7.2, 3.6), constrained_layout=True)

levels = [0.03, 0.10, 0.50, 0.90, 0.97]
options = dict(levels=levels, alpha=0.75, linewidths=0.5)  # plot options

ax = fig.add_subplot(231, projection="csdm")
for i, x in enumerate(single_pass):
    ax.contour(x, colors=color_path[i], **options, linestyles="solid")
ax.set_ylabel("$^{29}$Si anisotropic dimension (ppm)", fontsize=9)
ax.set_xlabel("$^{29}$Si isotropic dimension (ppm)", fontsize=9)
ax.text(469, -25, "$\\alpha:\\beta$=99:1", color=color_path[0], fontsize=8)
ax.text(469, -15, "$\\alpha:\\beta$=48:52", color=color_path[1], fontsize=8)
ax.text(469, -5, "$\\alpha:\\beta$=2:98", color=color_path[2], fontsize=8)

ax.tick_params(which="both", labelsize=9)
ax.set_ylim(35, -35)
ax.set_xlim(470, 420)
ax.set_title("$\mathbf{a}$", loc="left", fontsize=9, fontweight="bold")

#######################################################
#######################################################
#######################################################

ax = fig.add_subplot(234)
s = (
    spectra_si_pristine_pred[0] / si_idx_pristine_alpha_pred.sum()
    + spectra_si_pristine_pred[1] / si_idx_pristine_beta_pred.sum()
)
f = freqs_si_pristine_pred
ax.plot(f[f < 442], s[f < 442], ":", lw=0.75, c=color_path[0])
ax.plot(f[f > 442], s[f > 442], ":", lw=0.75, c=color_path[-1])

f = freqs_si_alpha
s = spectra_si_alpha / si_idx.sum() + spectra_si_beta / si_idx.sum()
ax.plot(f[f < 439.5], s[f < 439.5], lw=0.75, c=color_path[0])
ax.plot(f[f > 439.5], s[f > 439.5], lw=0.75, c=color_path[-1])


ax.plot(xs - shift, cs(xs)[::-1] * 5.5, "k--", lw=0.75)

ax.set_xlabel("$^{29}$Si NMR frequency (ppm)", fontsize=9)
ax.set_ylabel("Intensity (arb. units)", fontsize=9)
ax.set_yticks(())
ax.set_yticklabels(())
ax.tick_params(labelsize=9)
ax.set_xlim(462, 434)
ax.set_title("$\mathbf{b}$", loc="left", fontsize=9, fontweight="bold")
#######################################################
#######################################################
#######################################################

ax = fig.add_subplot(232, projection="csdm")

ax.contour(pass_alpha, colors=colors[0], **options, linestyles="solid")
ax.contour(pass_beta, colors=colors[1], **options, linestyles="solid")
ax.contour(pass_pristine_pred[0], colors=colors[0], **options, linestyles=":")
ax.contour(pass_pristine_pred[1], colors=colors[1], **options, linestyles=":")
ax.set_ylabel("$^{29}$Si anisotropic dimension (ppm)", fontsize=9)
ax.set_xlabel("$^{29}$Si isotropic dimension (ppm)", fontsize=9)
ax.set_xlim((465, 434))
ax.set_ylim((4, -4))
ax.tick_params(labelsize=9)
ax.set_title("$\mathbf{c}$", loc="left", fontsize=9, fontweight="bold")

#######################################################
#######################################################
#######################################################

ax = fig.add_subplot(235, projection="csdm")

ax.contour(mqmas_alpha, colors=colors[0], **options, linestyles="solid")
ax.contour(mqmas_beta, colors=colors[1], **options, linestyles="solid")
ax.contour(mqmas_pristine_pred[0], colors=colors[0], **options, linestyles=":")
ax.contour(mqmas_pristine_pred[1], colors=colors[1], **options, linestyles=":")
ax.set_ylabel("$^{17}$O isotropic dimension (ppm)", fontsize=9)
ax.set_xlabel("$^{17}$O MAS dimension (ppm)", fontsize=9)
ax.tick_params(labelsize=9)
ax.set_xlim((228, 147))
ax.set_ylim((-126, -142))
ax.set_title("$\mathbf{d}$", loc="left", fontsize=9, fontweight="bold")

#######################################################
#######################################################
#######################################################

ax = fig.add_subplot(233)
a = efg_asymm_alpha[o_idx]
b = efg_asymm_beta[o_idx]
mi = min(a.min(), b.min()) * 0.0
ma = max(a.max(), b.max()) * 1.3
r = np.linspace(mi, ma, 1000)
kde_a = gaussian_kde(a)
kde_b = gaussian_kde(b)
kde_a = kde_a(r)
kde_b = kde_b(r)
ax.plot(r, kde_a, lw=0.75, color=colors[0])
ax.plot(r, kde_b, lw=0.75, color=colors[1])
ax.set_yticks(())
ax.tick_params(labelsize=9)
ax.set_xlabel("oxygen <$\eta_Q$>", fontsize=9)
ax.set_ylabel("distribution (arb. units)", fontsize=9)
a = efg_asymm_pristine_alpha_pred[o_idx_pristine_alpha_pred]
b = efg_asymm_pristine_beta_pred[o_idx_pristine_beta_pred]
mi = min(a.min(), b.min()) * 0.0
ma = max(a.max(), b.max()) * 1.3
r = np.linspace(mi, ma, 1000)
kde_a = gaussian_kde(a, bw_method=10)
kde_b = gaussian_kde(b, bw_method=7)
kde_a = kde_a(r)
kde_b = kde_b(r)
ax.plot(r, kde_a, lw=0.75, color=colors[0], ls=":")
ax.plot(r, kde_b, lw=0.75, color=colors[1], ls=":")
ax.set_title("$\mathbf{e}$", loc="left", fontsize=9, fontweight="bold")

#######################################################
#######################################################
#######################################################

ax = fig.add_subplot(236)
a = ms_euler_alpha[:, 0][o_idx]
b = ms_euler_beta[:, 0][o_idx]
mi = min(a.min(), b.min()) * 0.9
ma = max(a.max(), b.max()) * 1.1
r = np.linspace(-np.pi, np.pi, 1000)
kde_a = circular_kde(r, a, bw=100)
kde_b = circular_kde(r, b, bw=100)
ax.plot(r, kde_a, lw=0.75, color=colors[0])
ax.plot(r, kde_b, lw=0.75, color=colors[1])

a = ms_euler_pristine_alpha_pred[:, 0][o_idx_pristine_alpha_pred]
b = ms_euler_pristine_beta_pred[:, 0][o_idx_pristine_beta_pred]
mi = min(a.min(), b.min()) * 0.9
ma = max(a.max(), b.max()) * 1.1
r = np.linspace(-np.pi, np.pi, 1000)
kde_a = circular_kde(r, a, bw=100)
kde_b = circular_kde(r, b, bw=100)
ax.plot(r, kde_a, lw=0.75, color=colors[0], ls=":")
ax.plot(r, kde_b, lw=0.75, color=colors[1], ls=":")

ax.set_yticks(())
ax.tick_params(labelsize=9)
ax.set_xticklabels(["-$\pi$", "0", "$\pi$"])
ax.set_xticks([-np.pi, 0, np.pi])
ax.set_xlabel("oxygen <$\\alpha_{\sigma}$>", fontsize=9)
ax.set_ylabel("distribution (arb. units)", fontsize=9)
ax.tick_params(labelsize=9)
ax.set_title("$\mathbf{f}$", loc="left", fontsize=9, fontweight="bold")

#######################################################
#######################################################
#######################################################

lines = [Line2D([0], [0], color=c, linewidth=3.75, linestyle="-") for c in colors]
labels = ["$\\alpha$", "$\\beta$"]
custom_lines = lines + [
    Line2D([0], [0], color=colors[0], lw=0.75, ls="-"),
    Line2D([0], [0], color=colors[1], lw=0.75, ls="-"),
    Line2D([0], [0], color=colors[0], lw=0.75, ls=":"),
    Line2D([0], [0], color=colors[1], lw=0.75, ls=":"),
]


fig.legend(
    custom_lines,
    labels + ["ML MD", "ML MD", "ML static", "ML static"],
    fontsize=9,
    ncols=3,
    loc="upper left",
    bbox_to_anchor=(0.3, 1.2),
    frameon=False,
)

# fig.savefig("./cristo_inversion_full_pred-prop_v4.svg", dpi=300, bbox_inches="tight")