In [None]:
import pickle
from collections import OrderedDict

import numpy as np
from matplotlib import pyplot as plt
from periodictable import elements
from tqdm import tqdm

In [None]:
from exfor_tools.distribution import AngularDistribution

In [None]:
from jitr.optical_potentials import kduq, wlh

In [None]:
import elm

In [None]:
neutron = (1, 0)
proton = (1, 1)

In [None]:
kduq_params_nn = kduq.get_samples_federal(neutron)
kduq_params_pp = kduq.get_samples_federal(proton)

## read in measurements and pre-computed workspaces

In [None]:
from pathlib import Path

output_dir = Path("./corpus/")
output_dir.mkdir(parents=True, exist_ok=True)

In [None]:
with open(output_dir / "nn_elastic_data.pkl", "rb") as f:
    nn_elastic_data = pickle.load(f)
with open(output_dir / "pp_elastic_data.pkl", "rb") as f:
    pp_elastic_data = pickle.load(f)

In [None]:
with open(output_dir / "nn_elastic_workspaces.pkl", "rb") as f:
    nn_elastic_workspaces = pickle.load(f)
with open(output_dir / "pp_elastic_workspaces.pkl", "rb") as f:
    pp_elastic_workspaces = pickle.load(f)

## set up corpora from workspaces

In [None]:
%%time
kduq_pp_corpus = elm.corpus.ElasticAngularCorpus(
    kduq.calculate_diff_xs,
    "dXS/dRuth",
    pp_elastic_workspaces,
    pp_elastic_data,
)

In [None]:
%%time
kduq_nn_corpus = elm.corpus.ElasticAngularCorpus(
    kduq.calculate_diff_xs,
    "dXS/dA",
    nn_elastic_workspaces,
    nn_elastic_data,
)

In [None]:
nparams_nn = len(kduq_params_nn[0])
nparams_pp = len(kduq_params_pp[0])
ndof_nn = kduq_nn_corpus.n_data_pts - nparams_nn
ndof_pp = kduq_pp_corpus.n_data_pts - nparams_pp

## Propagate KDUQ into constraint observables

In [None]:
inner_pctls = np.arange(10, 100, 10)
confidence_intervals_ec = np.hstack((50 - inner_pctls[::-1] / 2, 50 + inner_pctls / 2))
confidence_intervals_plot = [2.5, 97.5]  # inner 95%

In [None]:
inner_pctls

In [None]:
kduq_nn_ci_plot = []
kduq_nn_ci_ec = []
chi2_nn = np.zeros(
    (
        len(kduq_nn_corpus.constraints),
        len(kduq_params_nn),
    ),
    dtype=float,
)
for i, constraint in enumerate(tqdm(kduq_nn_corpus.constraints)):
    yth = np.vstack([constraint.model(p) for p in kduq_params_nn])
    residual = yth - constraint.y
    chi2_nn[i, :] = [r.T @ constraint.cov_inv @ r for r in residual]
    kduq_nn_ci_plot.append(np.percentile(yth, confidence_intervals_plot, axis=0))
    kduq_nn_ci_ec.append(np.percentile(yth, confidence_intervals_ec, axis=0))

In [None]:
plt.hist(np.sum(chi2_nn, axis=0) / ndof_nn, bins=np.logspace(1, 3, 50), density=True)
plt.title(f"KDUQ against {len(kduq_nn_corpus.constraints)} (n,n) measurements")
plt.xlabel(r"generalized $\chi^2/\rm{DOF}$")
plt.xscale("log")

In [None]:
kduq_pp_ci_plot = []
kduq_pp_ci_ec = []
chi2_pp = np.zeros(
    (
        len(kduq_pp_corpus.constraints),
        len(kduq_params_pp),
    ),
    dtype=float,
)
for i, constraint in enumerate(tqdm(kduq_pp_corpus.constraints)):
    yth = np.vstack([constraint.model(p) for p in kduq_params_pp])
    residual = yth - constraint.y
    chi2_pp[i, :] = [r.T @ constraint.cov_inv @ r for r in residual]
    kduq_pp_ci_plot.append(np.percentile(yth, confidence_intervals_plot, axis=0))
    kduq_pp_ci_ec.append(np.percentile(yth, confidence_intervals_ec, axis=0))

