# Introduction

- Generative models are characterized by their ability to learn from limited amounts of data and generalize into varying tasks and contexts.
- During recent years, flow-based generative models have attracted particular interest over alternative, traditionally more prominent models due to the tractability of exact inference for latent variables and the exact evaluation of log-likelihood (Kingma & Dhariwal 2018).
- In this notebook, a relatively simple instance of glow is implemented and experimented with using CIFAR-10 dataset.

# Model Specification

### Normalizing Flow

In the context of majority of flow-based models, the generative process for a high dimensional random vector $\textbf{x}$ is defined as follows

\begin{equation}
\begin{aligned}
    \textbf{z} \sim p_\theta(\textbf{z})  \\
    \textbf{x} = \textbf{g}_\theta(\textbf{z}).
\end{aligned}
\end{equation}

Here $\textbf{z}$ and $p_\theta(\textbf{z})$ is used to denote the latent variable and its corresponding, usually simple, density. Respectively, $\textbf{g}_\theta$ is used to denote an invertible function such that the latent variable can be inferred as $\textbf{z} = \textbf{g}_\theta^{-1}(\textbf{x}) = \textbf{f}_\theta(\textbf{x})$. The functions $\textbf{f}$ and $\textbf{g}$ are chosen so that they can be represented as sequences of individual and invertible transformations $\textbf{f} = \textbf{f}_1 \circ \textbf{f}_2 \circ \dots \circ \textbf{f}_K$ such that $\textbf{x} = \textbf{g}(\textbf{z})$ can be written as

\begin{equation}
    \textbf{x} \overset{f_1}{\leftrightarrow} \textbf{h}_1 \overset{f_2}{\leftrightarrow} \textbf{h}_2 \dots  \overset{f_K}{\leftrightarrow} \textbf{z}.
\end{equation}

This series of invertible transformations essentially maps a more complex desity $p(\textbf{x})$ to a simple, tractable density $p(\textbf{z})$.

For continuous data, $\mathcal{D}$, the log-likelihood objective for normalizing flow models is defined as minimizing

\begin{equation}
\begin{aligned}
    \mathcal{L}(\mathcal{D}) \simeq \frac{1}{N} \sum_{i=1}^N -\log p_{\theta} (\hat{x}^{(i)}) + c \\
    \hat{x}^{(i)} = x^{(i)} + u, u \sim \mathcal{U}(0, a) \\
    c = -M \cdot \log a,
\end{aligned}
\end{equation}

i.e. the expected compression cost, where $M$ and $a$ is the dimensionality and discretization level of $\textbf{x}$. By utilizing the change of variable formula, the probability density $p_\theta(\textbf{x})$ can be represented as

\begin{equation}
    p_\theta(\textbf{x}) = p_\theta(\textbf{z}) |\det(d\textbf{z}/d\textbf{x})|| \\
\end{equation}

By taking a logartihm of the density and representing the transformation $\textbf{x} \leftrightarrow \textbf{z}$ as a series of transformations, we can derive the maximum likelihood objective as follows:

\begin{equation}
\begin{aligned}
    \log p_\theta(\textbf{x}) = \log p_\theta(\textbf{z}) + \log |\det(d\textbf{z}/d\textbf{x})|| \\
    = \log p_\theta(\textbf{z}) + \sum_{i=1}^K \log |\det(d\textbf{h}_i/d\textbf{h}_{i-1})||,
\end{aligned}
\end{equation}

Here the terms $|\det(d\textbf{h}_i/d\textbf{h}_{i-1})|$ essentially indicate the expansive/contracting magnitude of a single transformation in sequence. 


In [0]:
import time
from google.colab import files

In [2]:
from google.colab import drive
drive.mount('/content/gdrive')

Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3aietf%3awg%3aoauth%3a2.0%3aoob&response_type=code&scope=email%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdocs.test%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive.photos.readonly%20https%3a%2f%2fwww.googleapis.com%2fauth%2fpeopleapi.readonly

Enter your authorization code:
··········
Mounted at /content/gdrive


