In [23]:
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from scipy.stats import multivariate_normal
from matplotlib import cm

%matplotlib notebook

# Visualizing Multivariate Gaussians

A multivariate Gaussian distribution is a probability density function taking as inputs a multivariate random variable $x$, a mean vector $\mu$, and a covariance matrix $\Sigma$. The function is defined as follows:

$$\mathcal{N}(x ; \mu, \Sigma) = \frac{1}{2\pi^{d/2}|\Sigma|^{1/2}} e^{-\frac{1}{2}(x-\mu)^{T}\Sigma^{-1}(x-\mu)}$$

Here $|\Sigma|$ represents the determinant of the covariance matrix $\Sigma$.

Let's plot some multivariate distributions.

In [51]:
def plot_multivariate(mean, cov):
    x_vals = [-3, 3]
    y_vals = [-3, 3]
    x = np.linspace(min(x_vals), max(x_vals))
    y = np.linspace(min(y_vals), max(y_vals))
    
    X, Y = np.meshgrid(x, y)
    xy = np.column_stack([X.flat, Y.flat])
    Z = multivariate_normal.pdf(xy, mean=mean, cov=covariance)
    Z = Z.reshape(Y.shape)

    fig = plt.figure()
    ax = fig.gca(projection='3d')

    ax.plot_surface(X, Y, Z, cmap=cm.coolwarm, cstride=2, rstride=2);

In [52]:
# Simplest Gaussian centered at (0, 0)
# and having the identity matrix as covariance

mean = [0, 0]
covariance = np.eye(2)

plot_multivariate(mean, covariance)

<IPython.core.display.Javascript object>

In [53]:
mean = [2, 0]
covariance = np.eye(2)

plot_multivariate(mean, covariance)

<IPython.core.display.Javascript object>

In [54]:
mean = [0, 0]
covariance = [[1, .8], [.8, 1]]

plot_multivariate(mean, covariance)

<IPython.core.display.Javascript object>

In [55]:
mean = [0, 0]
covariance = [[1, -.5], [-.5, 1]]

plot_multivariate(mean, covariance)

<IPython.core.display.Javascript object>