Plotting
========

ERLabPy provides a number of plotting functions to help visualize data and create
publication quality figures.

Importing
---------

In [None]:
import matplotlib.pyplot as plt
import erlab.plotting.erplot as eplt

In [None]:
%config InlineBackend.figure_formats = ["svg", "pdf"]
plt.rcParams["figure.dpi"] = 96
import xarray as xr

xr.set_options(display_expand_data=False)
nb_execution_mode = "cache"

First, let us generate some example data from a simple tight binding model of graphene.
A rigid shift of 200 meV has been applied so that the Dirac cone is visible.

In [None]:
from erlab.io.exampledata import generate_data

dat = generate_data(bandshift=-0.2, seed=1).T

In [None]:
dat

In [None]:
cut = dat.qsel(ky=0.3)
cut

Plotting 2D data
----------------

In [None]:
eplt.plot_array(cut)

In [None]:
eplt.plot_array(
    cut, cmap="Greys", gamma=0.5, colorbar=True, colorbar_kw=dict(width=10, ticks=[])
)

In [None]:
cut.qplot(cmap="Greys", gamma=0.5)

Next, let's add some annotations! The following code adds a line indicating the Fermi level, labels high symmetry points, and adds a colorbar. Here, unlike the previous example, the colorbar was added after plotting. Like this, adding elements separately instead of using keyword arguments can make the code more readable in complex plots.

In [None]:
eplt.plot_array(cut, cmap="Greys", gamma=0.5)

eplt.fermiline()
eplt.mark_points([-0.525, 0.525], ["K", "K"], fontsize=10, pad=(0, 10))
eplt.nice_colorbar(width=10, ticks=[])

Slices
------