In [3]:
!ln -sf /opt/bin/nvidia-smi /usr/bin/nvidia-smi
!pip install gputil
!pip install psutil
!pip install humanize
import psutil
import humanize
import os
import GPUtil as GPU
GPUs = GPU.getGPUs()
# XXX: only one GPU on Colab and isn’t guaranteed
gpu = GPUs[0]
def printm():
 process = psutil.Process(os.getpid())
 print("Gen RAM Free: " + humanize.naturalsize( psutil.virtual_memory().available ), " | Proc size: " + humanize.naturalsize( process.memory_info().rss))
 print("GPU RAM Free: {0:.0f}MB | Used: {1:.0f}MB | Util {2:3.0f}% | Total {3:.0f}MB".format(gpu.memoryFree, gpu.memoryUsed, gpu.memoryUtil*100, gpu.memoryTotal))
printm() 

Collecting gputil
  Downloading https://files.pythonhosted.org/packages/ed/0e/5c61eedde9f6c87713e89d794f01e378cfd9565847d4576fa627d758c554/GPUtil-1.4.0.tar.gz
Building wheels for collected packages: gputil
  Building wheel for gputil (setup.py) ... [?25l[?25hdone
  Created wheel for gputil: filename=GPUtil-1.4.0-cp36-none-any.whl size=7410 sha256=2c5d9bddd3b30021585275f0f5b56c75450957f66f11ff08f74aef2c52d6c03f
  Stored in directory: /root/.cache/pip/wheels/3d/77/07/80562de4bb0786e5ea186911a2c831fdd0018bda69beab71fd
Successfully built gputil
Installing collected packages: gputil
Successfully installed gputil-1.4.0
Gen RAM Free: 12.8 GB  | Proc size: 156.3 MB
GPU RAM Free: 16280MB | Used: 0MB | Util   0% | Total 16280MB


In [0]:
import numpy as np
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms
cuda = torch.cuda.is_available()

In [0]:
BATCH_SIZE = 64
DIM_HIDDEN = 512
DEPTH = 32
LEVELS = 3

In [6]:
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
from torchvision.transforms import ToTensor
from functools import reduce


classes = ('plane', 'car', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck')

n_bins = 2**8  # 8 bits

transform = transforms.Compose([transforms.ToTensor(), 
                         lambda x: x + torch.zeros_like(x).uniform_(0., 1./n_bins)])

# Load dataset
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)

    
train_loader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE,
                                          shuffle=True, num_workers=10, drop_last=True)
test_loader = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE,
                                         shuffle=False, num_workers=10, drop_last=True)
init_loader = torch.utils.data.DataLoader(torchvision.datasets.CIFAR10(root='./data', train=True, 
    download=True, transform=transform), batch_size=BATCH_SIZE, shuffle=True, num_workers=1)



0it [00:00, ?it/s]

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


170500096it [00:06, 27662614.48it/s]                               


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified
Files already downloaded and verified


In [7]:
from torch.autograd import Variable

x, _ = trainset[0]
x = Variable(x)

DIM_CHANNELS, IMG_HEIGHT, IMG_WIDTH = x.shape
SHAPE = BATCH_SIZE, DIM_CHANNELS, IMG_HEIGHT, IMG_WIDTH
SHAPE

(64, 3, 32, 32)

In [0]:
learn_top = True

# Utility Functions

In [0]:
# https://github.com/pclucas14/pytorch-glow/blob/master/utils.py
def flatten_sum(logps):
    while len(logps.size()) > 1: 
        logps = logps.sum(dim=-1)
    return logps

In [0]:
def onehot(y, num_classes):
    y_onehot = torch.zeros(y.size(0), num_classes).to(y.device)
    if len(y.size()) == 1:
        y_onehot = y_onehot.scatter_(1, y.unsqueeze(-1), 1)
    else:
        y_onehot = y_onehot.scatter_(1, y, 1)
    return y_onehot

In [0]:
def gaussian_diag(mean, logsd):
    class o(object):
        Log2PI = float(np.log(2 * np.pi))
        pass

        def logps(x):
            return  -0.5 * (o.Log2PI + 2. * logsd + (x - mean) ** 2 / torch.exp(2. * logsd))

        def sample(eps_std=1):
            eps = torch.normal(mean=torch.zeros_like(mean),
                           std=torch.ones_like(logsd) * eps_std)
            return mean + torch.exp(logsd) * eps

    o.logp = lambda x: flatten_sum(o.logps(x))
    return o


In [0]:
# Standard linear layer with zero initial weights and biases

