In [1]:
import os
import torch
import numpy as np

from torch import nn
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
from torchvision.transforms import ToTensor, Lambda
import matplotlib.pyplot as plt


In [2]:
available = torch.cuda.is_available()
curr_device = torch.cuda.current_device()
device = torch.device("cuda:0" if available else "cpu")
device_count = torch.cuda.device_count() 
device_name =  torch.cuda.get_device_name(0)

print(f'Cuda available: {available}')
print(f'Current device: {curr_device}')
print(f'Device: {device}')
print(f'Device count: {device_count}')
print(f'Device name: {device_name}')

#device = torch.device("cpu")

Cuda available: True
Current device: 0
Device: cuda:0
Device count: 1
Device name: GeForce GTX 1070


In [3]:
# (Adapted) Code from PyTorch's Resnet impl: https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py

def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d:
    """3x3 convolution with padding"""
    return nn.Conv2d(
        in_planes,
        out_planes,
        kernel_size=3,
        stride=stride,
        padding=dilation,
        groups=groups,
        bias=False,
        dilation=dilation,
    )


def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


class Bottleneck(nn.Module):
    # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
    # while original implementation places the stride at the first 1x1 convolution(self.conv1)
    # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
    # This variant is also known as ResNet V1.5 and improves accuracy according to
    # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.

    expansion: int = 1

    def __init__(
        self,
        inplanes: int,
        planes: int,
        stride: int = 1,
        downsample = None,
        groups: int = 1,
        base_width: int = 64,
        dilation: int = 1,
        norm_layer = None,
        use_final_relu = True,
    ) -> None:
        super().__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        width = int(planes * (base_width / 64.0)) * groups
        # Both self.conv2 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv1x1(inplanes, width)
        self.bn1 = norm_layer(width)
        self.conv2 = conv3x3(width, width, stride, groups, dilation)
        self.bn2 = norm_layer(width)
        
        # BK: Force last layers back to inplanes
        self.conv3 = conv1x1(width, inplanes)
        self.bn3 = norm_layer(inplanes)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride
        
        self.use_final_relu = use_final_relu

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        #out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        #out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        #out = self.bn3(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        if self.use_final_relu:
            out = self.relu(out)

        return out

In [4]:
class Reshape(nn.Module):
    def __init__(self, shape):
        super(Reshape, self).__init__()
        self.shape = tuple([-1] + list(shape))
        
    def forward(self, x):
        return torch.reshape(x, self.shape)

def dense_backbone(shape, network_width):
    input_width = shape[0] * shape[1] * shape[2]
    return nn.Sequential(
        nn.Flatten(),
        nn.Linear(input_width, network_width),
        nn.ReLU(),
        nn.Linear(network_width, input_width),
        Reshape(shape)
    )

def bottleneck_backbone(planes):
    return nn.Sequential(
        Bottleneck(1, planes),
        Bottleneck(1, planes),
    )

mask = {}
mask_device = {}
def checkerboard_mask(shape, to_device=True):
    global mask, mask_device
    if shape not in mask:
        mask[shape] = 1 - np.indices(shape).sum(axis=0) % 2
        mask[shape] = torch.Tensor(mask[shape])
        
    if to_device and shape not in mask_device:
        mask_device[shape] = mask[shape].to(device)
        
    return mask_device[shape] if to_device else mask[shape]

In [5]:
class NormalizingFlowMNist(nn.Module):
    EPSILON = 1e-7
    
    def __init__(self, num_coupling, planes):
        super(NormalizingFlowMNist, self).__init__()
        self.num_coupling = num_coupling
        self.shape = (1, 28, 28)
        
        self.planes = planes
        self.s = nn.ModuleList([bottleneck_backbone(planes) 
                                for x in range(num_coupling)])
        self.t = nn.ModuleList([bottleneck_backbone(planes)
                                for x in range(num_coupling)])
        
        # Learnable scaling parameters for outputs of S
        self.s_scale = nn.ParameterList([torch.nn.Parameter(torch.randn(self.shape)) 
                                         for x in range(num_coupling)])
        for i in range(num_coupling):
            self.s_scale[i].requires_grad = True

    def forward(self, x):
        if model.training:
            s_vals = []
            for i in range(self.num_coupling):
                mask = checkerboard_mask(self.shape)
                mask = mask if i % 2 == 0 else (1 - mask)
               
                t = self.t[i](mask * x)
                s = (self.s_scale[i]) * torch.tanh(self.s[i](mask * x))
                y = mask * x + (1 - mask) * (x * torch.exp(s) + t)
                s_vals.append(mask * s)
                
                x = y

            # Return outputs and vars needed for determinant
            return y, torch.cat(s_vals)
        else:
            y = x
            for i in reversed(range(self.num_coupling)):
                mask = checkerboard_mask(self.shape)
                mask = mask if i % 2 == 0 else (1 - mask)
               
                t = self.t[i](mask * y)
                s = (mask * self.s_scale[i]) * torch.tanh(self.s[i](mask * y))
                x = mask * y + (1 - mask)((y - t) * torch.exp(-s + EPSILON))
                
                y = x
                
            return x

In [6]:
def loss_fn(y, s, batch_size):
    # -log(zero-mean gaussian) + log determinant
    # -log p_x = log(pz(f(x))) + log(det(\partial f/\partial x))
    # -log p_x = 0.5 * y**2 + s1 + s2 + ...
    logpx = -torch.sum(0.5 * y**2)
    det = torch.sum(s)
    ret = -(logpx + det)
    return torch.div(ret, batch_size)