What if we want to plot multiple slices at once? We should create subplots to place the
slices. ``plt.subplots`` is very useful in managing multiple axes and figures. If you
are unfamiliar with the syntax, visit the [relevant matplotlib
documentation](https://matplotlib.org/stable/gallery/subplots_axes_and_figures/subplots_demo.html).

Suppose we want to plot constant energy surfaces at specific binding energies, say, at `[-0.4, -0.2, 0.0]`. We could create three subplots and iterate over the axes.

In [None]:
energies = [-0.4, -0.2, 0.0]

fig, axs = plt.subplots(1, 3, layout="compressed", sharey=True)
for energy, ax in zip(energies, axs):
    const_energy_surface = dat.qsel(eV=energy)
    eplt.plot_array(const_energy_surface, ax=ax, gamma=0.5, aspect="equal")

In [None]:
fig, axs = plt.subplots(1, 3, layout="compressed", sharey=True)
for energy, ax in zip(energies, axs):
    const_energy_surface = dat.qsel(eV=energy)
    eplt.plot_array(const_energy_surface, ax=ax, gamma=0.5, aspect="equal")

eplt.clean_labels(axs)  # removes shared y labels
eplt.label_subplot_properties(axs, values={"Eb": energies})  # annotates energy

In [None]:
fig, axs = eplt.plot_slices([dat], eV=[-0.4, -0.2, 0.0], gamma=0.5, axis="image")

We can also plot the data integrated over an energy window, in this case with a width of 200 meV by adding the `eV_width` argument:

In [None]:
fig, axs = eplt.plot_slices(
    [dat], eV=[-0.4, -0.2, 0.0], eV_width=0.2, gamma=0.5, axis="image"
)

Cuts along constant $k_y$ can be plotted analogously.

In [None]:
fig, axs = eplt.plot_slices([dat], ky=[0.0, 0.1, 0.3], gamma=0.5, figsize=(6, 2))

Here, we notice that the first two plots slices through regions with less spectral weight, so the color across the three subplots are not on the same scale. This may be misleading in some occasions where intensity across different slices are important. Luckily, we have a function that can unify the color limits across multiple axes.

The same effect can be achieved by passing on `same_limits=True` to `plot_slices`.

In [None]:
fig, axs = eplt.plot_slices([dat], ky=[0.0, 0.1, 0.3], gamma=0.5, figsize=(6, 2))
eplt.unify_clim(axs)

We can also choose a reference axis to get the color limits from.

In [None]:
fig, axs = eplt.plot_slices([dat], ky=[0.0, 0.1, 0.3], gamma=0.5, figsize=(6, 2))
eplt.unify_clim(axs, target=axs.flat[1])

What if we want to plot constant energy surfaces and cuts in the same figure? We can create the subplots first and then utilize the `axes` argument of `plot_slices`.

In [None]:
fig, axs = plt.subplots(2, 3, layout="compressed", sharex=True, sharey="row")
eplt.plot_slices([dat], eV=[-0.4, -0.2, 0.0], gamma=0.5, axes=axs[0, :], axis="image")
eplt.plot_slices([dat], ky=[0.0, 0.1, 0.3], gamma=0.5, axes=axs[1, :])
eplt.clean_labels(axs)

2D colormaps
------------

In [None]:
dat0, dat1 = generate_data(
    shape=(250, 250, 2), Erange=(-0.3, 0.3), temp=0.0, seed=1, count=1e6
).T

_, axs = eplt.plot_slices(
    [dat0, dat1],
    order="F",
    subplot_kw={"layout": "compressed", "sharey": "row"},
    axis="scaled",
    label=True,
)
# eplt.label_subplot_properties(axs, values=dict(Eb=[-0.3, 0.3]))

Suppose we want to visualize the sum and the normalized difference between the two. The simplest way is to plot them side by side.

In [None]:
dat_sum = dat0 + dat1
dat_ndiff = (dat0 - dat1) / dat_sum

eplt.plot_slices(
    [dat_sum, dat_ndiff],
    order="F",
    subplot_kw={"layout": "compressed", "sharey": "row"},
    cmap=["viridis", "bwr"],
    axis="scaled",
)
eplt.proportional_colorbar()

The difference array is noisy for small values of the sum. We can plot using a 2D
colomap, where `dat_ndiff` is mapped to the color along the colormap and `dat_sum` is
mapped to the lightness of the colormap.

In [None]:
eplt.plot_array_2d(dat_sum, dat_ndiff)

The color normalization for each axis can be set independently with `lnorm` and `cnorm`.
The appearance of the colorbar axes can be customized with the returned `Colorbar`
object.

In [None]:
_, cb = eplt.plot_array_2d(
    dat_sum,
    dat_ndiff,
    lnorm=eplt.InversePowerNorm(0.5),
    cnorm=eplt.CenteredInversePowerNorm(0.7, vcenter=0.0, halfrange=1.0),
)
cb.ax.set_xticks(cb.ax.get_xlim())
cb.ax.set_xticklabels(["Min", "Max"])

Styling figures
---------------

You can control the look and feel of matplotlib figures with [*style sheets* and *rcParams*](https://matplotlib.org/stable/users/explain/customizing.html). In addition to the [options provided by matplotlib](https://matplotlib.org/stable/gallery/style_sheets/style_sheets_reference.html), ERLabPy provides some style sheets that are listed below. Note that style sheets that change the default font requires the font to be installed on the system. To see how each one looks, try running [the code provided by matplotlib](https://matplotlib.org/stable/gallery/style_sheets/style_sheets_reference.html).

| Style Name | Description                                                                                         |
|------------|-----------------------------------------------------------------------------------------------------|
| khan       | Personal preferences of the package author.                                                         |
| fira       | Changes the default font to Fira Sans.                                                              |
| firalight  | Changes the default font to Fira Sans Light.                                                        |
| times      | Changes the default font to Times New Roman.                                                        |
| nature     | Changes the default font to Arial, and tweaks some aspects such as padding and default figure size. |


In [None]:
with plt.style.context(["nature"]):
    eplt.plot_array(cut, cmap="Greys", gamma=0.5)

Tips
----

- In the python ecosystem, there are some libraries that provide great colormaps, such as [cmasher](https://cmasher.readthedocs.io>), [cmocean](https://matplotlib.org/cmocean/>), and [colorcet](https://colorcet.holoviz.org>).

- Although matplotlib is a powerful library, it is heavy and slow, and better suited for static plots. For interactive plots, libraries such as [Plotly](https://github.com/plotly/plotly.py>) or [Bokeh](https://github.com/bokeh/bokeh>) are popular.

  The hvplot library is a high-level plotting library that provides a simple interface to Bokeh, Plotly, and Matplotlib. It is particularly useful for interactive plots and can be used with xarray objects. Here are some examples that uses the Bokeh backend:

In [None]:
import hvplot.xarray

cut.hvplot(x="kx", y="eV", cmap="Greys", aspect=1.5)

In [None]:
dat.hvplot(x="kx", y="ky", cmap="Greys", aspect="equal", widget_location="bottom")