class LinearZeros(nn.Linear):
    def __init__(self, in_channels, out_channels, logscale=3):
        super().__init__(in_channels, out_channels)
        self.logscale = logscale
        self.register_parameter("logs", nn.Parameter(torch.zeros(out_channels)))
        # zero init
        self.weight.data.zero_()
        self.bias.data.zero_()

    def forward(self, x):
        out = super().forward(x)
        return out * torch.exp(self.logs * self.logscale)

In [0]:
# Standard convolutional layer with zero initial weights and biases

class ConvZeros(nn.Conv2d):
    def __init__(self, channels_in, channels_out, filter_size, stride=1, padding=1, logs_factor=3.):
        super().__init__(channels_in, channels_out, filter_size, stride=stride, padding=padding)
        self.register_parameter("logs", nn.Parameter(torch.zeros(channels_out, 1, 1)))
        self.logs_factor = logs_factor
        # Zero init
        self.weight.data.zero_()
        self.bias.data.zero_()

    def forward(self, x):
        out = super().forward(x)
        return out * torch.exp(self.logs * self.logs_factor)

# Glow

### Actnorm 

The actnorm layer performs an affine transformation on the activations using scale and bias parameters per channel. Both scale and shift are initialized depending on the data so that the first batch has a mean of zero and unit variance per channel after the actnorm layer. These parameters are treated as regular trainable parameters after initialization.

In [0]:
class ActnormLayer(nn.Module):
    def __init__(self, dim_channels, logs_factor=1., scale=1.):
        super(ActnormLayer, self).__init__()
        self.initialized = False
        self.logs_factor = logs_factor
        self.scale = scale
        self.register_parameter('b', nn.Parameter(torch.zeros(1, dim_channels, 1, 1)))
        self.register_parameter('logs', nn.Parameter(torch.zeros(1, dim_channels, 1, 1)))


    def init_params(self, x):

        with torch.no_grad():

            bias = -torch.mean(x, dim=[0, 2, 3], keepdim=True)
            vars = torch.mean((x + bias) ** 2, dim=[0, 2, 3], keepdim=True)
            logs = torch.log(self.scale / (torch.sqrt(vars) + 1e-6)) / self.logs_factor

            self.b.data.copy_(bias.data)
            self.logs.data.copy_(logs.data)


    def forward(self, x, reverse, objective):

        if not reverse:

            if not self.initialized: 
                self.initialized = True
                self.init_params(x)

            logs = self.logs * self.logs_factor
            b = self.b
            
            output = (x + b) * torch.exp(logs)
            dlogdet = torch.sum(logs) * x.size(2) * x.size(3) # n of pixels 

            return output.view(x.shape), objective + dlogdet

        else:

            logs = self.logs * self.logs_factor
            b = self.b
            output = x * torch.exp(-logs) - b
            dlogdet = torch.sum(logs) * x.size(2) * x.size(3) # n of pixels 

            return output.view(x.shape), objective - dlogdet

In [0]:
# Standard convolutional layer with actnorm

class ConvActNorm(nn.Module):
    def __init__(self, channels_in, channels_out, filter_size, stride=1, padding=None):
        super(ConvActNorm, self).__init__()
        padding = (filter_size - 1) // 2 or padding
        self.conv = nn.Conv2d(channels_in, channels_out, filter_size, padding=padding, bias=False)
        self.actnorm =  ActnormLayer(channels_out)

    def forward(self, x):
        x = self.conv(x)
        x = self.actnorm.forward(x, False, -1)[0]
        return x

### Invertible $1 \times 1$ convolution

As the name suggests, this layer performs an invertible convolution with equal number of input and output channels and a kernel of size $c \times c$, where for an input tensor of dimensionality $c \times h \times w$. This transformation essentially learns to permute the ordering of the channels to account for the fact that affine coupling layers leave the first half of the data untouched. Respectively, the convolution kernel is first initialized as a random rotation matrix.

The log-determinant of an invertible $1 \times 1$ convolution for an input tensor $\textbf{h}$ is defined as

\begin{equation}
    \log |\det(\frac{d \mathsf{conv2D}(\textbf{h}; \textbf{W})}{d \textbf{h}})| = h \cdot w \cdot \log |\det(\textbf{W})|.
\end{equation}

The computation time for this log-determinant can be significantly decreased for large values of $c$ by utilizing LU-decomposition for the weight matrix $\textbf{W}$.

