# Building Complex Computational Graphs

In [29]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as tf


import efemarai as ef
ef.notebook()

## Generator

Generator takes in a random vector and should generate an image. For this we will use [transposed convolutions](https://towardsdatascience.com/transposed-convolution-demystified-84ca81b4baba). 

![Transposed Convolution](https://miro.medium.com/max/700/1*faRskFzI7GtvNCLNeCN8cg.png)

Let's see an example.

In [23]:
nz = 16
noise = torch.randn(1, nz, 1, 1)

convT = nn.ConvTranspose2d(
    in_channels=nz, 
    out_channels=64, 
    kernel_size=4, 
    stride=1, 
    padding=0,
)

print(convT(noise))

tensor([[[[ 0.0070,  0.1147, -0.0082,  0.0899],
          [-0.0530, -0.0626,  0.0393,  0.1251],
          [ 0.1381,  0.0016, -0.0353, -0.0460],
          [-0.0163,  0.0628, -0.1509, -0.1345]],

         [[-0.0332, -0.0538, -0.0272, -0.0673],
          [ 0.0603,  0.0927,  0.0434, -0.0034],
          [-0.0275,  0.1481, -0.0456,  0.1091],
          [-0.0347, -0.0176,  0.0646, -0.0299]],

         [[-0.0167,  0.0346,  0.0582, -0.0888],
          [-0.0999,  0.0241,  0.0198,  0.0375],
          [ 0.0987,  0.0226, -0.0047,  0.0071],
          [-0.0548,  0.0366,  0.0696,  0.0116]],

         ...,

         [[ 0.1380, -0.1157,  0.0782,  0.0628],
          [-0.0769, -0.0766, -0.0798,  0.0277],
          [ 0.0638, -0.1095,  0.0535, -0.1067],
          [-0.0415, -0.0228, -0.0058, -0.0490]],

         [[-0.0453,  0.0977,  0.0230,  0.1073],
          [-0.0495,  0.0333, -0.0700,  0.0720],
          [-0.0706, -0.0074, -0.0627, -0.0353],
          [-0.0118,  0.0436, -0.0879,  0.0081]],

         [[ 0.0

When working with tensors printing tensors rarely gives you useful information. How about something more visual?

In [25]:
ef.print(convT(noise))

Using Efemarai's `print()` function automatically generates a 3D visualizaton of the tensors where you can easily inspect any element or check out the values distribution with a few mouse clicks. Up to 6D tensors are supported.

![Print Tensor](imgs/ef_print.png)

Seeing the resulting tensor from a computation is useful, but what's even more useful is to explore the computation itself. 

In [25]:
with ef.scan():
    output = convT(noise)

![Graph Scan](imgs/ef_scan.png)

Now lets create our generator module starting with the first layer of transformations

In [32]:
class Generator(nn.Module):
    def __init__(self, ngf=64):
        super(Generator, self).__init__()
        self.network = nn.Sequential(
            # input is going into a convolution
            nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(),
        )


    def forward(self, input):
        return self.network(input)

and explore what's happening with the input noise passing through it

In [30]:
gen = Generator(nz)
with ef.scan(wait=False):
    output = gen(noise)

The complete generator contains 5 transposed convolution layers and outputs an image of size `(3, 64, 64)`.

In [45]:
class Generator(nn.Module):
    def __init__(self, ngf=64):
        super(Generator, self).__init__()
        self.network = nn.Sequential(
            # input is going into a convolution
            nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(),
            # state size. (ngf*8) x 4 x 4
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(),
            # state size. (ngf*4) x 8 x 8
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(),
            # state size. (ngf*2) x 16 x 16
            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(),
            # state size. (ngf) x 32 x 32
            nn.ConvTranspose2d(ngf, 3, 4, 2, 1, bias=False),
            nn.Tanh()
            # state size. (nc) x 64 x 64
        )


    def forward(self, input):
        return self.network(input)


Here is what the final computational graph of the generator looks like

In [37]:
gen = Generator()
with ef.scan(wait=False):
    output = gen(noise)

![Generator Graph](imgs/generator_graph.png)

You can easily confirm that
* all the layers are connected as expected
* all computations go as expected - there are no NaNs or Infs
* the input vector is correctly transformed into a 3x64x64 image

### Overfit to a small batch

Load images from the CelebA dataset and create a small batch of `(noise, image)` that we are going to overfit to in order to make sure that our generator can generate images.

In [82]:
nz = 100

dataset = datasets.ImageFolder(
    root="data",
    transform=tf.Compose([
        tf.Resize(64),
        tf.CenterCrop(64),
        tf.ToTensor(),
        tf.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]))

class GeneratorMiniBatch(torch.utils.data.Dataset):
    def __init__(self, size): 
        self.size = size
        self.noise = torch.randn(size, nz, 1, 1)

    def __len__(self):
        return self.size

    def __getitem__(self, idx):
        return self.noise[idx], dataset[idx][0]

generator_minibatch = GeneratorMiniBatch(50)   

dataloader = torch.utils.data.DataLoader(
    generator_minibatch, batch_size=10, shuffle=True,
)

Loop over the minibatch to train and scan the execution periodically.

In [83]:
generator = Generator()

optimizer = optim.Adam(
    generator.parameters(), lr=1e-3, 
)

iteration = 0
for epoch in range(100):
    for noise, image in dataloader:
        optimizer.zero_grad()
        
        with ef.scan(
            iteration, 
            enabled=iteration % 50 == 0, 
            wait=False,
        ):
            output = generator(noise)
            loss = (output - image).square().mean()
            loss.backward()
        
        optimizer.step()
        
        iteration += 1
        
    if epoch % 10 == 0:
        print(loss.item())

Tensor of shape (512, 256, 4, 4) cannot be visualized with the current GPU.
Tensor of shape (512, 256, 4, 4) cannot be visualized with the current GPU.
0.465360552072525
Tensor of shape (512, 256, 4, 4) cannot be visualized with the current GPU.
Tensor of shape (512, 256, 4, 4) cannot be visualized with the current GPU.
0.12741057574748993
Tensor of shape (512, 256, 4, 4) cannot be visualized with the current GPU.
Tensor of shape (512, 256, 4, 4) cannot be visualized with the current GPU.
0.04939030110836029
Tensor of shape (512, 256, 4, 4) cannot be visualized with the current GPU.
Tensor of shape (512, 256, 4, 4) cannot be visualized with the current GPU.
0.05862098187208176
Tensor of shape (512, 256, 4, 4) cannot be visualized with the current GPU.
Tensor of shape (512, 256, 4, 4) cannot be visualized with the current GPU.
0.03839258477091789
Tensor of shape (512, 256, 4, 4) cannot be visualized with the current GPU.
Tensor of shape (512, 256, 4, 4) cannot be visualized with the cur

## Discriminator

In [172]:
class Discriminator(nn.Module):
    def __init__(self, ndf=64):
        super(Discriminator, self).__init__()
        self.network = nn.Sequential(
            # input is (nc) x 64 x 64
            nn.Conv2d(3, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=False),
            # state size. (ndf) x 32 x 32
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=False),
            # state size. (ndf*2) x 16 x 16
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=False),
            # state size. (ndf*4) x 8 x 8
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=False),
            # state size. (ndf*8) x 4 x 4
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid(),
        )

    def forward(self, input):
        assert torch.all(0.0 < input) and torch.all(input < 1.0)
        return self.network(input)

### Overfit to a small batch

In [163]:
class DiscriminatorMiniBatch(torch.utils.data.Dataset): 
    def __init__(self):
        self.size = 2 * len(generator_minibatch)
        self.fakes = generator(generator_minibatch.noise).detach()
    def __len__(self):
        return self.size

    def __getitem__(self, idx):
        if idx < len(self.fakes):
            return self.fakes[idx], torch.zeros(1, dtype=torch.float).squeeze()
        else:        
            _, image = generator_minibatch[idx - len(self.fakes)]
            return image, torch.ones(1, dtype=torch.float).squeeze()


discriminator_minibatch = DiscriminatorMiniBatch()   

dataloader = torch.utils.data.DataLoader(
    discriminator_minibatch, batch_size=10, shuffle=True,
)

In [171]:
discriminator = Discriminator()

optimizer = optim.Adam(
    discriminator.parameters(), lr=1e-4, 
)

iteration = 0
for epoch in range(50):
    for image, label in dataloader:
        optimizer.zero_grad()
        
        output = discriminator(image).view(-1)
        loss = F.binary_cross_entropy(output, label)
        loss.backward()

        optimizer.step()
        
        iteration += 1
        
    if epoch % 10 == 0:
        print(loss.item())

AssertionError: 