In [None]:
import os
import re

from importlib import reload

import numpy as np
import pandas as pd

import seaborn as sns
import matplotlib.pylab as plt

import gvar as gv
import lsqfit

from luescher_nd.database import utilities as ut
from luescher_nd.database.utilities import DATA_FOLDER

sns.set(context="paper", style="ticks", font_scale=1)

In [None]:
%load_ext blackcellmagic

In [None]:
files = [f for f in os.listdir(DATA_FOLDER) if f.endswith(".sqlite") and not "tmp" in f]
print("\n".join([f"{n:2d} {f}" for n, f in enumerate(files)]))

In [None]:
file_name = "contact-fitted_a-inv=+0.0_zeta=spherical_projector=a1g_n-eigs=200.sqlite"
df = ut.read_table(
    os.path.join(DATA_FOLDER, file_name),
    zeta=None,
    round_digits=2,
    filter_poles=False,
    filter_by_nstates=False,
    filter_degeneracy=False,
).query("nlevel < 20 and epsilon < 0.2 and L == 1.0")[
    ["n1d", "epsilon", "nstep", "L", "x", "nlevel", "contact_strength", "E"]
]
df["L"] = df.L.round(7)
df.head()


In [None]:
nstep_label = lambda nstep: f"{int(nstep)}" if nstep > 0 else "\infty"
df["nstep_label"] = df.apply(lambda row: nstep_label(row["nstep"]), axis=1)

In [None]:
def plot(x_lablel, y_label, **kwargs):
    data = kwargs.pop("data")
    x = data[x_lablel]
    y = data[y_label]
    plt.plot(x, y, **kwargs)
    
    if kwargs["label"] == 0:
        degs = []
        for eps, n1d in zip(data.epsilon.unique(), data.n1d.unique()):
            degs += [n2 for n2 in ut.get_degeneracy(n1d) if n2 < 20]
            
        for d in set(degs):
            plt.axhline(d, ls=":", color="black", lw=0.5)

In [None]:
grid = sns.FacetGrid(
    data=df.reset_index().sort_values("epsilon"),
    col="nstep_label",
    hue="nlevel",
    sharey=False,
    margin_titles=True,
    col_order=[nstep_label(nstep) for nstep in [1, 2, 4, -1]]
)

grid.map_dataframe(plot, "epsilon", "x", marker=".", ls=":", zorder=10)
#grid.add_legend(title="Eigenstate index")

grid.set_ylabels("$ x $")
grid.set_xlabels("$ \epsilon \, [\mathrm{fm}]$")
grid.set_titles(
    #row_template="${row_var} = {row_name} \, [\mathrm{{fm}}]$",
    col_template="$n_\mathrm{{step}} = {col_name}$"
)

for ax in grid.axes.flatten():
    ax.set_xscale("log", basex=2)
    ax.set_xlim(2**-6, 2**-3)
    
plt.show(grid.fig)

In [None]:
odd = ut.get_continuum_extrapolation(df, odd_poly=True, include_statistics=True, n_poly_max=10)
even = ut.get_continuum_extrapolation(df, odd_poly=False, include_statistics=True, n_poly_max=5)

ff = pd.concat([odd, even], ignore_index=True, sort=False)

In [None]:
fff = ff.copy()
fff["P"] = (
    fff.groupby(["L", "epsilon", "nstep", "nlevel"], as_index=False)
    .apply(
        lambda frame: pd.Series(
            np.exp(frame.logGBF - frame.logGBF.max()), index=frame.index
        )
    )
    .reset_index(0, drop=True)
)
fff = fff.set_index(
    ["L", "epsilon", "nstep", "nlevel", "n_poly_max", "even"]
).sort_values(["L", "epsilon", "nstep", "nlevel", "P"], ascending=False)
fff.head()

fff["n1d"] = np.inf

In [None]:
def plot(x_lablel, y_label, **kwargs):
    data = kwargs.pop("data")
    x = data[x_lablel]
    y = data[y_label]
    plt.plot(x, y, **kwargs)

