In [None]:
import plotly.graph_objs as go
import numpy as np

"""Visualizaiton functions do the scatter plots in plotly since it seems to be more efficient."""


def get_plotly_scatter_plot(
    data_in: np.ndarray,
    lat_mat: np.ndarray,
    factor: int = 5,
    logcolor: bool = False,
    mask: np.ndarray = None,
    opacity: float = 0.5,
    marker_size: int = 5,
) -> go.Figure:
    """
    Returns a plotly fig object for plotting.
    Args:
        data_in: Structured grid data to be plotted
        lat_mat: Lattice vectors of the cell
        factor: reduction factor of the grid points for plotting, only show [::factor] in each direction
        logcolor: If True, assign the color in log scale
        mask: Filter the points to plot
        opacity: opacity of each point being plotted
        marker_size: size of the markers in the 3D scatter plot

    Returns:
        plotly Figure object

    """
    ndim = len(data_in.shape)
    if ndim > 3:
        raise NotImplementedError("Can only render data of 1, 2, or 3 dimensions.")

    ss = slice(0, None, factor)
    trimmed_data = np.real(data_in).copy()
    trimmed_data = trimmed_data[(ss, ) * ndim]

    if mask is not None:
        flat_mask = mask[(ss, ) * ndim].flatten()
    else:
        flat_mask = np.ones_like(trimmed_data, dtype=bool).flatten()

    vecs = [np.linspace(0, 1, trimmed_data.shape[0], endpoint=False) for _ in range(ndim)]
    gridded = np.meshgrid(*vecs, indexing="ij")  # indexing to match the labeled array
    res = np.dot(lat_mat.T, [g_.flatten() for g_ in gridded])

    if logcolor:
        cc = np.log(trimmed_data.flatten())
    else:
        cc = trimmed_data.flatten()

    xx = res[0, flat_mask]
    if ndim > 1:
        yy = res[1, flat_mask]
    if ndim > 2:
        zz = res[2, flat_mask]
    
    cc = cc[flat_mask]
    if ndim == 1:
        data = go.Scatter(
            x=xx, y=cc,
            mode="markers",
            marker=dict(
                size=marker_size,
                color="red",
            ),
        )
    if ndim == 2:
        data = go.Scatter(
            x=xx,y=yy,
            mode="markers",
            marker=dict(
                size=marker_size,
                color=cc,  # set color to an array/list of desired values
                colorscale="Viridis",  # choose a colorscale
                opacity=opacity,
            ),
        )
    if ndim == 3:
        data = go.Scatter3d(
            x=xx,y=yy,z=zz,
            mode="markers",
            marker=dict(
                size=marker_size,
                color=cc,
                colorscale="Viridis",
                opacity=opacity,
            ),
        )
    fig = go.Figure(
        data=[data]
    )

    # fig.update_layout(margin=dict(l=0, r=0, b=0, t=0))
    if ndim>=2:
        fig.update_layout(width=800, height=800,
                        template="plotly_white",yaxis=dict(scaleanchor="x", scaleratio=1),
                     )
    if ndim==3:
        fig.update_layout(scene_aspectmode='data')
    return fig


In [None]:
get_plotly_scatter_plot(pg_2D.grid_data, pg_2D.lattice)

In [None]:
get_plotly_scatter_plot(chgcar.grid_data, chgcar.lattice)

In [None]:
get_plotly_scatter_plot(pg_2D.grid_data, pg_2D.lattice)

In [None]:

x = np.arange(1,11)
y1 = np.exp(x)
y2 = np.log(x)
trace1 = go.Scatter(
   x = x,
   y = y1,
   name = 'exp'
)
trace2 = go.Scatter(
   x = x,
   y = y2,
   name = 'log',
   yaxis = 'y2'
)
data = [trace1, trace2]
layout = go.Layout(
   title = 'Double Y Axis Example',
   yaxis = dict(
      title = 'exp',zeroline=True,
      showline = True
   ),
   yaxis2 = dict(
      title = 'log',
      zeroline = True,
      showline = True,
      overlaying = 'y',
      side = 'right'
   )
)


fig = go.Figure(data=data, layout=layout)
fig.show()

In [None]:
from pyrho.core.chargeDensity import ChargeDensity
from pymatgen.io.vasp import Chgcar
from pyrho.core.pgrid import PGrid
from pyrho.core.utils import get_padded_array
from pyrho.core.fourier import PFourier

chgcar = Chgcar.from_hdf5("../test_files/Si.uc.hdf5")
chgcar = ChargeDensity.from_pmg_volumetric_data(chgcar)
chgcar.reorient_axis()

a_mat = chgcar.lattice[:2,:2]
data = chgcar.grid_data[24, :, :]

pg_2D = PGrid(data, a_mat)

In [None]:
pg_2D.lattice

In [None]:
data_in = pg_2D.grid_data

In [None]:
ndim = len(data_in.shape)
if ndim > 3:
    raise NotImplementedError("Can only render data of 1, 2, or 3 dimensions.")

ss = slice(0, None, 4)
trimmed_data = np.real(data_in).copy()
all_slices = (ss, ) * ndim
trimmed_data = trimmed_data[all_slices]
trimmed_data.shape

In [None]:
def show_2d(pg_2D):
    av=np.linspace(0,1,pg_2D.grid_data.shape[0],endpoint=False)
    bv=np.linspace(0,1,pg_2D.grid_data.shape[1],endpoint=False)
    AA, BB = np.meshgrid(av,bv,indexing='ij') #indexing to match the lablled array
    xx, yy = np.dot(pg_2D.lattice.T[:2,:2], [AA.flatten(),BB.flatten()])
    xshift, yshift = np.dot(pg_2D.lattice.T[:2,:2], ((av[1]-av[0])/2.,(bv[1]-bv[0])/2.))
    plt.scatter(xx+xshift, yy+yshift, c=np.log(pg_2D.grid_data.flatten()), edgecolors='black',alpha=0.1)
show_2d(pg_2D)


In [None]:
fft_data = np.abs(np.fft.fftn(pg_2D.grid_data))
pf = PFourier(fft_data, pg_2D.lattice)

def show_2d_f(pf, thresh = 12):
    av=np.linspace(0,1,pf.fourier_data.shape[0],endpoint=False)
    bv=np.linspace(0,1,pf.fourier_data.shape[1],endpoint=False)

    xx, yy = pf.fft_pos_centered_cartesian_s
    mask = pf.fourier_data.flatten() > thresh
    plt.scatter(xx[mask], yy[mask], c=pf.fourier_data.flatten()[mask], edgecolors='black',alpha=0.4)
    plt.axes().set_aspect("equal")


In [None]:
fft_data = np.abs(np.fft.fftn(pg_2D.grid_data))
pf = PFourier(fft_data, pg_2D.lattice)
show_2d_f(pf)

In [None]:
g1,g2 = pg_2D.grid_data.shape
pg_super = pg_2D.get_transformed_obj(
    sc_mat=[[1, 1], [1, -2]], frac_shift=[0, 0], grid_out=[g1*3 , g2*2 ], up_sample=2
)
show_2d(pg_super)


In [None]:
fft_data_super = np.abs(np.fft.fftn(pg_super.grid_data))
pf_super = PFourier(fft_data_super, pg_super.lattice)


In [None]:
pg_2D.lattice

In [None]:
show_2d_f(pf_super, thresh = 3*9)

In [None]:
show_2d_f(pf, thresh = 3)