In [1]:
%matplotlib widget

import numpy as np
from wigner import Su2Group

import plotly.express as px
import plotly.graph_objects as go

from ipywidgets import interact, IntSlider, FloatSlider

In [2]:
dim = 4

In [None]:
density = 200

xs_1d = np.linspace(-np.pi, np.pi, density)
ys_1d = np.linspace(-np.pi, np.pi, density)
xs, ys = np.meshgrid(xs_1d, ys_1d)

thetas = np.hypot(ys, xs)
phis = np.atan2(ys, xs)


fig: go.FigureWidget = go.FigureWidget()
fig.update_layout(
    width=800,
    height=600,
    autosize=False,
    margin=dict(l=0, r=0, b=0, t=0, pad=4),
    yaxis_scaleanchor="x",
)

fig.add_heatmap(
    zmax=+0.8 * np.sqrt(dim),
    zmin=-0.8 * np.sqrt(dim),
    colorscale=px.colors.diverging.RdBu,
    zsmooth="best",
)


@interact(
    i=IntSlider(min=0, max=dim - 1, step=1),
    j=IntSlider(min=0, max=dim - 1, step=1),
    theta=FloatSlider(min=0, max=np.pi, step=np.pi / 20, value=np.pi / 2),
    phi=FloatSlider(min=0, max=2 * np.pi, step=2 * np.pi / 20, value=0),
    purity=FloatSlider(min=0, max=1, step=0.05, value=1),
)
def transition(i: int, j: int, theta: float, phi: float, purity: float):
    global thetas, phis

    psi = np.zeros((dim,), dtype=complex)
    if i != j:
        psi[i] = np.cos(theta / 2)
        psi[j] = np.exp(1j * phi) * np.sin(theta / 2)
    else:
        psi[i] = 1

    rho_pure = psi[:, None] * psi[None, :].conj()
    if i != j:
        rho_mixed = np.diag(rho_pure.diagonal())
        rho = purity * rho_pure + (1 - purity) * rho_mixed
    else:
        rho = rho_pure

    wigner_fn = Su2Group(dim).wigner_transform(thetas, phis, rho)
    wigner_fn = np.asarray(wigner_fn, dtype=object)
    wigner_fn[thetas > np.pi] = None
    fig.data[0].z = wigner_fn


display(fig)