<a href="https://colab.research.google.com/github/harvard-ml-courses/a-cs281-demo/blob/master/03_KLandMLE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 03 - KL and MLE

In [0]:
# This line will be at the top of all notebooks to initialize course material.
!pip install -qU plotly torch
!rm -fr start; git clone --single-branch -b demos2018 -q https://github.com/harvard-ml-courses/cs281-demos start; cp -f start/cs281.py cs281.py

tcmalloc: large alloc 1073750016 bytes == 0x5981c000 @  0x7f447104f1c4 0x46d6a4 0x5fcbcc 0x4c494d 0x54f3c4 0x553aaf 0x54e4c8 0x54f4f6 0x553aaf 0x54efc1 0x54f24d 0x553aaf 0x54efc1 0x54f24d 0x553aaf 0x54efc1 0x54f24d 0x551ee0 0x54e4c8 0x54f4f6 0x553aaf 0x54efc1 0x54f24d 0x551ee0 0x54efc1 0x54f24d 0x551ee0 0x54e4c8 0x54f4f6 0x553aaf 0x54e4c8


In [0]:
# Standard includes.
import torch
import torch.distributions as ds
from plotly.offline import iplot
import plotly.graph_objs as go
import cs281

## Entropy and KL

In [0]:
# Batching example. Entropy calculation.
mu = torch.linspace(0, 1)
p_x = ds.Bernoulli(mu)
y = p_x.entropy()
iplot(cs281.plot(lambd, y))

In [0]:
lambd = torch.linspace(0.5, 5)
p_x = ds.Normal(0, lambd)
y = p_x.entropy()
iplot(cs281.plot(lambd, y))

In [0]:
#@title KL Gaussians { run: "auto" }

p = ds.Normal(0, 1)

mu = 2 #@param {type:"slider", min:-10, max:10, step:1}
sigma = 5 #@param {type:"slider", min:1, max:10, step:1}
q = ds.Normal(mu, sigma)
x = torch.linspace(-10, 10)
iplot(cs281.plot(x, p.log_prob(x).exp()) + cs281.plot(x, q.log_prob(x).exp()))
print("KL(p || q)", ds.kl_divergence(p, q).item())
print("KL(q || p)", ds.kl_divergence(q, p).item())

KL(p || q) 1.129437804222107
KL(q || p) 10.390562057495117


In [0]:
#@title KL Gaussian - Laplace { run: "auto" }

p = ds.Normal(0, 1)

location = 1 #@param {type:"slider", min:-10, max:10, step:1}
scale = 5 #@param {type:"slider", min:1, max:10, step:1}
q = ds.Laplace(location, scale)

x = torch.linspace(-10, 10)
iplot(cs281.plot(x, p.log_prob(x).exp()) + cs281.plot(x, q.log_prob(x).exp()))
print("KL(q || p)", ds.kl_divergence(q, p).item())

KL(q || p) 23.11635398864746


In [0]:
# Look at the code implementing this.
ds.kl.

## Fitting with SGD

In [0]:
!pip install colorlover
import colorlover as cl
bupu = cl.scales['9']['seq']['BuPu']



In [0]:
mu = torch.tensor([1.], requires_grad=True)
sigma = torch.tensor([2.], requires_grad=True)

opt = torch.optim.SGD([mu, sigma], lr=0.7)

p = ds.Laplace(0, 1)
plots = []

# Run SGD
bupu10 = cl.to_rgb(cl.interp(bupu, 10 ))
for i in range(10): 
    opt.zero_grad()
    q = ds.Normal(mu, sigma)
    plots += cs281.plot(x, q.log_prob(x).exp(), 
                        marker=dict(color=bupu10[i]))
    
    loss = ds.kl_divergence(p, q)
    loss.backward()
    opt.step()

x = torch.linspace(-10, 10)
iplot(cs281.plot(x, p.log_prob(x).exp()) + plots)


In [0]:
go.Histogram()

Histogram()

In [0]:
class0 = ds.MultivariateNormal(torch.tensor([0., 0]), 
                               torch.diag(torch.tensor([2., 1])))
class1 = ds.MultivariateNormal(torch.tensor([2., 2]), 
                               torch.diag(torch.tensor([1., 2])))
x_0 = class0.sample(torch.Size([500]))
x_1 = class1.sample(torch.Size([500]))
x = torch.linspace(-10, 10)
base = [go.Scatter(x = x_0[:, 0].numpy(), y=x_0[:, 1].numpy(), mode="markers"),
       go.Scatter(x = x_1[:, 0].numpy(), y=x_1[:, 1].numpy(), mode="markers")
      ]
iplot(base)
X = torch.cat([x_0, x_1])
y = torch.tensor([0]* 500 + [1]*500)

In [0]:
w = torch.tensor([-1., 1.], requires_grad=True)
b = torch.tensor([2.], requires_grad=True)
opt = torch.optim.SGD([w, b], lr=0.5)
plots = []

x = torch.linspace(-10, 10)
bupu100 = cl.to_rgb(cl.interp(bupu, 100))
for epoch in range(100):
    plots += cs281.plot(x, -(x * w[0] + b) / w[1],
                        marker=dict(color=bupu100[epoch]))
    
    # Main loop
    opt.zero_grad()
    q = ds.Bernoulli(torch.sigmoid(w @ X.t() + b))
    loss = -q.log_prob(y.float()).mean()
    loss.backward()
    opt.step()
    
iplot(base + plots[::5])