In [None]:
import os
import re

from importlib import reload

import numpy as np
import pandas as pd

import seaborn as sns
import matplotlib.pylab as plt

from luescher_nd.database import utilities as ut
from luescher_nd.database.utilities import DATA_FOLDER

from luescher_nd.operators.a1g import projector

from luescher_nd.operators import get_projector_to_not_a1g
from luescher_nd.database.utilities import get_degeneracy

from luescher_nd.hamiltonians.contact import MomentumContactHamiltonian

from luescher_nd.database import utilities as ut

import numpy as np
from scipy.sparse.linalg import eigsh

In [None]:
%load_ext blackcellmagic

In [None]:
file_name = "contact-fitted_a-inv=+0.0_zeta=cartesian_projector=a1g_n-eigs=200.sqlite"
df = ut.read_table(
    os.path.join(DATA_FOLDER, file_name),
    zeta=None,
    round_digits=2,
    filter_poles=False,
    filter_by_nstates=False,
    filter_degeneracy=False
).query("nlevel == 0 and nstep == -1")[["n1d", "epsilon", "nstep", "L", "x", "nlevel", "contact_strength", "E"]]

In [None]:
interactions = df.set_index(["n1d", "epsilon"])["contact_strength"].to_dict()
interactions

In [None]:
ndim = 3
contact_strength = -0.10822073617155972

In [None]:
def get_a1g_basis(n1d, ndim=3):
    degs = get_degeneracy(n1d, ndim)
    p = projector(n1d, ndim)
    basis = {}
    for vec_set in degs.values():
        for vec in vec_set:
            pvec = (n1d ** np.arange(ndim)) @ np.array(vec)
            bvec = p.T[pvec].toarray().flatten()
            bvec /= np.sqrt(bvec@bvec)
            basis[vec] = bvec
    return basis

In [None]:
data = []

for (n1d, epsilon), c0 in interactions.items():
    if n1d >=30 or n1d % 2 != 0 or n1d * epsilon > 1:
        continue
    
    p = projector(n1d, ndim)
    pnot = get_projector_to_not_a1g(n1d, ndim)

    basis = get_a1g_basis(n1d, ndim=ndim)
    

    h = MomentumContactHamiltonian(
        n1d,
        epsilon=epsilon,
        ndim=ndim,
        nstep=None,
        contact_strength=c0,
        filter_out=pnot,
        filter_cutoff=3.0e2,
    )
    E, v = eigsh(h.op, k=20, which="SA", tol=1.e-16)
    x = 2 * h.mass / 2 * E * h.L**2 / 4 / np.pi**2
    
    for nlevel, (xx, vv) in enumerate(zip(x, v.T)):
        for key, bv in basis.items():
            coeff = bv@vv
            overlap = coeff**2
            
            if overlap > 1.e-2:
                data.append({
                    "n1d": n1d,
                    "epsilon": epsilon,
                    "L": epsilon*n1d,
                    "x": xx,
                    "nlevel":  nlevel,
                    "overlap": overlap,
                    "coeff": coeff,
                    "a1g": key, 
                })
                
df = pd.DataFrame(data)

In [None]:
df.head()

In [None]:
df.query("nlevel == 9")

In [None]:
cmap = sns.color_palette("BuGn_r", n_colors=10)[::-1]

def heatmap(**kwargs):
    """
    """
    frame = kwargs["data"]
    pivot = frame.pivot(values="overlap", index="a1g", columns="epsilon")
    sns.heatmap(pivot, vmin=1.e-2, vmax=1, cmap=cmap, cbar=False, linewidths=1)


In [None]:
grid = sns.FacetGrid(
    data=df.query("L == 1 and (nlevel == 0 or nlevel == 5 or nlevel ==9)"),
    row="nlevel",
    col="L",
    sharey=False,
    sharex=False,
    margin_titles=False,
)

grid.map_dataframe(heatmap)

for ax in grid.axes.flatten():
    ax.set_yticklabels(ax.get_yticklabels(), rotation=0)

#grid.fig.set_dpi(250)
    
plt.show(grid.fig)