# **CS 1810 Homework 5**
---
To account for potential version issues, try the following in your terminal:

1. Create a new environment with `python3 -m venv venv`
2. Activate that environment with `source venv/bin/activate`
3. Make sure the interpreter in the top right corner of your VSCode (or whatever you use to run your code is venv).
4. If you get a "install kernel" message, press it.
5. Run `pip install -r requirements.txt`
6. Run the remainder of this notebook.

Note that this is not necessary but can help prevent any issues due to package versions.

**The following notebook is meant to help you work through Problems 2 and 3 on Homework 5. You are by no means required to use it, nor are you required to fill out/use any of the boilerplate code/functions. You are welcome to implement the functions however you wish.**


## Problem 2

#### Initialize data and parameters

Consider a specific example of when we have $K = 3$ component Gamma distributions. Let's initialize the initial parameter values for $\theta$ and $\beta_k$ as follows:
$$
\begin{align*}
  \theta_k &=  1/K, \\
  \beta_k & = k/K.
\end{align*}
$$

Note that we usually initialize $\theta$ and $\beta_k$ randomly. However, by fixing the initial $\theta$ and $\beta_k$, EM becomes deterministic which makes debugging (and grading) easier.



In [1]:
import torch
import torch.distributions as ds
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'svg'

In [2]:
# Load in the data
x = torch.load('data.pt').reshape((-1, 1))



# # uncomment to use numpy (optional)
# import numpy as np
# from scipy.stats import gamma
# x = x.numpy()
# theta = theta.numpy()
# betas = betas.numpy()

### Part 1

In [3]:
# YOUR CODE HERE
x = x.flatten()

### Part 2

#### **Todo:** implement the E-step

In [4]:
alpha = 5.0

In [5]:
def e_step(theta, betas):
    log_px_z = ds.Gamma(alpha, betas).log_prob(x.unsqueeze(1))
    # add log theta_k
    log_joint = log_px_z + torch.log(theta)
    log_sum = torch.logsumexp(log_joint, dim=1, keepdim=True)
    q = torch.exp(log_joint - log_sum) 
    return q

#### **Todo:** implement the M-step

In [6]:
def m_step(q):
    N = q.shape[0]
    theta_new = q.sum(dim=0) / N
    # numerator: alpha * sum_n q[n,k]
    # denominator: sum_n q[n,k] * x[n]
    weighted_x = (q * x.unsqueeze(1)).sum(dim=0)
    betas_new = alpha * q.sum(dim=0) / weighted_x
    return theta_new, betas_new

#### **Todo:** implement log likelihood

In [7]:
def log_px(x, theta, betas):
    x_flat = x.flatten()
    log_px_z = ds.Gamma(alpha, betas).log_prob(x_flat.unsqueeze(1))
    return torch.logsumexp(log_px_z + torch.log(theta), dim=1)

def log_likelihood(theta, betas):
    return log_px(x, theta, betas).sum().item()

#### **Todo:** implement EM algorithm

In [8]:
def run_em(theta, betas, iterations=1000, verbose=True):
    theta = theta.clone()
    betas = betas.clone()
    for i in range(iterations):
        q = e_step(theta, betas)
        theta, betas = m_step(q)
        if verbose:
            ll = log_likelihood(theta, betas)
            print(f'iter {i:4d}, log-likelihood = {ll:.6f}')
    return theta, betas

### Part 3

In [9]:
def make_overlay_plot(theta, betas):
    x_test = torch.linspace(0.01, x.max(), 1000)
    prob = log_px(x_test.unsqueeze(-1), theta, betas).exp()
    # prob = np.exp(log_px(x_test.unsqueeze(-1), theta, betas))  # use this line for numpy
    ll = log_likelihood(theta, betas)
    
    fig, ax = plt.subplots(figsize=(5, 3))
    fig.subplots_adjust(top=0.7)
    fig.suptitle(f'theta = {theta}\nbeta = {betas}\nlog likelihood = {ll:.3e}')
    
    ax.hist(x, bins=100, color='tomato', alpha=0.5, density=True, label='Dataset')
    ax.plot(x_test, prob, color='royalblue', label='Gamma mixture')
    
    ax.set_title(f'Dataset and Gamma mixture (K={len(theta)})')
    ax.set_xlabel('Recovery time (hours)')
    ax.set_ylabel('Density')
    ax.legend()

