# Visualisation

Visualising data is a core skill of any scientist. It enables you to explore and understand your data and calculations, and to tell a story about the data to your audience. Apart from static plots, it can also be very useful to use animations or interactive plots to present and explore your data.

## Matplotlib

There are a bunch of very nice plotting libraries for Python, also many that do high-quality interactive visualisations. Prominent examples are [Plotly](https://plotly.com/python/), [Bokeh](https://bokeh.org), [Vega-Altair](https://altair-viz.github.io) and [seaborn](https://seaborn.pydata.org). It is fun to play around with these libraries and get inspiration for how to present data in clever, enlightening ways.

Good old [Matplotlib](https://matplotlib.org) is however still the "industry standard" and go-to plotting tool for most scientists working in Python.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import ipywidgets as widgets
from ipywidgets import interact, interactive
from matplotlib.animation import ArtistAnimation
from matplotlib import colormaps
from IPython.display import HTML

Let's explore the cavity-QED system from class 4:

In [None]:
import qutip as q

N = 50
sz = q.tensor(q.sigmaz(), q.identity(N))
sm = q.tensor(q.sigmam(), q.identity(N))
sp = q.tensor(q.sigmap(), q.identity(N))
a = q.tensor(q.identity(2), q.destroy(N))
ad = q.tensor(q.identity(2), q.create(N))
pi2 = 2*np.pi

def solve_qed(om_a, om_c, Om, gamma=0, alpha=0, t=np.linspace(0, 100, 200), return_states=True):
    """Solve the cavity QED system using qutip.mesolve.

    Parameters:
    om_a  : atomic transition frequency
    om_c  : cavity resonance frequency
    Om    : vacuum Rabi frequency (coupling strength)
    gamma : atomic decay rate
    alpha : amplitude of the initial coherent state of the cavity 
    t     : array of time values at which to evaluate the state evolution
    """
    om_a = pi2 * om_a
    om_c = pi2 * om_c
    Om = pi2 * Om
    gamma = pi2 * gamma
    H = om_a/2 * sz + om_c * ad*a + Om/2 * (sp * a + sm * ad)
    rho0 = q.tensor(q.basis(2,0), q.coherent(N, alpha))

    if return_states:
        e_ops = None
    else:
        e_ops = [sp*sm, ad*a]
    
    if gamma > 0:
        c_ops = [np.sqrt(gamma) * sm]
    else:
        c_ops = None
        
    result = q.mesolve(H, rho0, t, c_ops, e_ops)
    
    return result

In [None]:
result = solve_qed(om_a=1, om_c=1, Om=.05, gamma=0, alpha=3, t=np.linspace(0,300,500), return_states=False)
plt.subplots()
plt.plot(result.times, result.expect[0])

In [None]:
alphas = np.linspace(0, 4, 4)
Oms = [.025, .05, .1]
om_a = 1
om_c = 1
t = np.linspace(0, 200, 200)

fig, axs = plt.subplots(len(alphas), len(Oms), 
                        sharex=True, sharey=True, 
                        figsize=(9,9), gridspec_kw=dict(wspace=0, hspace=0))
for i, alpha in enumerate(alphas):
    for j, Om in enumerate(Oms):
        result = solve_qed(om_a, om_c, Om, gamma=0, alpha=alpha, t=t, return_states=False)
        axs[i,j].plot(t, result.expect[0])
        axs[i,j].text(.5, .9, f'Ω = 2π × {Om:.3f}\nα = 2π × {alpha:.1f}', 
                      size=8, ha='left', va='top',
                      bbox=dict(ec=(.5,.5,.5), fc=(.8,.8,1,.8)), transform=axs[i,j].transAxes)
        if i==len(alphas)-1:
            axs[i,j].set_xlabel('t')
    axs[i,0].set_ylabel('population')


### Polishing plots

Matplotlib plots don't look amazing out of the box, but there are a million ways to customise their look.

Below is an example of my own of a 3D Wigner function plot that involves a lot of customisation of plot type, view angle, colormap, labels, etc.

An great resource for getting inspiration on how to pretty up your Matplotlib plots is [Python Graph Gallery](https://python-graph-gallery.com/), which also has instructions for other plotting libraries. Especially the tutorials in the [BEST](https://python-graph-gallery.com/best-python-chart-examples/) section are excellent!

In [None]:
N = 30
xvec = np.linspace(-4,4,121)
s = (q.destroy(N) * q.squeeze(N, .3) * q.basis(N)).unit()
wig = q.wigner(s, xvec, xvec)

In [None]:
X, P = np.meshgrid(xvec, xvec)

In [None]:
# From: https://matplotlib.org/stable/gallery/mplot3d/pathpatch3d.html
from matplotlib.patches import Circle, PathPatch
from matplotlib.text import TextPath
from matplotlib.transforms import Affine2D
import mpl_toolkits.mplot3d.art3d as art3d

def text3d(ax, xyz, s, zdir="z", size=None, angle=0, flip=False, usetex=False, **kwargs):
    """  
    Plots the string *s* on the Axes *ax*, with position *xyz*, size *size*,
    and rotation angle *angle*. *zdir* gives the axis which is to be treated as
    the third dimension. *usetex* is a boolean indicating whether the string
    should be run through a LaTeX subprocess or not.  Any additional keyword
    arguments are forwarded to `.transform_path`.

    Note: zdir affects the interpretation of xyz.
    """
    x, y, z = xyz
    if zdir == "y":
        xy1, z1 = (x, z), y
        sx = np.diff(ax.get_xlim())[0]
        sy = np.diff(ax.get_zlim())[0]
    elif zdir == "x":
        xy1, z1 = (y, z), x
        sx = np.diff(ax.get_ylim())[0]
        sy = np.diff(ax.get_zlim())[0]
    else:
        xy1, z1 = (x, y), z
        sx = np.diff(ax.get_xlim())[0]
        sy = np.diff(ax.get_ylim())[0]

    if flip:
        sx = -sx

    text_path = TextPath((0, 0), s, size=size, usetex=usetex)
    trans = Affine2D().scale(.01*sx, .01*sy).rotate(angle).translate(xy1[0], xy1[1])
    # trans = Affine2D().rotate(angle).translate(xy1[0], xy1[1])

    p1 = PathPatch(trans.transform_path(text_path), **kwargs)
    ax.add_patch(p1)
    art3d.pathpatch_2d_to_3d(p1, z=z1, zdir=zdir)
    return p1

In [None]:
fig, ax = plt.subplots(figsize=(8,8),
    subplot_kw={"projection": "3d", 'proj_type':'persp', 
                                   'elev':13, 'azim':0, 'focal_length':.2})
ax.set_xlim(-4,4)
ax.set_ylim(-4,4)
ax.set_zlim(-.34,.34)

ax.set_xticklabels([])
ax.xaxis.set_tick_params(color='#0000')
ax.xaxis.line.set_color('#0000')
ax.set_yticklabels([])
ax.yaxis.set_tick_params(color='#0000')
ax.yaxis.line.set_color('#0000')
ax.set_zticklabels([])
ax.zaxis.set_tick_params(color='#0000')
ax.zaxis.line.set_color('#0000')

x, p = X.flatten(), P.flatten()
w = wig.flatten()

from matplotlib.colors import LinearSegmentedColormap
vlim = .3
c2 = [(0.,'#0047b2'), (0.18,'#0c50b7'), (0.26,'#2765c2'), (0.34,'#5889d3dd'), \
      (0.42,'#a2bee888'), (0.49,'#ffffff22'), (0.51,'#ffffff22'), (0.6,'#e8a2a288'), \
      (0.7,'#d35858dd'), (0.8,'#c22727'), (0.9,'#b70c0c'), (1.,'#b20000')]
cmwig1 = LinearSegmentedColormap.from_list('cmwig1',c2) 

cont = ax.contourf(X, P, wig, levels=np.arange(-.34,.34,.02), cmap=cmwig1, vmin=-vlim, vmax=vlim,
                   zdir='z', offset=-.34, alpha=.8, zorder=99)

surf = ax.plot_trisurf(x, p, w, cmap=cmwig1, vmin=-vlim, vmax=vlim, 
                   linewidth=0, antialiased=True)

wire = ax.plot_wireframe(X, P, wig, rstride=4, cstride=4, color='#555',
                   linewidth=1, alpha=.3)

for zval in [-.3,-.2,-.1,0,.1,.2,.3]:
    text3d(ax, (3.5, -4.2, zval+.00), f'{zval:.2f}', zdir='x', size=4.5, flip=False,
       ec='none', fc='k')
    
for yval in [-4,-2,0,2,4]:
    text3d(ax, (4, 1.05*yval-.12, -.02), f'{yval}', zdir='x', size=5, flip=False,
       ec='none', fc='k')

for xval in [-4,-2,0,2,4]:
    text3d(ax, (xval, 1.05*4-.12, -.02), f'{xval}', zdir='x', size=5, flip=False,
       ec='none', fc='k')

text3d(ax, (5, -.12, -.02), 'p', zdir='x', size=5, ec='none', fc='k')
text3d(ax, (0, 4.5, -.02), 'x', zdir='x', size=5, ec='none', fc='k')
text3d(ax, (3.5, -4.3, .35), 'W(x,p)', zdir='x', size=4.5, ec='none', fc='k')

## Widgets

Using [ipywidgets](https://ipywidgets.readthedocs.io/en/latest/index.html), it is really easy to add interactivity to your plots or other forms of outputs in the notebook - here illustrated with a small [SymPy](https://www.sympy.org/en/index.html) symbolic calculation ([source](https://github.com/jupyter-widgets/ipywidgets/blob/c579fcd1265af77a0d793aa47a7b9a401d952550/docs/source/examples/Factoring.ipynb)]:  

In [None]:
from sympy import Symbol, Eq, factor

x = Symbol('x')

def factorit(n=4):
    return Eq(x**n-1, factor(x**n-1))

factorit(12)

In [None]:
slider = widgets.IntSlider(4, 2, 40)
slider

In [None]:
interact(factorit, n=slider);

In [None]:
interact(factorit, n=(2, 40, 2));

In [None]:
sz = q.tensor(q.sigmaz(), q.identity(N))
sm = q.tensor(q.sigmam(), q.identity(N))
sp = q.tensor(q.sigmap(), q.identity(N))
a = q.tensor(q.identity(2), q.destroy(N))
ad = q.tensor(q.identity(2), q.create(N))
pi2 = 2*np.pi

@interact(om_a=(0.9,1.1,.01), om_c=(0.9,1.1,.01), Om=(0.,.5,.01), gamma=(0.,.1,.002), 
          plot_field=True, continuous_update=False)
def plot_qed(om_a=1, om_c=1, Om=.1, gamma=0, max_t=100, plot_field=True):
    alpha = 0

    t = np.linspace(0, max_t, 200)
    result = solve_qed(om_a, om_c, Om, gamma, alpha, t)
    fig, ax = plt.subplots()
    plt.plot(result.times, [q.expect(sp*sm, s) for s in result.states], label='atom population')
    if plot_field:
        plt.plot(result.times, [q.expect(ad*a, s) for s in result.states], label='cavity photon number')
    plt.legend();

In [None]:
plot_qed(1, 1, .1, .01)

In [None]:
interact(plot_qed, om_a=(0.,5.,.1), om_c=(0.,5.,.1), Om=(0.,.5,.01), gamma=(0.,.1,.002), plot_field=True,
         continuous_update=False)

Next: 
* contour plot of Wigner
* imshow of rho
* animation of 3D Wigner
* exercise: combine widgets and qutip to show effect of squeezing, displacement, photon addition on vacuum


## Animated plot

In [None]:
N = 30
H = q.num(N)
rho0 = q.coherent(N, 2)
tlist = np.linspace(0,4*np.pi)
sz = q.tensor(q.sigmaz(), q.identity(N))
sm = q.tensor(q.sigmam(), q.identity(N))
sp = q.tensor(q.sigmap(), q.identity(N))
a = q.tensor(q.identity(2), q.destroy(N))
ad = q.tensor(q.identity(2), q.create(N))
pi2 = 2*np.pi
sim_out = q.mesolve(H, rho0, tlist) 

In [None]:
xvec = np.linspace(-5,5,100)
ims = []
fig, ax = plt.subplots()
ax.set_aspect(1)

for s in sim_out.states:
    W = q.wigner(s, xvec, xvec)
    im = plt.contourf(xvec, xvec, W,
                      vmin=-1/np.pi, vmax=1/np.pi, cmap=colormaps['RdBu_r'])
    ims.append([im])

ani = ArtistAnimation(fig, ims, interval=50, blit=True)
HTML(ani.to_html5_video())

In [None]:
ani.to_html5_video()

In [None]:
ani.save('out.mp4')