In [None]:
degs = [n2 for n2 in ut.get_degeneracy(10) if n2 < 20]

In [None]:
grid = sns.FacetGrid(
    data=df.reset_index().sort_values("epsilon"),
    col="nstep_label",
    hue="nlevel",
    sharey=False,
    margin_titles=True,
    col_order=[nstep_label(nstep) for nstep in [1, 2, 4, -1]]
)

grid.map_dataframe(plot, "epsilon", "x", marker=".", zorder=10, ls="None")
#grid.add_legend(title="Eigenstate index")

grid.set_ylabels("$ x $")
grid.set_xlabels("$ \epsilon \, [\mathrm{fm}]$")
grid.set_titles(
    #row_template="${row_var} = {row_name} \, [\mathrm{{fm}}]$",
    col_template="$n_\mathrm{{step}} = {col_name}$"
)

for ax in grid.axes.flatten():
    #ax.set_xscale("log", basex=2)
    ax.set_xlim(2**-6, 2**-5)
    for deg in degs:
        ax.axhline(deg, color="grey", ls=":", zorder=-1, lw=0.5)

eps = np.linspace(min(2**-6, df.epsilon.min()), df.epsilon.max(), 1000)
        
for axs in grid.axes:
    for nstep, ax in zip([1, 2, 4, -1], axs):
        fit = fff.reset_index().query("nstep == @nstep and P == 1")
        for idx, row in fit.iterrows():
            c = grid._colors[row["nlevel"]]
            p = {
                "x": np.array(
                    [
                        row[col]
                        for col in row.index
                        if re.match("x[0-9]+", col) and isinstance(row[col], gv.GVar)
                    ]
                )
            }
            xfit = ut._poly(eps, p)
            ax.plot(eps, gv.mean(xfit), color=c, zorder=0, ls="-", lw=0.5)
    
grid.fig.set_dpi(150)
    
plt.show(grid.fig)

In [None]:
fitframe = fff.reset_index().query("P == 1").set_index(["L", "nlevel", "nstep"])

In [None]:
def fit(inp_row):
    row = fitframe.loc[
        tuple(
            [
                dtype(inp_row[col])
                for dtype, col in [(float, "L"), (int, "nlevel"), (int, "nstep")]
            ]
        )
    ]
    p = {
        "x": np.array(
            [
                row[col]
                for col in row.index
                if re.match("x[0-9]+", col) and isinstance(row[col], gv.GVar)
            ]
        )
    }
    return ut._poly(inp_row["epsilon"], p)



In [None]:
diff = df.copy()[["n1d", "epsilon", "nstep", "L", "nlevel", "x"]]
diff.head()

diff["xfit"] = diff.apply(fit, axis=1)
diff["diff"] = diff["x"] - diff["xfit"]

In [None]:
grid = sns.FacetGrid(
    data=diff.reset_index().sort_values("epsilon").query("nlevel > 0"),
    col="nstep",
    hue="nlevel",
    sharey=False,
    margin_titles=True,
    col_order=[nstep for nstep in [1, 2, 4, -1]]
)

def gvarplot(*args, **kwargs):
    plt.plot(args[0], np.abs(gv.mean(args[1].values)), **kwargs)

grid.map(gvarplot, "epsilon", "diff", marker=".", zorder=10, ls="-")
#grid.add_legend(title="Eigenstate index")

grid.set_ylabels("$ x $")
grid.set_xlabels("$ \epsilon \, [\mathrm{fm}]$")
grid.set_titles(
    #row_template="${row_var} = {row_name} \, [\mathrm{{fm}}]$",
    col_template="$n_\mathrm{{step}} = {col_name}$"
)

for ax in grid.axes.flatten():
    ax.set_xscale("log", basex=2)
    #ax.set_xlim(2**-6, 2**-5)
    ax.set_yscale("log")
    
grid.fig.set_dpi(150)
    
plt.show(grid.fig)