In [0]:
class Invertible_1x1_Convolution(nn.Module):
    
    def __init__(self, dim_channels):
        super(Invertible_1x1_Convolution, self).__init__()
        
        self.kernel_shape = 1
        self.dim_channels = dim_channels
        
        # Init with random orthonormal weights
        self.register_parameter('w', torch.nn.Parameter(torch.from_numpy(np.linalg.qr(np.random.randn(
            *[self.dim_channels,self.dim_channels]))[0].astype('float32'))))
        
    def forward(self, z, reverse, logdet):
        
        shape = z.shape
        d_logdet = torch.log(torch.abs(torch.det(self.w))) * shape[2] * shape[3]  # pixels
        
        if not reverse:
            
            w = self.w
            kernel = torch.reshape(w, w.shape + torch.Size([1,1]))
            z = F.conv2d(z, kernel)
            
            return z, logdet + d_logdet

        else:

            w = torch.inverse(self.w)
            kernel = torch.reshape(w, w.shape + torch.Size([1,1]))
            z = F.conv2d(z, kernel)
            
            return z, logdet - d_logdet
              

### Affine coupling layer

The affine coupling layer first splits the input tensor along its channel dimension after which it performs an affine transformation on the second half of data ($\textbf{x}_b$) before concatenating them back together. The affine transformation is performed based on the function NN(), which takes the first half of the data ($\textbf{x}_a$) as an input and outputs the scale and shift parameters for the affine transformation. The function NN() is essentially a shallow convolutional network with its last convolution initialized as zeros.


In [0]:
class AffineCouplingLayer(nn.Module):
    
    def __init__(self, dim_channels, dim_hidden=DIM_HIDDEN):
        super(AffineCouplingLayer, self).__init__()
        
        self.dim_channels = dim_channels
        
        self.out_dim = dim_channels // 2
        self.in_dim = dim_channels - self.out_dim
        self.dim_hidden = dim_hidden
        
        self.kernel1 = 3
        self.kernel2 = 1
        self.kernel3 = 3
        
        self.nn = nn.Sequential(
            ConvActNorm(self.in_dim, self.dim_hidden, self.kernel1, padding=1),
            nn.ReLU(),
            ConvActNorm(self.dim_hidden, self.dim_hidden, self.kernel2, padding=0),
            nn.ReLU(),
            ConvZeros(self.dim_hidden, 2*self.out_dim, self.kernel3, padding=1)
        )
        
        
    def forward(self, z, reverse, logdet):
        
        z_a, z_b = torch.chunk(z, 2, dim=1)
        
        out = self.nn(z_a)
        
        shift, log_scale = out[:, 0::2], out[:, 1::2]
        scale = torch.sigmoid(log_scale + 2.)
        
        if not reverse:
            z_b += shift
            z_b *= scale
            logdet += flatten_sum(torch.log(scale))
            
        else:
            z_b /= scale
            z_b -= shift
            logdet -= flatten_sum(torch.log(scale))
            
          
        z = torch.cat([z_a, z_b], 1)
        
        return z, logdet
    


### Multi-scale architecture

The multi-scale architecture of the model consists of sequences of flow steps, i.e.

- **Step of flow**: actnorm $\rightarrow$ invertible $1 \times 1$ convolution $\rightarrow$ affine coupling

preceded by a squeeze operation (transforming a $c \times h \times w$ tensor into shape of $4c \times \frac{h}{2} \times \frac{w}{2}$) and followed by a split layer. 

- **Level of flow**: squeeze $\rightarrow$ step of flow $\rightarrow$ split

At each split layer, half of the dimensions are factored out and modelled as a Gaussian distribution (optionally storing the cut-off dimensions) while the other half are selected for further layers as described in the work of Dinh et al. (2016). A complete glow consists of a stack of multiple levels of flow.