In [13]:
def train_loop(dataloader, model, loss_fn, optimizer, report_iters=10, num_pixels=28*28):
    size = len(dataloader)
    prev = []
    for batch, (X, _) in enumerate(dataloader):
        # Transfer to GPU
        X = X.to(device)
        
        # Compute prediction and loss
        y, s = model(X)
        loss = loss_fn(y, s, batch_size)
        
        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        
        prev = [(name, x, x.grad) for name, x in model.named_parameters(recurse=True)]
        optimizer.step()

        if batch % report_iters == 0:
            loss, current = loss.item(), batch
            print(f"loss: {loss:.2f}; {loss / num_pixels / np.log(2):>.2f}  [{current:>5d}/{size:>5d}]")

def test_loop(dataloader, model, loss_fn, num_pixels=28*28):
    size = len(dataloader)
    num_batches = len(dataloader)
    test_loss = 0

    with torch.no_grad():
        for X, _ in dataloader:
            X = X.to(device)
            y, s = model(X)
            test_loss += loss_fn(y, s, batch_size)

    test_loss /= num_batches
    print(f"Test Error: \n Avg loss: {test_loss:.2f}; {test_loss / num_pixels / np.log(2):.2f} \n")

# MNist Training

In [14]:
def pre_process(x):
    # Convert back to integer values
    x = x * 255.
    
    # Add random uniform [0, 1] noise to get a proper likelihood estimate
    # https://bjlkeng.github.io/posts/a-note-on-using-log-likelihood-for-generative-models/
    x = x + torch.rand(x.shape)
    
    # Apply transform to deal with boundary effects (see realNVP paper)
    x = torch.logit(0.05 + 0.90 * x / 256)
    
    return x

In [15]:
train_dataset = datasets.MNIST('data', train=True, download=True,
                               transform=transforms.Compose([
                                   transforms.ToTensor(),
                                   transforms.Lambda(pre_process)
                               ]))
test_dataset = datasets.MNIST('data', train=False, download=True,
                              transform=transforms.Compose([
                                  transforms.ToTensor(),
                                  transforms.Lambda(pre_process),
                              ]))

In [16]:
learning_rate = 0.001
batch_size = 100
epochs = 1

model = NormalizingFlowMNist(3, 64).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False) #shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size,  shuffle=False) #shuffle=True)

for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train_loop(train_loader, model, loss_fn, optimizer)
    test_loop(test_loader, model, loss_fn)

print("Done!")

Epoch 1
-------------------------------
loss: 2991.84; 5.51  [    0/  600]
loss: 555.46; 1.02  [   10/  600]
loss: 482.01; 0.89  [   20/  600]
loss: 369.45; 0.68  [   30/  600]
loss: 350.66; 0.65  [   40/  600]
loss: 349.23; 0.64  [   50/  600]
loss: 402.66; 0.74  [   60/  600]
loss: 288.61; 0.53  [   70/  600]
loss: 345.52; 0.64  [   80/  600]
loss: 265.58; 0.49  [   90/  600]
loss: 270.46; 0.50  [  100/  600]
loss: 280.42; 0.52  [  110/  600]
loss: 232.31; 0.43  [  120/  600]
loss: 269.43; 0.50  [  130/  600]
loss: 219.72; 0.40  [  140/  600]
loss: 143.88; 0.26  [  150/  600]
loss: 164.26; 0.30  [  160/  600]
loss: 110.23; 0.20  [  170/  600]
loss: 108.00; 0.20  [  180/  600]
loss: 76.59; 0.14  [  190/  600]
loss: 110.35; 0.20  [  200/  600]
loss: 57.21; 0.11  [  210/  600]
loss: 81.44; 0.15  [  220/  600]
loss: 147.58; 0.27  [  230/  600]
loss: 97.08; 0.18  [  240/  600]
loss: 41.22; 0.08  [  250/  600]
loss: 52.70; 0.10  [  260/  600]
loss: 19.86; 0.04  [  270/  600]
loss: 42.81; 0

# 2022-01-29

* Getting lots of NaNs -- debugged a bunch of things:
    * Removed Resnet
    * Removed exp()
    * Made forward pass a simple feedforward
* But it looks like issue is the data???    
    * The stupid paper said the transform should be `logit(alpha + (1-alpha)*x/256)`...
    * Data is originally in [0,1] (pytorch dataset)
    * Convert back to pixels multiply by 255
    * Add jitter to get upper bound on bits per pixel (see my post)
    * Range is now [0, 256]
    * Suggested alpha=0.05 (I had a bug and used 0.5)
    * But that gets you really close to 256 (jitter is always less than 1.0 though) e.g.i logit(0.05 + 0.95 * ~255.99/256) ~= \inf!
    * Instead, I used this `logit(alpha + (1-alpha - 0.05)*x/256)`, which is symmetrical...
    
NEXT STEPS:
* So things look good now, except that I get a negative loss, which shouldn't happen (after applying jitter)???
    * It's because I need a new uniform noise sample per EPOCH???
    * Or is it because I'm using continuous variables on the output?  So maybe I just need to measure this "loss" when I reverse the network?  
        * It's probably this... if it's a continuous output, the log density surely doesn't need to be positive (vs. if I were directly outputting pixel values).