In [None]:
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import IPython
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torch.optim as optim
from scipy.stats import norm
from tqdm import trange, tqdm_notebook
import os.path as osp
import warnings
from collections import OrderedDict
warnings.filterwarnings('ignore')

In [None]:
import torch
torch.cuda.is_available()

# Part 1: KL Divergence:

## Ground Truth Probability

In [None]:
from torch.distributions.normal import Normal

class MOG(nn.Module):
    def __init__(self, weights, locs, scales, torch_device): 
        super().__init__()
        self.weights = torch.tensor(weights, device=torch_device)
        self.locs = torch.tensor(locs, device=torch_device)
        self.scales = torch.tensor(scales, device=torch_device)
        self.n_components = len(self.weights)

    def log_prob(self, x):
        weights = self.weights.unsqueeze(0).repeat(x.shape[0], 1)
        return (Normal(self.locs, self.scales).log_prob(x.unsqueeze(1).repeat(1, self.n_components)).exp() * weights).sum(dim=1).log()

In [None]:
xs = np.linspace(-3, 3, num=1000)
data_distribution = MOG(np.array([0.7, 0.3]),    # mixture weights
                        np.array([-1, 1]),       # means
                        np.array([0.25, 0.25]),# scales
                        torch_device = 'cpu')  
ys = data_distribution.log_prob(torch.tensor(xs)).exp().numpy()
plt.plot(xs, ys)
plt.show()

## Model Probability

In [None]:
class Gaussian(nn.Module):
    def __init__(self):
        super().__init__()
        # Returns a tensor filled with random numbers from a normal distribution with mean 0 and variance 1 (also called the standard normal distribution).
        self.loc = nn.Parameter(torch.randn(1), requires_grad=True)
        self.log_scale = nn.Parameter(torch.zeros(1), requires_grad=True)

    def log_prob(self, x):
        return Normal(self.loc, self.log_scale.exp()).log_prob(x)
 
    # Compute loss as negative log-likelihood
    def nll(self, x):
        return - self.log_prob(x).mean()

In [None]:
model_distribution = Gaussian()  # scales
y_gauss = model_distribution.log_prob(torch.tensor(xs)).exp().detach().numpy()
plt.plot(xs, y_gauss)
plt.show()

##  Kullback-Leibler Divergence

Kullback-Leibler Divergence is a measure of how one probability distribution is different from a second, reference probability distribution. It is defined as the expectation of the logarithmic difference between the probabilities P and  Q, where the expectation is taken using the probabilities P:
$$D_{KL}[P||Q]  = \mathbb{E}_{x \sim P}\left[\log P(x) - \log Q(x)\right] $$
For discrete probability distributions P and Q defined on the same probability space, KL-divergence is defined as a sum:
$$D_{KL}[P||Q]  = \sum_{i} P(i) log\frac{P(i)}{Q(i)}$$
For distributions P and Q of a continuous random variable, the KL-divergence as an integral:
$$D_{KL}[P||Q]  =  \int_{} P(x) log\frac{P(x)}{Q(x)}dx$$

In [None]:
def kl_divergence(p, q, interval_size=1):
    return (p * (p.log() - q.log())).sum() * interval_size

def kl_divergence_with_logs(p, q, interval_size=1):
    return (p.exp() * (p - q)).sum() * interval_size

def forward_kl(model, data_dist):
    num_steps = 100000
    interval_size = 10.0 / num_steps

    xs = torch.linspace(-5, 5, steps=num_steps).to(ptu.device)
    return kl_divergence_with_logs(data_dist.log_prob(xs), model.log_prob(xs), interval_size=interval_size)

In [None]:
mle_model = Gaussian().to(ptu.device)

## Training
Train the model using forward KL or maximum log-likelihood estimate (MLE)

In [None]:
def train_epochs(model, data_distribution, train_args, loss_fn):
    epochs, lr = train_args['epochs'], train_args['lr']
    optimizer = optim.Adam(model.parameters(), lr=lr)

    ## Evaluation
   # model.eval()
    #test_loss = loss_fn(model, data_distribution).item()
    #test_losses.append(test_loss)  # loss at init

    ## Training
    for epoch in tqdm_notebook(range(epochs), desc='Epoch', leave=False):
        # start the training
        mle_model.train()
        loss = forward_kl(mle_model, data_distribution)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