In [0]:
class FlowStep(nn.Module):
    
    def __init__(self, dim_channels):
        super(FlowStep, self).__init__()
        
        self.actnorm = ActnormLayer(dim_channels)
        self.invertible_convolution = Invertible_1x1_Convolution(dim_channels)
        self.affine_coupling = AffineCouplingLayer(dim_channels)
        
        
    def forward(self, z, reverse, logdet):
        
        if not reverse:
            
            #print('Initial: {}'.format(logdet[0]))
            z, logdet = self.actnorm(z, reverse, logdet)
            #print('Actnorm: {}'.format(logdet[0]))
            z, logdet = self.invertible_convolution(z, reverse, logdet)
            #print('Conv: {}'.format(logdet[0]))
            z, logdet = self.affine_coupling(z, reverse, logdet)
            #print('Affine: {}'.format(logdet[0]))
             
        else:
            
            z, logdet = self.affine_coupling(z, reverse, logdet)
            z, logdet = self.invertible_convolution(z, reverse, logdet)
            z, logdet = self.actnorm(z, reverse, logdet)
        
        return z, logdet
                      

In [0]:
class SqueezeLayer(nn.Module):
    
    def __init__(self, factor=2):
        super(SqueezeLayer, self).__init__()
        
        self.factor = factor

    def squeeze(self, x):
        batch_size, dim_channel, height, width = x.shape

        # https://github.com/chaiyujin/glow-pytorch/blob/master/glow/modules.py
        x = x.view(batch_size, dim_channel, height // self.factor, self.factor, width // self.factor, self.factor)
        x = x.permute(0, 1, 3, 5, 2, 4).contiguous()
        x = x.view(batch_size, dim_channel * self.factor ** 2, height // self.factor, width // self.factor)
        
        return x

    def unsqueeze(self, x):
        batch_size, dim_channel, height, width = x.shape

        # https://github.com/chaiyujin/glow-pytorch/blob/master/glow/modules.py
        x = x.view(batch_size, dim_channel // self.factor ** 2, self.factor, self.factor, height, width)
        x = x.permute(0, 1, 4, 2, 5, 3).contiguous()
        x = x.view(batch_size, dim_channel // self.factor ** 2, height * self.factor, width * self.factor)
        
        return x
        
    def forward(self, z, reverse, logdet):
        
        if not reverse:
            return self.squeeze(z), logdet
               
        else:
            return self.unsqueeze(z), logdet
                

In [0]:
class SplitLayer(nn.Module):
    
    def __init__(self, dim_channels):
        super(SplitLayer, self).__init__()
        self.dim_channels = dim_channels
        self.in_dim = dim_channels // 2
        
        self.conv = ConvZeros(self.in_dim, self.dim_channels, 3, padding=1)

    def prior(self, z):
        out = self.conv(z)
        shift, log_scale = out[:, 0::2], out[:, 1::2]
        return gaussian_diag(shift, log_scale)

    def forward(self, z, reverse, logdet, eps_std=1, use_stored_sample=False):
        
        if not reverse:
            
            z_a, z_b = torch.chunk(z, 2, dim=1)
            
            pz = self.prior(z_a)
            self.sample = z_b
            
            logdet += pz.logp(z_b)
            
            return z_a, logdet
        
        else:
            
            pz = self.prior(z)
            z_b = self.sample if use_stored_sample else pz.sample(eps_std)
            # Concatenate the input and sample
            z = torch.cat([z, z_b], dim=1)
            logdet -= pz.logp(z_b)
            
            return z, logdet
 

In [0]:
class Prior(nn.Module):
    
    def __init__(self, input_shape):
        super(Prior, self).__init__()
        self.input_shape = input_shape
        self.conv = ConvZeros(2*input_shape[1], 2*input_shape[1], 3, padding=1)
        self.project_y = LinearZeros(len(classes), 2 * self.input_shape[1])
        self.register_parameter("prior",
            nn.Parameter(torch.zeros([input_shape[0], 2*input_shape[1], input_shape[2], input_shape[3]])))

        
    def forward(self, z, reverse, logdet, y_onehot=None, eps_std=1):

        B, C = self.prior.size(0), self.prior.size(1)
        
        if z is not None:
            b, c, h, w = z.shape
        else:
            b, c, h, w = self.input_shape
        
        h_ = self.prior.detach().clone()

        if learn_top:
            h_ = self.conv(h_)

        if y_onehot is not None:
            y_proj = self.project_y(y_onehot).view(B, C, 1, 1)
            h_ += y_proj

        mean, logsd = torch.chunk(h_, 2, dim=1)
        pz = gaussian_diag(mean, logsd)

        if not reverse:
        
            logdet += pz.logp(z).cuda()
            return z, logdet

        else:
        
            if z is None:
                z = pz.sample(eps_std)
            logdet -= pz.logp(z).cuda()
            return z, logdet


In [0]:
class Glow(nn.Module):
    
    def __init__(self, input_shape=SHAPE, levels=LEVELS, depth=DEPTH):
        super(Glow, self).__init__()
        
        self.input_shape = input_shape
        self.levels = levels

        b, c, h, w = input_shape
        layers = []
        self.layer_names = []
        self.out_shapes = []
        
        for l in range(levels):
            layers += [SqueezeLayer()]
            self.layer_names += ['Squeeze']
            c, h, w = c * 4, h // 2, w // 2
            self.out_shapes.append([-1, c, h, w])
            
            for _ in range(depth):
                layers += [FlowStep(c)]
                self.layer_names += ['Flow']
                self.out_shapes.append([-1, c, h, w])
                
            if l < levels - 1:
                layers += [SplitLayer(c)]
                self.layer_names += ['Split']
                c //= 2
                self.out_shapes.append([-1, c, h, w])
        
        layers += [Prior([b,c,h,w])]
        self.layer_names += ['Prior']
        
        self.layers =  nn.ModuleList(layers)
        self.project_class = LinearZeros(self.out_shapes[-1][1], len(classes))

        
    def forward(self, x, reverse, logdet, y_onehot=None, eps_std=1, n_stored_samples=None):
        
        n_stored = n_stored_samples or 0
        stored_sample_pattern = torch.fill_(torch.empty(self.levels, dtype=bool), False)
        stored_sample_pattern_r = torch.fill_(torch.empty(self.levels, dtype=bool), False)
        for l in range(self.levels): 
            if l < n_stored:
                stored_sample_pattern[l] = True
                stored_sample_pattern_r[self.levels - l - 1] = True

        if not reverse:
            
            #it = time.time()
            l = 0  # Current level

            for i, layer in enumerate(self.layers):
                
                #t = time.time()
                
                if isinstance(layer, SplitLayer):
                    x, logdet = layer(x, reverse, logdet, eps_std, stored_sample_pattern[l])
                    l += 1
                elif isinstance(layer, Prior):
                    x, logdet = layer(x, reverse, logdet, y_onehot, eps_std)
                else:                
                    x, logdet = layer(x, reverse, logdet)

                #print("{}: {}".format(self.layer_names[i], time.time() - t))

            #print(time.time() - it)

        else:
            
            with torch.no_grad():

                l = 0  # Current level

                for i, layer in enumerate(reversed(self.layers)):
                    
                    if isinstance(layer, SplitLayer):
                        x, logdet = layer(x, reverse, logdet, eps_std, stored_sample_pattern_r[l])
                        l += 1
                    elif isinstance(layer, Prior):
                        x, logdet = layer(x, reverse, logdet, y_onehot, eps_std)
                    else:                
                        x, logdet = layer(x, reverse, logdet)

                
        return x, logdet
        
        
    def sample(self, y_onehot=None, eps_std=1):
        
        with torch.no_grad():
            
            samples = self.forward(None, True, torch.zeros(self.input_shape[0]).cuda(), y_onehot, eps_std)[0]
    
            return samples

      
        

# Training

In [0]:
model = Glow().cuda()

In [0]:
import torch.optim as optim
from torchvision import datasets, transforms, utils

# init learning rate
lr = 1e-3

optim = optim.Adam(model.parameters(), lr=lr, weight_decay=5e-5)
scheduler = torch.optim.lr_scheduler.StepLR(optim, step_size=45, gamma=0.1)
BCE = nn.BCEWithLogitsLoss().cuda()


def objective(img_shape, logdet):
    l = (-logdet) / float(np.log(2.) * np.prod(img_shape[1:]))
    return torch.mean(l)

np.random.seed(0)
torch.manual_seed(0)
torch.cuda.manual_seed_all(0)

In [0]:
# Data dependent init

with torch.no_grad():
    model.eval()
    for (img, y) in init_loader:
        img = img.cuda()
        y_onehot = onehot(y, len(classes)).cuda()
        logdet = torch.zeros_like(img[:, 0, 0, 0])
        _ = model(img, False, logdet, y_onehot)
        break

In [0]:
model = nn.DataParallel(model).cuda()

In [27]:
num_epochs = 2000
num_warmup_epochs = 20

for epoch in range(num_epochs):
    print('Epoch: {}'.format(epoch))
    num_batches = len(train_loader)
    avg_train_bits_x = 0.

    # --------------------------------------------------------------------------
    # Training
    # --------------------------------------------------------------------------

    for i, (img, y) in enumerate(train_loader):
        #if i > 10 : break
            
        t = time.time()
        img = img.cuda()
        y_onehot = onehot(y, len(classes)).cuda()
        logdet = torch.zeros_like(img[:, 0, 0, 0]).cuda()
        # discretizing cost 
        logdet += float(-np.log(n_bins) * np.prod(img.shape[1:]))
        # log_det_jacobian cost (and some prior from Split OP)
        z, logdet = model(img, False, logdet, y_onehot)
        loss_generative = objective(img.shape, logdet)

        y_logits = model.module.project_class(z.mean(2).mean(2))
        loss_classes = BCE(y_logits, y_onehot)

        loss = loss_generative #+ loss_classes * 0.5

        optim.zero_grad()
        loss.backward()

        torch.nn.utils.clip_grad_value_(model.parameters(), 5)
        torch.nn.utils.clip_grad_norm_(model.parameters(), 100)

        optim.step()

        #print('Optimization time: {}'.format(time.time() - ot))

        avg_train_bits_x += loss.item()

        # update learning rate
        new_lr = float(lr * min(1., (i + epoch * num_batches) / (num_warmup_epochs * num_batches)))
        for pg in optim.param_groups: pg['lr'] = new_lr

        if (i + 1) % 500 == 0: 
            print('avg train bits per pixel {:.4f}'.format(avg_train_bits_x / 500))
            avg_train_bits_x = 0.
            sample = model.module.sample(y_onehot)
            grid = utils.make_grid(sample)

            utils.save_image(grid, '/content/gdrive/My Drive/Colab Notebooks/glow_samples/samples6/cifar_Test_{}_{}.png'.format(epoch, i // 500))
            #files.download('cifar_Test_{}_{}.png'.format(epoch, i // 500))

        print('Batch: {}, iteration took {:.4f}, Loss: {}'.format(i, time.time() - t, loss))
    
    # --------------------------------------------------------------------------
    # Validation
    # --------------------------------------------------------------------------

    if (epoch + 1) % 1 == 0:
        model.eval()
        avg_test_bits_x = 0.
        with torch.no_grad():
            for i, (img, y) in enumerate(test_loader): 
                #if i > 10 : break
                img = img.cuda() 
                y_onehot = onehot(y, len(classes)).cuda()

                logdet = torch.zeros_like(img[:, 0, 0, 0]).cuda()
               
                # discretizing cost 
                logdet += float(-np.log(n_bins) * np.prod(img.shape[1:]))
                
                # log_det_jacobian cost (and some prior from Split OP)
                z, logdet = model(img, False, logdet, y_onehot)
                last_img = img

                loss_generative = objective(img.shape, logdet)

                y_logits = model.module.project_class(z.mean(2).mean(2))
                loss_classes = BCE(y_logits, y_onehot)

                loss = loss_generative #+ loss_classes * 0.5

                avg_test_bits_x += loss

            print('avg test bits per pixel {:.4f}'.format(avg_test_bits_x.item() / i))

            # ------------------------------------------------------------------
            # Sampling
            # ------------------------------------------------------------------

            # Random sample

            sample = model.module.sample(y_onehot)
            grid = utils.make_grid(sample)

            utils.save_image(grid, '/content/gdrive/My Drive/Colab Notebooks/glow_samples/samples6/cifar_Test_{}__1.png'.format(epoch))
            #files.download('cifar_Test_{}__1.png'.format(epoch))

            # Random sample with more variation

            #sample = model.module.sample(y_onehot, eps_std=2.0)
            #grid = utils.make_grid(sample)

            #utils.save_image(grid, '/content/gdrive/My Drive/Colab Notebooks/glow_samples/samples5/cifar_Test_{}__2.png'.format(epoch))
            #files.download('cifar_Test_{}__2.png'.format(epoch))
            
            # ------------------------------------------------------------------
            # Interpolation
            # ------------------------------------------------------------------

            #center = z[torch.randint(0, z.shape[0] - 1, (1,)),:]

            # Randomize direction
            #dim_weights = torch.rand([z.shape[1], z.shape[2], z.shape[3]]).cuda()
            #dim_weights = dim_weights / torch.sum(dim_weights)

            #linspace = torch.linspace(-1000, 1000, z.shape[0])
            #interpolation = torch.empty(z.shape).cuda()
            #for i, l in enumerate(linspace):
            #    interpolation[i,:] = center + l * dim_weights

            #sample, _ = model.module(interpolation, True, logdet)
            #grid = utils.make_grid(sample)

            #utils.save_image(grid, '/content/gdrive/My Drive/Colab Notebooks/glow_samples/samples5/cifar_Test_{}__3.png'.format(epoch))
            #files.download('cifar_Test_{}__3.png'.format(epoch))

            c1 = z[torch.randint(0, z.shape[0] - 1, (1,)),:]
            c2 = z[torch.randint(0, z.shape[0] - 1, (1,)),:]

            diff = c1 - c2

            linspace = torch.linspace(0, 1, z.shape[0])
            interpolation = torch.empty(z.shape).cuda()
            for i, l in enumerate(linspace):
                interpolation[i,:] = c1 - l * diff

            sample, _ = model.module(interpolation, True, logdet)
            grid = utils.make_grid(sample)

            utils.save_image(grid, '/content/gdrive/My Drive/Colab Notebooks/glow_samples/samples6/cifar_Test_{}__4.png'.format(epoch))
            #files.download('cifar_Test_{}__3.png'.format(epoch))

            # ------------------------------------------------------------------
            # Reconstruction
            # ------------------------------------------------------------------


            # Store no samples

            n_stored_samples = 0

            x_hat = model.module(z, True, logdet, y_onehot, n_stored_samples=n_stored_samples)[0]
            grid = utils.make_grid(x_hat)

            utils.save_image(grid, '/content/gdrive/My Drive/Colab Notebooks/glow_samples/samples5/cifar_Test_Recon{}_0%.png'.format(epoch))
            #files.download('cifar_Test_Recon{}_0%.png'.format(epoch))

            # Store half of the samples

            #n_stored_samples = LEVELS // 2

            #x_hat = model.module(z, True, logdet, y_onehot, n_stored_samples=n_stored_samples)[0]
            #grid = utils.make_grid(x_hat)

            #utils.save_image(grid, '/content/gdrive/My Drive/Colab Notebooks/glow_samples/samples5/cifar_Test_Recon{}_50%.png'.format(epoch))
            #files.download('cifar_Test_Recon{}_50%.png'.format(epoch))

            # Store all samples

            #n_stored_samples = LEVELS

            #x_hat = model.module(z, True, logdet, y_onehot, n_stored_samples=n_stored_samples)[0]
            #grid = utils.make_grid(x_hat)

            #utils.save_image(grid, '/content/gdrive/My Drive/Colab Notebooks/glow_samples/samples5/cifar_Test_Recon{}_100%.png'.format(epoch))
            #files.download('cifar_Test_Recon{}_100%.png'.format(epoch))


            grid = utils.make_grid(last_img)

            utils.save_image(grid, '/content/gdrive/My Drive/Colab Notebooks/glow_samples/samples6/cifar_Test_Target.png')
            #files.download('cifar_Test_Target.png')


Epoch: 0
Batch: 0, iteration took 1.9196, Loss: 5.97113037109375
Batch: 1, iteration took 1.4945, Loss: 40.90465545654297
Batch: 2, iteration took 1.4613, Loss: 40.34504699707031
Batch: 3, iteration took 1.4671, Loss: 41.63421630859375
Batch: 4, iteration took 1.4730, Loss: 41.45762252807617
Batch: 5, iteration took 1.4951, Loss: 40.521121978759766
Batch: 6, iteration took 1.4863, Loss: 40.46720886230469
Batch: 7, iteration took 1.4964, Loss: 40.346168518066406


KeyboardInterrupt: ignored

## References

Kingma,  D.  P.  &  Dhariwal,  P.  (2018),  Glow:  Generative  flow  with  invertible  1x1  convolutions, in ‘Advances in Neural Information Processing Systems’, pp. 10215–10224.

Dinh, L., Sohl-Dickstein, J. & Bengio, S. (2016), ‘Density estimation using real nvp’ ,arXiv preprintarXiv:1605.08803.


