# Building Complex Computational Graphs

## Machine Learning Models

Machine learning models are often described as parameterized nonlinear mathematical functions that transform an input into a desired output. A less common view on ML models is from a computational perspective where a model is simply a sequence of computations performed on some data.

<img src="imgs/model_theoretical_perspective.png" style="float: left;" width="200px"/>
<img src="imgs/model_computational_perspective.png" width="200px"/>

Computations performed by a model are usually represented as a directed acyclic graph

<img src="https://miro.medium.com/max/2434/1*_rCyzi7fQzc_Q1gCqSLM1g.png" width="500px"/>

which can be quite complex for modern ML models. 
Data is usually represented as multidimensional arrays (tensors)

<img src="imgs/multidimensional_tensors.png" width="500px"/>

and it is quite common to work with 4D or 5D tensors.
The highly abstract nature of these data structures makes the development of ML models a demanding problem where it's hard to find bugs in the code itself.

<img src="imgs/karpathy_tweet.png" width="500px"/>

## No program is bug-free

The standard research and development cycle 

<br>
<img src="imgs/research_cycle_ok.png" width="500px"/>
<br>

is implicitly dependent on the correcntess of your code

<br>
<img src="imgs/research_cycle_bug.png" width="500px"/>
<br>

so instead of a single loop there are two separate ones
coupled through your code

<br>
<img src="imgs/research_cycle_chain.png" width="500px"/>
<br>

so being able to inspect and reason with the abstract data structures used for implementing an ML model is essential. 

## Best practices

1. Start from the simplest possible model
2. Train on a single mini-batch to ensure your code is correct and your model is sufficiently expressive
3. Ensure your input data is normalized and encoded appropriately
4. Confirm you are using the right loss for your problem
5. Check intermediate outputs and connections

<span style="color: #008cbb">
    <b>At each step visualize as much as you can!</b>
</span>

## Let's build a DCGAN