In [12]:
alpha = 5.0
for K in range(1,5):
    theta0 = torch.ones(K) / K
    betas0 = (torch.arange(K) + 1) / K
    theta, betas = run_em(theta0, betas0, verbose=False)
    make_overlay_plot(theta, betas)
    plt.savefig(f'img_output/p2_3_{K}mixtures.pdf', bbox_inches='tight')

---
## Problem 3

#### Initialize data:

In [13]:
import torch
import torchvision
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'svg'

mnist_trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True)  # download MNIST
N = 6000 

x = mnist_trainset.data[:N]  # select N datapoints
x = x.flatten(1)             # flatten the images
x = x.float()                # convert pixels from uint8 to float
# x = x.numpy()              # uncomment to use numpy (optional)

100%|██████████| 9.91M/9.91M [00:00<00:00, 19.7MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 1.75MB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 14.3MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 3.00MB/s]


#### **Todo:** implement PCA

*Hint: see `.linalg.svd()`*

In [14]:
def pca(x, n_comps=500):
    x_mean = x.mean(dim=0)
    Xc = x - x_mean
    U, S, Vh = torch.linalg.svd(Xc, full_matrices=False)
    eigvals = (S**2) / (x.shape[0] - 1)
    return eigvals[:n_comps], Vh[:n_comps]

#### **Todo:** calculate cumulative fraction of variance

*Hint: see `.cumsum()`*

In [15]:
def calc_cfvs(eigvals):
    cfvs = eigvals.cumsum(dim=0) / eigvals.sum()
    return cfvs


#### **Todo:** calculate mean squared L2 norm reconstruction losses

In [16]:
def calc_errs(x, pcomps, n_pcs=10):
    x_mean = x.mean(dim=0)
    err_mean = ((x - x_mean)**2).sum(dim=1).mean().item()
    pcs = pcomps[:n_pcs]                # [n_pcs, D]
    Xc = x - x_mean                     # [N, D]
    coeffs = Xc @ pcs.T                 # [N, n_pcs]
    x_recon = x_mean + coeffs @ pcs     # [N, D]
    err_pcomp = ((x - x_recon)**2).sum(dim=1).mean().item()
    return err_mean, err_pcomp

#### Plot and print errors:

In [17]:
def plot_pic(pic, ax, title=''):
    x = pic.reshape(28, 28)
    ax.imshow(x, cmap='binary')
    ax.set_title(title)
    ax.axis('off')

def make_plots(eigvals, cfvs, x_mean, pcomps):
    # plot eigenvals and cfvs
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 3))
    ax1.plot(eigvals, color='tomato')
    ax1.set_title('Eigenvalues')
    ax2.plot(cfvs, color='tomato')
    ax2.set_title('CFVs')
    fig.savefig('img_output/p3_cfvs.pdf')

    # plot mean
    fig, ax = plt.subplots(1, 1, figsize=(3, 3))
    plot_pic(x_mean, ax, title='Mean')
    fig.savefig('img_output/p3_mean.pdf')

    # plot top 10 pcomps
    fig, axes = plt.subplots(2, 5, figsize=(10, 4))
    for i in range(10):
        plot_pic(pcomps[i], axes.flat[i], title=f'PC index {i}')
    fig.savefig('img_output/p3_pcomps.pdf')

In [18]:
# do PCA
eigvals, pcomps = pca(x)

# calculate CFVs
fcvs = calc_cfvs(eigvals)

# print errors
err_mean, err_pcomp = calc_errs(x, pcomps)
print(f'Reconstruction error (using mean): {err_mean:3e}')  # 3.436022e+06
print(f'Reconstruction error (using mean and top 10 pcomps): {err_pcomp:3e}')  # 1.731315e+06

# make plots
make_plots(eigvals, fcvs, x.mean(0), pcomps)


Reconstruction error (using mean): 3.436024e+06
Reconstruction error (using mean and top 10 pcomps): 1.731315e+06
