In [None]:
import os
import re

from importlib import reload

import numpy as np
import pandas as pd

import seaborn as sns
import matplotlib.pylab as plt

from luescher_nd.database import utilities as ut
from luescher_nd.database.utilities import DATA_FOLDER

from luescher_nd.operators.a1g import projector

from luescher_nd.operators import get_projector_to_not_a1g
from luescher_nd.database.utilities import get_degeneracy

from luescher_nd.hamiltonians.contact import MomentumContactHamiltonian

from luescher_nd.database import utilities as ut

import numpy as np
from scipy.sparse.linalg import eigsh

In [None]:
%load_ext blackcellmagic

In [None]:
file_names = [f for f in os.listdir(DATA_FOLDER) if f.endswith(".h5")]

In [None]:
dfs = []
for file_name in file_names:
    dfs.append(
        pd.read_hdf(os.path.join(DATA_FOLDER, file_name), key="overlap").fillna(0)
    )

df = pd.concat(dfs, ignore_index=True)

n2 = (
    df.a1g.str.extractall("\((?P<nx>[0-9]+), (?P<ny>[0-9]+), (?P<nz>[0-9]+)\)")
    .reset_index(1, drop=True)
    .astype(int)
)
df["n2"] = (n2 ** 2).sum(axis=1)


In [None]:
cmap = sns.color_palette("BuGn_r", n_colors=10)[::-1]
CUTOFF = 0.02

def heatmap(**kwargs):
    """
    """
    frame = kwargs["data"]
    pivot = frame.pivot(values="overlap", index="a1g", columns="n1d")
    missing = 1 - pivot.sum(axis=0)
    missing[missing.abs() < CUTOFF] = np.nan
    missing = pd.DataFrame(data=[missing.values], columns=missing.index, index=["(...)"])
    if not all(missing.isna().values[0]):
        pivot = pd.concat([pivot, missing])
    
    ax = plt.gca()
    sns.heatmap(pivot, vmin=1.e-2, vmax=1, cmap=cmap, cbar=False, linewidths=1, annot=True, fmt="0.2f", ax=ax)
    y1, y2 = ax.get_ylim()
    ax.set_ylim(y1+0.5, y2-0.5)
    ax.tick_params(axis=u'both', which=u'both',length=0)


In [None]:
grid = sns.FacetGrid(
    data=df[df.nlevel.isin([0, 1, 2, 8, 9]) & df.n1d.isin([4, 10, 20, 30, 40, 50])]
    .query("overlap > @CUTOFF")
    .sort_values(["n1d", "n2"]),
    col="nstep",
    row="nlevel",
    sharey=False,
    sharex=False,
    margin_titles=True,
    aspect=1.2,
    col_order=[1,4,-1]
)

grid.map_dataframe(heatmap)

row_lables = [
    text for ax in grid.axes.flat for text in ax.texts if "nlevel" in text.get_text()
]
plt.setp(row_lables, text="")
grid.set_titles(
    col_template="$n_\mathrm{{step}} = {col_name}$",
    row_template="$n_\mathrm{{level}} = {row_name}$",
)

text = [ax.title for ax in grid.axes.flat if "-1" in ax.title.get_text()][0]
plt.setp([text], text=text.get_text().replace("-1", "\infty"))

plt.subplots_adjust(wspace=0.4, hspace=0.5)

for nlevel, axes in zip(grid.row_names, grid.axes):
    for nstep, ax in zip(grid.col_names, axes):
        tf = df.query("nlevel == @nlevel and nstep == @nstep")
        x_map = tf.groupby("n1d")["x"].mean().to_dict()

        ax.set_xlabel("$n_{1d}$")

        topax = ax.twiny()
        topax.xaxis.set_ticks_position("top")
        topax.xaxis.set_label_position("top")
        topax.set_xlabel("$x$")
        topax.set_xticklabels(
            [
                "{0:2.2f}".format(x_map[int(n1d.get_text())])
                for n1d in ax.get_xticklabels()
            ]
        )
        topax.set_xticks(ax.get_xticks())
        topax.set_xlim(ax.get_xlim())

sns.despine(grid.fig, left=True, bottom=True)

for ax in grid.axes.flatten():
    ax.set_yticklabels(
        [f"${label.get_text()}$" for label in ax.get_yticklabels()], rotation=0
    )
    ax.set_xticklabels([f"${label.get_text()}$" for label in ax.get_xticklabels()])

grid.set_ylabels(r"$\left\vert [p] \right\rangle \in A_{1g}$")


plt.show(grid.fig)
