In [1]:
%matplotlib widget

import numpy as np
from wigner import Su2Group

import plotly.express as px
import plotly.graph_objects as go

from ipywidgets import interact, IntSlider, FloatSlider

In [2]:
dim = 4

In [None]:
density = 100

thetas_1d = np.linspace(0, np.pi, density)
phis_1d = np.linspace(0, 2 * np.pi, density)
thetas, phis = np.meshgrid(thetas_1d, phis_1d)


fig: go.FigureWidget = go.FigureWidget()
fig.update_layout(
    width=800,
    height=600,
    autosize=False,
    margin=dict(l=0, r=0, b=0, t=0, pad=4),
    scene_camera=dict(eye=dict(x=1, y=1, z=1.2)),
    scene=dict(
        xaxis=dict(showticklabels=False),
        yaxis=dict(showticklabels=False),
        zaxis=dict(showticklabels=False),
    ),
)

fig.add_surface(
    x=np.sin(thetas) * np.cos(phis),
    y=np.sin(thetas) * np.sin(phis),
    z=np.cos(thetas),
    # cmid=0,
    cmax=+0.8 * np.sqrt(dim),
    cmin=-0.8 * np.sqrt(dim),
    colorscale=px.colors.diverging.RdBu,
)


@interact(
    i=IntSlider(min=0, max=dim - 1, step=1),
    j=IntSlider(min=0, max=dim - 1, step=1),
    theta=FloatSlider(min=0, max=np.pi, step=np.pi / 20, value=np.pi / 2),
    phi=FloatSlider(min=0, max=2 * np.pi, step=2 * np.pi / 20, value=0),
    purity=FloatSlider(min=0, max=1, step=0.05, value=1),
)
def transition(i: int, j: int, theta: float, phi: float, purity: float):
    global thetas, phis

    psi = np.zeros((dim,), dtype=complex)
    if i != j:
        psi[i] = np.cos(theta / 2)
        psi[j] = np.exp(1j * phi) * np.sin(theta / 2)
    else:
        psi[i] = 1

    rho_pure = psi[:, None] * psi[None, :].conj()
    if i != j:
        rho_mixed = np.diag(rho_pure.diagonal())
        rho = purity * rho_pure + (1 - purity) * rho_mixed
    else:
        rho = rho_pure
    fig.data[0].surfacecolor = Su2Group(dim).wigner_transform(thetas, phis, rho)


display(fig)