In [None]:
import numpy as np
import matplotlib.pyplot as plt


In [None]:
import cp2k_spm_tools.postprocess.overlap as pp

In [None]:
pdos_folder = "/home/kristjan/local_work/uks_pdos_tests/benz-diradical-scf-au-uks3/"

dos = pp.process_pdos_files(pdos_folder)

In [None]:
npz_path = "/home/kristjan/local_work/uks_pdos_tests/overlap_run/overlap-uks3-0.2.npz"

od = pp.load_overlap_npz(npz_path)
om = pp.match_and_reduce_spin_channels(od['overlap_matrix'])

In [None]:
fwhm = 0.10
de = np.min([fwhm/5, 0.005])
energy_arr = np.arange(-3.0, 3.0, de)

In [None]:
mpl_def_colors = plt.rcParams['axes.prop_cycle'].by_key()['color']

In [None]:
pdos_series = [
    ['tdos', 'lightgray', 0.02, 'TDOS'],
    ['mol', 'black', 1.0, 'molecule PDOS'],
]

In [None]:
fig = plt.figure(figsize=(12, 6))

ax1 = plt.gca()
ylim = [None, None]

mol_series = []

### PDOS
for pdos_ser in pdos_series:
    label = pdos_ser[3] if pdos_ser[2] == 1.0 else fr"${pdos_ser[2]}\cdot${pdos_ser[3]}"
    d = dos[pdos_ser[0]]
    for i_spin in range(len(d)):
        series = pp.create_series_w_broadening(d[i_spin][:, 0], d[i_spin][:, 1], energy_arr, fwhm)
        series *= pdos_ser[2]

        kwargs = {}
        if i_spin == 0:
            kwargs['label'] = label
        if pdos_ser[0] == 'mol':
            kwargs['zorder'] = 300
            if i_spin == 0:
                ylim[1] = 1.2 * np.max(series)
            else:
                ylim[0] = 1.2 * np.min(-series)

            mol_series.append(series)


        ax1.plot(energy_arr, series * (-2* i_spin + 1), color=pdos_ser[1], **kwargs)

        ax1.fill_between(energy_arr, 0.0, series * (-2* i_spin + 1), color=pdos_ser[1], alpha=0.2)

### Overlap
for i_spin in range(od['nspin_g2']):
    cumulative = None
    for i_orb, energy in enumerate(od['energies_g2'][i_spin]):
        index = od['orb_indexes_g2'][i_spin][i_orb]
        i_wrt_homo = i_orb - od['homo_i_g2'][i_spin]
        label = pp.get_orbital_label(i_wrt_homo)

        spin_letter = ""
        if od['nspin_g2'] == 2:
            spin_letter = "a-" if i_spin == 0 else "b-"

        full_label = f'MO{index:2} {spin_letter}{label:6} (E={energy:5.2f})'

        series = pp.create_series_w_broadening(od['energies_g1'][i_spin], om[i_spin][:, i_orb], energy_arr, fwhm)

        if cumulative is None:
            cumulative = series
        else:
            cumulative += series

        # possibly due to numerical precision, the cumulative orbital makeup can slightly
        # surpass molecule PDOS. reduce it to the PDOS level
        if len(mol_series) != 0:
            surpass = cumulative > mol_series[i_spin]
            cumulative[surpass] = mol_series[i_spin][surpass]

        ax1.fill_between(energy_arr, 0.0, cumulative * (-2* i_spin + 1),
                         facecolor=mpl_def_colors[i_orb], alpha=1.0, zorder=-i_orb+100, label=full_label)

    if i_spin == 0 and od['nspin_g2'] == 2:
        # add empty legend entries to align the spin channels
        for i in range(len(pdos_series)):
            ax1.fill_between([0.0], 0.0, [0.0], color='w', alpha=0, label=' ')

plt.legend(ncol=od['nspin_g2'], loc='center left',bbox_to_anchor=(1.01, 0.5))

plt.xlim([np.min(energy_arr), np.max(energy_arr)])

if od['nspin_g2'] == 1:
    ylim[0] = 0.0
plt.ylim(ylim)

plt.axhline(0.0, color='k', lw=2.0, zorder=200)

plt.ylabel("Density of States [a.u.]")
plt.xlabel("$E-E_F$ [eV]")
plt.show()