# A3

In [1]:
import torch
from torch import nn
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
import math

In [2]:
def log1pexp(x):
    z = torch.stack([x, torch.zeros_like(x)], axis=-1)
    return torch.logsumexp(z, axis=-1)

In [3]:
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
DEVICE

device(type='cuda', index=0)

## Binarized MNIST

In [4]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: (x>0.5).view(-1).to(torch.int8)),
])

ds = torchvision.datasets.MNIST(root='.', train=True, download=True, transform=transform)
data_loader = DataLoader(ds, batch_size=256, shuffle=True, pin_memory=True)


## Model Implementation

In [5]:
Dz, Dh, Ddata = 2, 500, 28**2

In [6]:
def log_prior(z):
    """
    z: input tensor with shape (batch_size, Dz)
    """
    return torch.sum(-0.5 * (torch.log(torch.tensor(2 * math.pi, device=DEVICE)) + z ** 2), dim=-1)  # shape: (batch_size,)

def bernoulli_log_density(x, logit_means):
    """
    x: Input tensor with shape (batch_size, 784)
    logit_means: Input tensor with shape (batch_size, 784)
    """
    b = x * 2 - 1
    return torch.sum(-log1pexp(-b * logit_means), dim=-1)  # shape: (batch_size,)


decoder = nn.Sequential(nn.Linear(Dz, Dh),
                        nn.Tanh(),
                        nn.Linear(Dh, Ddata)).to(DEVICE)


def log_likelihood(x,z):
    return bernoulli_log_density(x, decoder(z))


def joint_log_density(x,z):
    return log_likelihood(x,z) + log_prior(z)  # shape: (batch_size, )

## Amortized Approximate Inference with Learned Variational Distribution

In [7]:
def log_q(z, mu, log_sigma):
    """
    z: tensor with shape (batch_size, Dz)
    mu: tensor with shape (batch_size, Dz)
    log_sigma: tensor with shape (batch_size, Dz)
    """
    
    return torch.sum(
        -0.5 * torch.log(torch.tensor(2 * math.pi, device=DEVICE)) - 
        log_sigma - 0.5 * ((z - mu) / torch.exp(log_sigma)) ** 2, dim=1)  # shape: (batch_size,)

In [8]:
encoder = nn.Sequential(nn.Linear(Ddata, Dh),
                        nn.Tanh(),
                        nn.Linear(Dh, 2 * Dz)).to(DEVICE)


def elbo(x):
    batch_size = x.size(0)
    enc_out = encoder(x)
    mu, log_sigma = enc_out[..., :Dz], enc_out[..., Dz:]
    z = torch.randn_like(log_sigma) * torch.exp(log_sigma) + mu
    joint_ll = joint_log_density(x, z)
    log_q_z = log_q(z, mu, log_sigma)
    elbo_estimate = torch.mean(joint_ll - log_q_z)  # scalar
    
    return elbo_estimate


def loss_fn(x):
    return -elbo(x)

## Optimize the model and amortized variational parameters

In [9]:
def train(enc, dec, data, loss_fn, n_epochs=10, print_every=500):
    optimizer = torch.optim.Adam(list(enc.parameters()) + list(dec.parameters()), lr=1e-4, )
        
    itrs = 0
    for epoch in range(n_epochs):
        print(f'Epoch {epoch}')
        for X_batch, y in data_loader:
            optimizer.zero_grad()
            loss = loss_fn(X_batch.to(DEVICE).float())
            
            if itrs % print_every == 0:
                print(f'\tIteration {itrs}, Loss {loss.item()}')

            loss.backward()
            optimizer.step()
            itrs += 1

In [10]:
train(encoder, decoder, data_loader, loss_fn)

Epoch 0
	Iteration 0, Loss 552.8909912109375
Epoch 1
Epoch 2
	Iteration 500, Loss 193.457763671875
Epoch 3
Epoch 4
	Iteration 1000, Loss 190.4912872314453
Epoch 5
Epoch 6
	Iteration 1500, Loss 185.35568237304688
Epoch 7
Epoch 8
	Iteration 2000, Loss 177.88592529296875
Epoch 9
