In [None]:
import matplotlib.pyplot as plt
from matplotlib import cm
import numpy as np

from mpl_toolkits.mplot3d.axes3d import get_test_data

In [51]:
rho = np.array([[0.45,0.2-1j*0.3],[0.2+1j*0.3,0.55]])
e_val, e_vec = np.linalg.eigh(rho)
rho_inv = np.linalg.inv(rho)

In [60]:
QQ = (2*np.pi*e_val[0]*e_val[1])/(e_val[1]-e_val[0])*((np.exp(-1/(2*e_val[1])))-(np.exp(-1/(2*e_val[0]))))

In [48]:
p_plus, phi_plus = 0.568, 0.983
p_minus, phi_minus = 0.432, 4.124

In [49]:
X_plus = np.argmin(np.abs(X[0]-p_plus)**2)
X_minus = np.argmin(np.abs(X[0]-p_minus)**2)

Y_plus = np.argmin(np.abs(Y[:,0]-phi_plus)**2)
Y_minus = np.argmin(np.abs(Y[:,0]-phi_minus)**2)

In [78]:
X.shape

(629, 100)

In [None]:
# set up a figure twice as wide as it is tall
fig = plt.figure(figsize=(18,7))

#===============
#  First subplot
#===============
# set up the axes for the first plot
ax = fig.add_subplot(1, 2, 1, projection='3d')

X = np.arange(0, 1, 0.01)
Y = np.arange(0, 2*np.pi, 0.01)
X, Y = np.meshgrid(X, Y)
Z1 = np.zeros(X.shape)
Z1[Y_minus,X_minus], Z1[Y_plus,X_plus] = e_val[0], e_val[1]
surf2 = ax.plot_surface(X, Y, Z1, rstride=1, cstride=1, cmap=cm.coolwarm,
                       linewidth=0, antialiased=False)

ax.set_xlabel("p",fontsize=15)
ax.set_ylabel("$\phi$",fontsize=15)
#ax.set_zlim(0, 10)
#fig.colorbar(surf2, shrink=0.5, aspect=10)
ax.view_init(30,20)

#===============
# Second subplot
#===============
# set up the axes for the second plot
ax = fig.add_subplot(1, 2, 2, projection='3d')

# plot a 3D surface like in the example mplot3d/surface3d_demo
X = np.arange(0, 1, 0.01)
Y = np.arange(0, 2*np.pi, 0.01)
X, Y = np.meshgrid(X, Y)
R = (1-X)*np.real(rho_inv[0,0]) + X*np.real(rho_inv[0,0]) + np.sqrt(X*(1-X))*np.real(np.exp(-1j*Y)*rho_inv[1,0]+np.exp(1j*Y)*rho_inv[0,1])
Z = np.exp(-0.5*R)/QQ
surf = ax.plot_surface(X, Y, Z, rstride=1, cstride=1, cmap=cm.coolwarm,
                       linewidth=0, antialiased=False)

ax.set_xlabel("p",fontsize=15)
ax.set_ylabel("$\phi$",fontsize=15)
#ax.set_zlim(0, 10)
ax.view_init(30,20)
fig.colorbar(surf,ax=ax, shrink=0.5, aspect=10)

plt.show()