[Paper](https://arxiv.org/abs/1511.06434)

![DCGAN](imgs/dcgan.png)

GANs are challenging to train since the optimization procedure is looking for a saddle point rather than a convex optima.

![offconvex min max problem](http://www.offconvex.org/assets/GDA_spiral_2.gif)

## Getting the code

We will train a DCGAN to generate images from the CelebA aligned dataset http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html 

You can find all code from the following session and the trained models at https://github.com/efemarai/build_dcgan.

## Efemarai
We will be using [Efemarai](https://efemarai.com) to visualize all steps through the implmentation of the DCGAN. [Efemarai](https://efemarai.com) is a platform for testing and debugging ML code. If you want to try it out just sign up for the *free* Personal tier and start saving tons of hours fighting elusive bugs.

## Let's get started

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as tf


import efemarai as ef
ef.notebook()

Do not forget to run the local `efemarai` daemon.

## Generator

Generator takes in a random vector and should generate an image. For this we will use [transposed convolutions](https://towardsdatascience.com/transposed-convolution-demystified-84ca81b4baba). 

![Transposed Convolution](https://miro.medium.com/max/700/1*faRskFzI7GtvNCLNeCN8cg.png)

Let's see an example.

In [2]:
nz = 16
noise = torch.randn(1, nz, 1, 1)

convT = nn.ConvTranspose2d(
    in_channels=nz, 
    out_channels=64, 
    kernel_size=4, 
    stride=1, 
    padding=0,
)

print(convT(noise))

tensor([[[[ 0.0817,  0.1159,  0.1244, -0.0026],
          [ 0.0735,  0.0646, -0.0672, -0.0087],
          [-0.0300,  0.2063, -0.0606, -0.0689],
          [ 0.0209, -0.0029, -0.0517, -0.0402]],

         [[ 0.1119, -0.0153,  0.0446,  0.0726],
          [-0.0798,  0.0302, -0.0007, -0.0745],
          [ 0.1081, -0.0931, -0.0362,  0.0653],
          [-0.0148, -0.0075, -0.0170,  0.0770]],

         [[-0.0115, -0.1256,  0.0818,  0.0549],
          [-0.0907, -0.1071,  0.0591, -0.0342],
          [ 0.0119,  0.0514,  0.0248, -0.0155],
          [-0.0887,  0.0594, -0.1385, -0.0716]],

         ...,

         [[ 0.0344,  0.1427, -0.0692, -0.0444],
          [-0.1189,  0.0543, -0.0404,  0.0038],
          [-0.0147, -0.0413, -0.0496, -0.0010],
          [-0.1101,  0.0212, -0.0489,  0.1232]],

         [[-0.0203, -0.0809,  0.0800, -0.0648],
          [ 0.0002, -0.2017, -0.0872,  0.0648],
          [-0.0676, -0.0369,  0.0148,  0.0029],
          [-0.0779, -0.0248,  0.0022,  0.0272]],

         [[ 0.0

When working with tensors printing tensors rarely gives you useful information. How about something more visual?

In [3]:
ef.print(convT(noise))

Using Efemarai's `print()` function automatically generates a 3D visualizaton of the tensors where you can easily inspect any element or check out the values distribution with a few mouse clicks. Up to 6D tensors are supported.

![Print Tensor](imgs/ef_print.png)

Seeing the resulting tensor from a computation is useful, but what's even more useful is to explore the computation itself. 

In [4]:
with ef.scan():
    output = convT(noise)

![Graph Scan](imgs/ef_scan.png)

Now lets create our generator module starting with the first layer of transformations

In [5]:
class Generator(nn.Module):
    def __init__(self, ngf=64):
        super(Generator, self).__init__()
        self.network = nn.Sequential(
            # input is going into a convolution
            nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(),
        )


    def forward(self, input):
        return self.network(input)

and explore what's happening with the input noise passing through it

In [6]:
gen = Generator(nz)
with ef.scan(wait=False):
    output = gen(noise)

The complete generator contains 5 transposed convolution layers and outputs an image of size `(3, 64, 64)`.

In [7]:
class Generator(nn.Module):
    def __init__(self, ngf=64):
        super(Generator, self).__init__()
        self.network = nn.Sequential(
            # input is going into a convolution
            nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(),
            # state size. (ngf*8) x 4 x 4
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(),
            # state size. (ngf*4) x 8 x 8
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(),
            # state size. (ngf*2) x 16 x 16
            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(),
            # state size. (ngf) x 32 x 32
            nn.ConvTranspose2d(ngf, 3, 4, 2, 1, bias=False),
            nn.Tanh()
            # state size. (nc) x 64 x 64
        )


    def forward(self, input):
        return self.network(input)


Here is what the final computational graph of the generator looks like

In [9]:
gen = Generator()
with ef.scan(wait=False):
    output = gen(noise)

![Generator Graph](imgs/generator_graph.png)

You can easily confirm that
* all the layers are connected as expected
* all computations go as expected - there are no NaNs or Infs
* the input vector is correctly transformed into a 3x64x64 image

### Overfit to a small batch

Load images from the CelebA dataset and create a small batch of `(noise, image)` that we are going to overfit to in order to make sure that our generator can generate images.

In [9]:
nz = 100

dataset = datasets.ImageFolder(
    root="data",
    transform=tf.Compose([
        tf.Resize(64),
        tf.CenterCrop(64),
        tf.ToTensor(),
        tf.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]))

class GeneratorMiniBatch(torch.utils.data.Dataset):
    def __init__(self, size): 
        self.size = size
        self.noise = torch.randn(size, nz, 1, 1)

    def __len__(self):
        return self.size

    def __getitem__(self, idx):
        return self.noise[idx], dataset[idx][0]

generator_minibatch = GeneratorMiniBatch(50)   

dataloader = torch.utils.data.DataLoader(
    generator_minibatch, batch_size=10, shuffle=True,
)

Loop over the minibatch to train and scan the execution periodically.

In [83]:
generator = Generator()

optimizer = optim.Adam(
    generator.parameters(), lr=1e-3, 
)

iteration = 0
for epoch in range(100):
    for noise, image in dataloader:
        optimizer.zero_grad()
        
        with ef.scan(
            iteration, 
            enabled=iteration % 50 == 0, 
            wait=False,
        ):
            output = generator(noise)
            loss = (output - image).square().mean()
            loss.backward()
        
        optimizer.step()
        
        iteration += 1
        
    if epoch % 10 == 0:
        print(loss.item())

Tensor of shape (512, 256, 4, 4) cannot be visualized with the current GPU.
Tensor of shape (512, 256, 4, 4) cannot be visualized with the current GPU.
0.465360552072525
Tensor of shape (512, 256, 4, 4) cannot be visualized with the current GPU.
Tensor of shape (512, 256, 4, 4) cannot be visualized with the current GPU.
0.12741057574748993
Tensor of shape (512, 256, 4, 4) cannot be visualized with the current GPU.
Tensor of shape (512, 256, 4, 4) cannot be visualized with the current GPU.
0.04939030110836029
Tensor of shape (512, 256, 4, 4) cannot be visualized with the current GPU.
Tensor of shape (512, 256, 4, 4) cannot be visualized with the current GPU.
0.05862098187208176
Tensor of shape (512, 256, 4, 4) cannot be visualized with the current GPU.
Tensor of shape (512, 256, 4, 4) cannot be visualized with the current GPU.
0.03839258477091789
Tensor of shape (512, 256, 4, 4) cannot be visualized with the current GPU.
Tensor of shape (512, 256, 4, 4) cannot be visualized with the cur

## Discriminator

In [27]:
class Discriminator(nn.Module):
    def __init__(self, ndf=64):
        super(Discriminator, self).__init__()
        self.network = nn.Sequential(
            # input is (nc) x 64 x 64
            nn.Conv2d(3, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=False),
            # state size. (ndf) x 32 x 32
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=False),
            # state size. (ndf*2) x 16 x 16
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=False),
            # state size. (ndf*4) x 8 x 8
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=False),
            # state size. (ndf*8) x 4 x 4
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid(),
        )

    def forward(self, input):
        assert torch.all(-1.0 <= input) and torch.all(input <= 1.0)
        return self.network(input)

### Overfit to a small batch

In [163]:
class DiscriminatorMiniBatch(torch.utils.data.Dataset): 
    def __init__(self):
        self.size = 2 * len(generator_minibatch)
        self.fakes = generator(generator_minibatch.noise).detach()
    def __len__(self):
        return self.size

    def __getitem__(self, idx):
        if idx < len(self.fakes):
            return self.fakes[idx], torch.zeros(1, dtype=torch.float).squeeze()
        else:        
            _, image = generator_minibatch[idx - len(self.fakes)]
            return image, torch.ones(1, dtype=torch.float).squeeze()


discriminator_minibatch = DiscriminatorMiniBatch()   

dataloader = torch.utils.data.DataLoader(
    discriminator_minibatch, batch_size=10, shuffle=True,
)

In [171]:
discriminator = Discriminator()

optimizer = optim.Adam(
    discriminator.parameters(), lr=1e-4, 
)

iteration = 0
for epoch in range(50):
    for image, label in dataloader:
        optimizer.zero_grad()
        
        output = discriminator(image).view(-1)
        loss = F.binary_cross_entropy(output, label)
        loss.backward()

        optimizer.step()
        
        iteration += 1
        
    if epoch % 10 == 0:
        print(loss.item())

AssertionError: 

# Train both networks

The GAN training is dependent on balancing the min max criterion.

In [None]:
# Training Loop
ef.deregister_assertions(ef.assertions.NoNonZeroGradientsAssertion)
print("Starting Training Loop...")
discriminator = Discriminator()
generator = Generator()

# From the current generator we assumed
fake_label = 0.0
real_label = 1.0

batch_size = 10

dataloader = torch.utils.data.DataLoader(
    dataset, batch_size=batch_size, shuffle=True)

# Setup Adam optimizers for both G and D
optimizerD = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizerG = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))

for epoch in range(5):
    # For each batch in the dataloader
    for i, data in enumerate(dataloader, 0):
        ef.inspect(data[0])
        with ef.scan(i):
            ############################
            # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
            ###########################
            ## Train with all-real batch

            discriminator.zero_grad()
            # Format batch
            label = torch.full((batch_size,), real_label, dtype=torch.float)
            # Forward pass real batch through D
            output = discriminator(data[0]).view(-1)
            # Calculate loss on all-real batch
            errD_real = F.binary_cross_entropy(output, label)

            ## Train with all-fake batch
            # Generate batch of latent vectors
            noise = torch.randn(batch_size, nz, 1, 1)
            # Generate fake image batch with G
            fake = generator(noise)
            label = torch.full((batch_size,), fake_label, dtype=torch.float)
            # Classify all fake batch with D
            output = discriminator(fake.detach()).view(-1)
            # Calculate D's loss on the all-fake batch
            errD_fake = F.binary_cross_entropy(output, label)

            errD = errD_real + errD_fake
            errD.backward()

        with ef.scan(i):
            # Update D
            optimizerD.step()

            ############################
            # (2) Update G network: maximize log(D(G(z)))
            ###########################
            generator.zero_grad()

            # fake labels are real for generator cost
            label = torch.full((batch_size,), real_label, dtype=torch.float)
            # Since we just updated D, perform another forward pass of all-fake batch through D
            output = discriminator(fake).view(-1)
            # Calculate G's loss based on this output
            errG = F.binary_cross_entropy(output, label)
            # Calculate gradients for G
            errG.backward()

            # Update G
            optimizerG.step()

        # Output training stats
        if i % 100 == 0:
            print("[%d/5][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f"
                % (epoch, i, len(dataloader), errD.item(), errG.item()))

print('Done')

# Let's see the training loss

![Training loss](models/training_loss.svg)

![Image Comparison - Real vs Fake](models/real_vs_gen_images.svg)