In [None]:
plt.hist(np.sum(chi2_pp, axis=0) / ndof_pp, bins=np.logspace(3, 7, 50), density=True)
plt.title(f"KDUQ against {len(kduq_pp_corpus.constraints)} (p,p) measurements")
plt.xlabel(r"generalized $\chi^2/\rm{DOF}$")
plt.xscale("log")

## Emprical coverage

## Plotting constraints and confidence intervals

In [None]:
from mpl_toolkits.axes_grid1.inset_locator import inset_axes

## $(p,p)$

In [None]:
def sort_by_subentry(constraints, other_list):
    return zip(*sorted(list(zip(constraints, other_list)), key=lambda x: x[0].subentry))

In [None]:
kduq_pp_corpus.constraints, kduq_pp_ci_plot = sort_by_subentry(
    kduq_pp_corpus.constraints, kduq_pp_ci_plot
)
for i, (constraint, (lower, upper)) in enumerate(
    zip(kduq_pp_corpus.constraints, kduq_pp_ci_plot)
):
    plt.errorbar(
        constraint.x * 180 / np.pi,
        constraint.y,
        yerr=np.sqrt(np.diag(constraint.covariance)),
        marker="s",
        markersize=2,
        linestyle="none",
        elinewidth=3,
    )

    plt.fill_between(constraint.x * 180 / np.pi, lower, upper, alpha=0.3)
    rxn = constraint.model.workspace.visualization_workspace.reaction.reaction_latex
    Elab = constraint.model.workspace.visualization_workspace.kinematics.Elab
    plt.title(
        f"{constraint.subentry}: ${rxn}$ at {Elab} MeV || reported as ${pp_elastic_data[i][1].quantity}$"
    )
    plt.xlabel(r"$\theta$ [deg]")
    plt.ylabel(r"$d\sigma/d \sigma_{\rm{Ruth}}$ [dimensionless]")

    plt.yscale("log")

    ax_inset = inset_axes(plt.gca(), width="30%", height="30%", loc="upper right")
    ax_inset.hist(chi2_pp[i, :] / len(constraint.x), bins=20, color="gray", alpha=0.7)
    ax_inset.set_xlabel(r"$\chi^2/N$")
    ax_inset.set_yticks([])
    ax_inset.patch.set_alpha(0.5)

    plt.show()

## $(n,n)$

In [None]:
kduq_nn_corpus.constraints, kduq_nn_ci_plot = sort_by_subentry(
    kduq_nn_corpus.constraints, kduq_nn_ci_plot
)

for i, (constraint, (lower, upper)) in enumerate(
    zip(kduq_nn_corpus.constraints, kduq_nn_ci_plot)
):
    plt.errorbar(
        constraint.x * 180 / np.pi,
        constraint.y,
        yerr=np.sqrt(np.diag(constraint.covariance)),
        marker="s",
        markersize=2,
        linestyle="none",
        elinewidth=3,
    )
    plt.fill_between(constraint.x * 180 / np.pi, lower, upper, alpha=0.3)
    rxn = constraint.model.workspace.visualization_workspace.reaction.reaction_latex
    Elab = constraint.model.workspace.visualization_workspace.kinematics.Elab

    plt.title(f"{constraint.subentry}: ${rxn}$ at {Elab} MeV")
    plt.xlabel(r"$\theta$ [deg]")
    plt.ylabel(r"$d\sigma/d\Omega$ [b/Sr]")
    plt.yscale("log")

    ax_inset = inset_axes(plt.gca(), width="30%", height="30%", loc="upper right")
    ax_inset.hist(chi2_nn[i, :] / len(constraint.x), bins=20, color="gray", alpha=0.7)
    ax_inset.set_xlabel("$\chi^2/N$")
    ax_inset.set_yticks([])
    ax_inset.patch.set_alpha(0.5)
    plt.show()