# Demonstration of the regridding algorithm in 2D

In [None]:
from pyrho.core.chargeDensity import ChargeDensity
from pymatgen.io.vasp import Chgcar
from pyrho.core.pgrid import PGrid
from pyrho.core.utils import get_padded_array


In [None]:
chgcar = Chgcar.from_hdf5("../test_files/Si.uc.hdf5")
chgcar = ChargeDensity.from_pmg_volumetric_data(chgcar)
chgcar.reorient_axis()

a_mat = chgcar.lattice[:2,:2]
data = chgcar.grid_data[24, :, :]
data_timmed = chgcar.grid_data[24, ::5, ::5]

In [None]:
pg_2D = PGrid(data_timmed, a_mat)

In [None]:
data_timmed.shape

In [None]:
av=np.linspace(0,1,pg_2D.grid_data.shape[0],endpoint=False)
bv=np.linspace(0,1,pg_2D.grid_data.shape[1],endpoint=False)
AA, BB = np.meshgrid(av,bv,indexing='ij') #indexing to match the lablled array
xx, yy = np.dot(pg_2D.lattice.T[:2,:2], [AA.flatten(),BB.flatten()])
xshift, yshift = np.dot(pg_2D.lattice.T[:2,:2], ((av[1]-av[0])/2.,(bv[1]-bv[0])/2.))
plt.scatter(xx+xshift, yy+yshift, c=np.log(pg_2D.grid_data.flatten()), edgecolors='black',alpha=0.4)

In [None]:
# fig, (ax1, ax2) = plt.subplots(1,2, figsize=(16,11))
from matplotlib import gridspec
import matplotlib.patches as mpatches
sns.set_palette("Paired")
fig = plt.figure(figsize=(20, 22))
gs = gridspec.GridSpec(2, 1, height_ratios=[2, 2])
ax1 = plt.subplot(gs[0])
ax2 = plt.subplot(gs[1])
fig.subplots_adjust(
    left=None, bottom=None, right=None, top=None, wspace=None, hspace=None
)


av = np.linspace(0, 1, pg_2D.grid_data.shape[0] + 1, endpoint=True)
bv = np.linspace(0, 1, pg_2D.grid_data.shape[1] + 1, endpoint=True)
AA, BB = np.meshgrid(av, bv, indexing="ij")  # indexing to match the lablled array
xxb, yyb = np.dot(pg_2D.lattice.T, [AA.flatten(), BB.flatten()])


pg_shifted = pg_2D.get_transformed_obj(
    sc_mat=[[1, 0], [0, 1]], frac_shift=[0, 0], grid_out=[12 * 4, 12 * 4]
)
av = np.linspace(0, 1, pg_shifted.grid_data.shape[0] + 1, endpoint=True)
bv = np.linspace(0, 1, pg_shifted.grid_data.shape[1] + 1, endpoint=True)
AA, BB = np.meshgrid(av, bv, indexing="ij")  # indexing to match the lablled array
xx, yy = np.dot(pg_shifted.lattice.T, [AA.flatten(), BB.flatten()])
ax1.scatter(
    xx,
    yy,
    c=np.log(get_padded_array(pg_shifted.grid_data).flatten()),
    s=10,
    edgecolors=None,
    alpha=0.9,cmap="viridis"
)
ax1.scatter(
    xxb,
    yyb,
    c=np.log(get_padded_array(pg_2D.grid_data).flatten()),
    s=180,
    edgecolors="black",
    alpha=1,cmap="viridis"
)

pg_shifted = pg_2D.get_transformed_obj(
    sc_mat=[[2, 0], [-1, 2]], frac_shift=[0.5, 0.5], grid_out=[12 * 4, 12 * 4]
)
av = np.linspace(0, 1, pg_shifted.grid_data.shape[0] + 1, endpoint=True)
bv = np.linspace(0, 1, pg_shifted.grid_data.shape[1] + 1, endpoint=True)
bv = bv + 0.25
av = av + 0.375
AA, BB = np.meshgrid(av, bv, indexing="ij")  # indexing to match the labled array
xx, yy = np.dot(pg_shifted.lattice.T, [AA.flatten(), BB.flatten()])

