In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torchvision import transforms, datasets
import numpy as np

import matplotlib.pyplot as plt
%matplotlib inline

from modules.householder import HH
from modules.couplings import NICE

In [3]:
train_dataset = datasets.MNIST("/storage/datasets", train=True, download=False, transform=transforms.ToTensor())
test_dataset = datasets.MNIST("/storage/datasets", train=False, download=False, transform=transforms.ToTensor())

In [4]:
IM_SIZE = 28 * 28
IM_SHAPE = (28, 28)
CPU_training = False
BATCH = 64
EPOCHS = 5

use_cuda =  torch.cuda.is_available() and not CPU_training
device = torch.device("cuda" if use_cuda else "cpu")

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH, 
                                           shuffle=True, num_workers=7, pin_memory=use_cuda, drop_last=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=BATCH, 
                                          shuffle=False, num_workers=7, pin_memory=use_cuda, drop_last=True)

In [54]:
def net_slim(size):
    return nn.Sequential(
        nn.Conv2d(size, 64, kernel_size=1, padding=0),
        nn.ReLU(True),
        nn.Conv2d(64, 64, kernel_size=3, padding=1),
        nn.ReLU(True),
        nn.Conv2d(64, size, kernel_size=1, padding=0),
    )

# def net_wide(size=784//2):
#     return nn.Sequential(
#         nn.Linear(size, size),
#         nn.ReLU(True),
#         nn.Linear(size, size)
#     )

In [55]:
class Downsample(nn.Module):
    


    def inverse(self, input):
        upscale_factor=2
        '''
        [:, C*r^2, H, W] -> [:, C, H*r, W*r]
        '''
        batch_size, in_channels, in_height, in_width = input.size()
        out_channels = in_channels // (upscale_factor**2)

        out_height = in_height * upscale_factor
        out_width = in_width * upscale_factor

        input_view = input.contiguous().view(batch_size, out_channels, upscale_factor,
                                             upscale_factor, in_height, in_width)

        output = input_view.permute(0, 1, 4, 2, 5, 3).contiguous()
        return output.view(batch_size, out_channels, out_height, out_width)


    def forward(self, input):
        downscale_factor=2
        '''
        [:, C, H*r, W*r] -> [:, C*r^2, H, W]
        '''
        batch_size, in_channels, in_height, in_width = input.size()
        out_channels = in_channels * (downscale_factor**2)

        out_height = in_height // downscale_factor
        out_width = in_width // downscale_factor

        input_view = input.contiguous().view(
            batch_size, in_channels, out_height, downscale_factor, out_width, downscale_factor
        )

        output = input_view.permute(0, 1, 3, 5, 2, 4).contiguous()
        return output.view(batch_size, out_channels, out_height, out_width)

In [56]:
def calculate_bits_per_pixel(ll, size=784):
    return 8 - ll / np.log(2) / size

In [57]:
class Sequential_VP_Flow(nn.Module):
    
    def __init__(self, distribution, flow_steps):
        super().__init__()
        self.flow_steps = nn.Sequential(*flow_steps)
        self.distribution = distribution
        
    def forward(self, x):
        return self.flow_steps(x)
    
    def inverse(self, z):
        for m in reversed(self.flow_steps):
            z = m.inverse(z)
        return z
    
    def log_prob(self, x):
        z = self.forward(x)
        z = z.view(z.size(0), -1)
        ll = self.distribution.log_prob(z).mean()
        return ll
    
    def sample(self):
        z = self.distribution.sample()[None]
        x = self.inverse(z)
        return x

In [58]:
nn.Conv2d(2, 3, kernel_size=5).weight.shape

torch.Size([3, 2, 5, 5])

In [62]:
dist = torch.distributions.MultivariateNormal(torch.zeros(IM_SIZE).to(device), torch.eye(IM_SIZE).to(device))
flow = Sequential_VP_Flow(dist, 
                          flow_steps=[ # 1
                              Downsample(), # 4
                              NICE(net_slim(2)),
                              NICE(net_slim(2)),
                              Downsample(),
                              NICE(net_slim(8)),
                              NICE(net_slim(8)),
                          ])
flow.to(device)

Sequential_VP_Flow(
  (flow_steps): Sequential(
    (0): Downsample()
    (1): NICE(
      (net): Sequential(
        (0): Conv2d(2, 64, kernel_size=(1, 1), stride=(1, 1))
        (1): ReLU(inplace=True)
        (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (3): ReLU(inplace=True)
        (4): Conv2d(64, 2, kernel_size=(1, 1), stride=(1, 1))
      )
    )
    (2): NICE(
      (net): Sequential(
        (0): Conv2d(2, 64, kernel_size=(1, 1), stride=(1, 1))
        (1): ReLU(inplace=True)
        (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (3): ReLU(inplace=True)
        (4): Conv2d(64, 2, kernel_size=(1, 1), stride=(1, 1))
      )
    )
    (3): Downsample()
    (4): NICE(
      (net): Sequential(
        (0): Conv2d(8, 64, kernel_size=(1, 1), stride=(1, 1))
        (1): ReLU(inplace=True)
        (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (3): ReLU(inplace=True)
        (4): Conv2d(6

In [63]:
optimizer = optim.Adam(flow.parameters(), 1e-3)

In [64]:
for i in range(EPOCHS):
    bits_total = 0
    for x, _ in train_loader:
        x = x.to(device)
        with torch.no_grad():
            x = (x * (256 - 1) + torch.randn_like(x).to(device)) / 256
        ll = flow.log_prob(x)
        bits = calculate_bits_per_pixel(ll)
        optimizer.zero_grad()
        bits.backward()
        optimizer.step()
        bits_total += bits
    print(bits_total.item() / len(train_loader))

9.333046374733192
9.328957735792423
9.328158351120598


KeyboardInterrupt: 