In [None]:
lr=3E-3
epochs=1000

optimizer = optim.Adam(mle_model.parameters(), lr=lr)
mle_model.train()

for epoch in tqdm_notebook(range(epochs), desc='Epoch', leave=False):

    loss = forward_kl(mle_model, data_distribution)
    optimizer.zero_grad()
    loss.backward()
    print(loss)
    optimizer.step()

In [None]:
plt.plot(xs, ys, label='data')
plt.plot(xs, ptu.get_numpy(mle_model.log_prob(ptu.tensor(xs)).exp()), c='g', linestyle='dashed', label='learned model')
plt.legend()
plt.show()

# Part 2:  A Simple Latent Variable Model (LVM)
In this part, we train a simple LVM modeled as $z \sim \text{Multinomial}(3), x \sim N(\mu_\theta(z), 1)$, where $\mu_\theta(z)$ is a small neural network outputting the mean of a guassian. We fit this LVM using maximum likelhood by marginalizing out $z$.

In [None]:
def sample_blobs(n):
    centers = np.array([[5, 5], [-5, 5], [0, -5]])
    st_devs = np.array([[1.0, 1.0], [0.2, 0.2], [3.0, 0.5]])
    labels = np.random.randint(0, 3, size=(n,), dtype='int32')
    x = np.random.randn(n, 2) * st_devs[labels] + centers[labels]
    return x.astype('float32')

Plot the data

In [None]:
def plot_scatter_2d(points, title='', labels=None):
    plt.figure()
    if labels is not None:
        plt.scatter(points[:, 0], points[:, 1], c=labels,
                    cmap=matplotlib.colors.ListedColormap(['red', 'blue', 'green', 'purple']))
    else:
        plt.scatter(points[:, 0], points[:, 1])
    plt.title(title)
    plt.show()

In [None]:
train_data = sample_blobs(10000)
test_data = sample_blobs(2500)
plot_scatter_2d(train_data, title='Train Data')

convert the numpy dataset to pytorch

In [None]:
train_loader = data.DataLoader(train_data, batch_size=128, shuffle=True)
test_loader = data.DataLoader(test_data, batch_size=128)
device = torch.device('cpu')

## Define the model

In [None]:
class SimpleLVM(nn.Module):
    def __init__(self, n_mix):
        super().__init__()

        self.n_mix = n_mix
        self.pi_logits = nn.Parameter(torch.zeros(n_mix, dtype=torch.float32), requires_grad=True)
        self.mus = nn.Parameter(torch.randn(n_mix, 2, dtype=torch.float32), requires_grad=True)
        self.log_stds = nn.Parameter(-torch.ones(n_mix, 2, dtype=torch.float32), requires_grad=True)

    def loss(self, x):
        log_probs = []
        for i in range(self.n_mix):
            mu_i, log_std_i = self.mus[i].unsqueeze(0), self.log_stds[i].unsqueeze(0)
            log_prob = -0.5 * (x - mu_i) ** 2 * torch.exp(-2 * log_std_i)
            log_prob = log_prob - 0.5 * np.log(2 * np.pi) - log_std_i
            log_probs.append(log_prob.sum(1))
        log_probs = torch.stack(log_probs, dim=1)

        log_pi = F.log_softmax(self.pi_logits, dim=0)
        log_probs = log_probs + log_pi.unsqueeze(0)
        loss = -torch.logsumexp(log_probs, dim=1).mean()
        return OrderedDict(loss=loss)

    def sample(self, n):
        with torch.no_grad():
            probs = F.softmax(self.pi_logits, dim=0)
            labels = torch.multinomial(probs, n, replacement=True)
            mus, log_stds = self.mus[labels], self.log_stds[labels]
            x = torch.randn(n, 2) * log_stds.exp() + mus
        return x.numpy(), labels.numpy()

In [None]:
from utils import train_epochs

n_mix = 3
model = SimpleLVM(n_mix)

def fn(epoch):
    x, labels = model.sample(10000)
    plot_scatter_2d(x, title=f'Epoch {epoch} Samples', labels=labels)

train_epochs(model, train_loader, test_loader, device, dict(epochs=10, lr=7e-2),
             fn=fn, fn_every=2, quiet=True)

x, labels = model.sample(10000)
plot_scatter_2d(x, title='Final Samples', labels=labels)