ax2.scatter(
    xx,
    yy,
    c=np.log(get_padded_array(pg_shifted.grid_data).flatten()),
    s=20,
    edgecolors=None,
    alpha=0.9,cmap="viridis"
)

scatter2 = ax2.scatter(
    xxb,
    yyb,
    c=np.log(get_padded_array(pg_2D.grid_data).flatten()),
    s=180,
    edgecolors="black",
    alpha=1,cmap="viridis"
)

# pg_shifted2 = pg_2D.get_transformed_obj(
#     sc_mat=[[0.5, 0], [-0.5, 1]], frac_shift=[-1.0, 0.5], grid_out=[12 , 12 * 2]
# )
# av = np.linspace(0, 1, pg_shifted2.grid_data.shape[0], endpoint=True)
# bv = np.linspace(0, 1, pg_shifted2.grid_data.shape[1], endpoint=True)
# av = av - 1.5
# bv = bv + 0.5

# AA, BB = np.meshgrid(av, bv, indexing="ij")  # indexing to match the labled array
# xxc, yyc = np.dot(pg_shifted2.lattice.T, [AA.flatten(), BB.flatten()])
# ax2.scatter(
#     xxc,
#     yyc,
#     c=np.log(pg_shifted2.grid_data).flatten(),
#     s=20,
#     edgecolors=None,
#     alpha=0.9,cmap="viridis"

# )



ax1.set_aspect("equal")
ax2.set_aspect("equal")

ax1.axis("off")
ax2.axis("off")


def axis_arrow(a, b, orig=None):
    if orig is None:
        style = "Simple,head_length=16,head_width=16,tail_width=5"
        return mpatches.FancyArrowPatch(
            (0, 0), (a, b), arrowstyle=style, ec=None, fc="#ff073a"
        )
    else:
        style = "Simple,head_length=16,head_width=16,tail_width=5"
        return mpatches.FancyArrowPatch(
            orig, (orig[0] + a, orig[1] + b), arrowstyle=style, ec=None, fc="#107ab0"
        )


# atom_frac = [[0, 0], [1, 0], [2, 0], [3, 0], [0, 1], [1, 1], [2, 1], [-1,2],[0,2],[1,2],[2,2]]
# atom_frac = np.array(atom_frac)
# xx_atom, yy_atom = np.dot(pg_2D.lattice.T, atom_frac.T)
# ax2.scatter(
#     xx_atom,
#     yy_atom,
#     s=200,
#     edgecolors=None,
#     alpha=0.9,
# )

ax1.add_patch(axis_arrow(*pg_2D.lattice[0]))
ax1.add_patch(axis_arrow(*pg_2D.lattice[1]))
ax2.add_patch(axis_arrow(*pg_2D.lattice[0]))
ax2.add_patch(axis_arrow(*pg_2D.lattice[1]))
new_origin = 0.5 * pg_2D.lattice[0] + 0.5 * pg_2D.lattice[1]
ax2.add_patch(axis_arrow(*pg_shifted.lattice[0], orig=new_origin))
ax2.add_patch(axis_arrow(*pg_shifted.lattice[1], orig=new_origin))
# new_origin = -1.0 * pg_2D.lattice[0] + 0.5 * pg_2D.lattice[1]
# ax2.add_patch(axis_arrow(*pg_shifted2.lattice[0], orig=new_origin))
# ax2.add_patch(axis_arrow(*pg_shifted2.lattice[1], orig=new_origin))



# cbar = ax1.figure.colorbar(scatter2)
# cbar.set_ticks([])
# cbar.ax.set_ylabel("test", rotation=-90, va="center", ha='left')

fig.savefig(
    "/Users/lik/Desktop/resample_grid.pdf",
    bbox_inches="tight",
    transparent=True,
    pad_inches=0,
)