<a href="https://colab.research.google.com/github/borundev/pytorch_examples/blob/kaggle-dataset/GanLogic.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Gan Logic

This notebook contains my notes on gan logic without all the complexity that actually makes it work. So don't expect to get useful output.

In [2]:
!pip install -Uqqq pytorch_lightning
!pip install -Uqqq wandb

[K     |████████████████████████████████| 686kB 11.8MB/s 
[K     |████████████████████████████████| 829kB 35.2MB/s 
[K     |████████████████████████████████| 645kB 62.1MB/s 
[K     |████████████████████████████████| 102kB 9.2MB/s 
[K     |████████████████████████████████| 1.3MB 42.4MB/s 
[K     |████████████████████████████████| 296kB 39.2MB/s 
[K     |████████████████████████████████| 143kB 53.8MB/s 
[?25h  Building wheel for future (setup.py) ... [?25l[?25hdone
  Building wheel for idna-ssl (setup.py) ... [?25l[?25hdone
[K     |████████████████████████████████| 1.9MB 11.4MB/s 
[K     |████████████████████████████████| 102kB 8.7MB/s 
[K     |████████████████████████████████| 133kB 49.2MB/s 
[K     |████████████████████████████████| 102kB 8.4MB/s 
[K     |████████████████████████████████| 163kB 54.3MB/s 
[K     |████████████████████████████████| 71kB 5.3MB/s 
[?25h  Building wheel for watchdog (setup.py) ... [?25l[?25hdone
  Building wheel for subprocess32 (setup.p

In [3]:
import torch
from torch import nn
from torch.optim import Adam
import numpy as np
import pytorch_lightning as pl
import torch.nn.functional as F
import torchvision
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

In [4]:
# A class to make hooks on modules and remove them when the destructor is called
# I use this to diagnose when the forward and backward of a module are called

class Hook():
    def __init__(self,name,m):
        self.name=name
        self.m=m
        self.ref_fw=m.register_forward_hook(self.forward_hook)
        self.ref_bk=m.register_backward_hook(self.backward_hook)
        

    def remove(self):
        if hasattr(self,'ref_fw'):
           self.ref_fw.remove()
           self.ref_bk.remove()

    def forward_hook(self,*args,**kwargs):
        print('forward called on',self.name)

    def backward_hook(self,*args,**kwargs):
        print('backward called on',self.name)

    def __del__(self,*args,**kwargs):
        self.remove()

In [5]:
# Two functions to keep track of grads signifying if a backward call populated the grad tensor

def gen_params():
    r=torch.tensor([0])
    try:
        r=torch.tensor([torch.sum(p.grad != 0) for p in gen.parameters()]).sum()
    except TypeError as e:
        pass
    return r

def disc_params():
    r=torch.tensor([0])
    try:
        r=torch.tensor([torch.sum(p.grad != 0) for p in disc.parameters()]).sum()
    except TypeError as e:
        pass
    return r

In [8]:
# Callback that has methods called right after backward call(s) and before calling zero grad

class NewCallback(pl.callbacks.Callback):

    def on_after_backward(self, trainer, pl_module):
        print('after_backward generator',gen_params())
        print('after_backward discriminator',disc_params())
        print('-' * 20)

    def on_before_zero_grad(self, trainer, pl_module, optimizer):
        #print('before_zero_grad_generator', optimizer.name,gen_params())
        #print('before_zero_grad_discriminator', optimizer.name,disc_params())
        #print('*' * 20)
        return


In [9]:
class G(nn.Module):
    def __init__(self):
        super().__init__()
        self.W=torch.nn.Linear(1,1,bias=False)


    def forward(self,x):
        return self.W(x)

class D(nn.Module):
    def __init__(self):
        super().__init__()
        self.W=torch.nn.Linear(1,1,bias=False)
        self.sigmoid = nn.Sigmoid()
    
    def forward(self,x):
        return self.sigmoid(self.W(x))

In [10]:
gen=G()
disc=D()

In [12]:
h_gen=Hook('gen',gen.W)
h_disc=Hook('disc',disc.W)

In [13]:
X=torch.rand((10,1))
y=torch.tensor(np.random.choice([0,1],10))
data=DataLoader(list(zip(X,y)),batch_size=2)

In [14]:
class Gan(pl.LightningModule):

    def __init__(self, gen, disc,
                 latent_dim: int = 1,
                 **kwargs):
        super().__init__()
        self.save_hyperparameters()
        self.gen = gen
        self.disc = disc

    def forward(self, x):
        return self.gen(x)

    def training_step(self, batch, batch_idx, optimizer_idx):

        if optimizer_idx == 0:
            r = self.training_step_generator(batch)
        elif optimizer_idx == 1:
            r = self.training_step_discriminator(batch)
        return r

    def training_step_generator(self, batch):

        xs, _ = batch
        batch_size = xs.shape[0]

        # note z.type_as(imgs) not only type_casts but also puts on the same device
        z = torch.randn(batch_size, self.hparams.latent_dim, device=self.device)
        generated_xs = self(z)
        generated_y_score = self.disc(generated_xs)
        generated_y = torch.ones(batch_size, 1, device=self.device)
        g_loss = self.adversarial_loss(generated_y_score, generated_y)

        self.log('generator/g_loss', g_loss, prog_bar=True)

        return {'loss': g_loss}

    def training_step_discriminator(self, batch):

        xs, _ = batch
        batch_size = xs.shape[0]

        # note z.type_as(imgs) not only type_casts but also puts on the same device
        z = torch.randn(batch_size, self.hparams.latent_dim, device=self.device)
        generated_xs = self(z)
        generated_y_score = self.disc(generated_xs)
        generated_y = torch.zeros(batch_size, 1, device=self.device)
        g_loss = self.adversarial_loss(generated_y_score, generated_y)

        real_y_score = self.disc(xs)
        real_y = torch.ones(batch_size, 1, device=self.device)
        r_loss = self.adversarial_loss(real_y_score, real_y)

        d_loss = (r_loss + g_loss) / 2.0

        self.log('discriminator/d_loss', d_loss, prog_bar=True)

        return {'loss': d_loss}

    def adversarial_loss(self, y_hat, y):
        return F.binary_cross_entropy(y_hat, y)

    def configure_optimizers(self):
        opt_g = torch.optim.Adam(self.gen.parameters())
        opt_g.name = 'G'
        opt_d = torch.optim.Adam(self.disc.parameters())
        opt_d.name = 'D'
        return [opt_g, opt_d], []


The reason I got around to looking at this is because the way GANs work is that we train the generator and the discriminator one by one. Lets look at the generator first:

1) We create some random noise and pass it through the generator. Thus the parameters of the generator are leaf nodes in the graph now. We then pass the result through the discriminator and thus the parameters of the discriminator are also leaf nodes in the graph now. Now when we call backward on the loss both these leaf nodes will get their grad tensors populated (unless we set `requires_grad=False`). At this step the optimizer for the generator will use the generator grads to update the weights and set those grads to zero for the next round. However, the discriminator grads are still set to non-zero values (unless it had `requires_grad` set to False).

Now let's look at the discriminator. 

2) We create some random noise and pass it through the generator. This makes the paramaters of the generator leaf nodes of the graph. We then pass it through the discriminator making the discriminator's parameters also leaf nodes. Separately we pass real images through the discriminator making a separate computational graph with the discriminator's parameters as leaf nodes again. Now when we call backward on the losses of both the graphs, the grads of the parameters of both the generator and discriminator get populated. If we had not been careful of either seeting `requires_grad=False` for the discriminator in step (1) or to have reset the grads to zero then we will have the grads accumulate and the optimizer will end up using the wrong grads. The same goes for the grads for the generator computed in this step.

We see this in action below

In [15]:
batch=next(iter(data))

In [16]:
gan=Gan(gen,disc)

In [17]:
l=gan.training_step(batch,0,0)
l['loss'].backward()
gen_params(),disc_params()

forward called on gen
forward called on disc
backward called on disc
backward called on gen


(tensor(1), tensor(1))

We see that there is a forward call on gen and disc followed by a backward call on disc and gen and both of them have their parameters' grads populated

In [18]:
gan.zero_grad()

In [19]:
l=gan.training_step(batch,0,1)
l['loss'].backward()
gen_params(),disc_params()

forward called on gen
forward called on disc
forward called on disc
backward called on disc
backward called on disc
backward called on gen


(tensor(1), tensor(1))

We see that there is a forward call on gen and two forward calls on disc followed by backward calls in reverse order.  Both of them have their parameters' grads populated.

In [20]:
gan.zero_grad()

Now we make a run of the pytorch_lighning trainer 

In [21]:
trainer = pl.Trainer(gpus=0,
                        max_epochs=1,
                        limit_train_batches=1,
                        limit_val_batches=0,
                        progress_bar_refresh_rate=50,
                        callbacks=[NewCallback()]
                        )
trainer.fit(gan,data)

GPU available: False, used: False
TPU available: None, using: 0 TPU cores

  | Name | Type | Params
------------------------------
0 | gen  | G    | 1     
1 | disc | D    | 1     
------------------------------
2         Trainable params
0         Non-trainable params
2         Total params


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

forward called on gen
forward called on disc
backward called on disc
backward called on gen
after_backward generator tensor(1)
after_backward discriminator tensor(0)
--------------------
forward called on gen
forward called on disc
forward called on disc
backward called on disc
backward called on disc
after_backward generator tensor(0)
after_backward discriminator tensor(1)
--------------------



1

We see that on the generator's `training_step` we have a forward call on gen and disc followed by backward calls in reverse order. However, only the generator's parameters have their grads populated.

On the discriminator's `training_step` we have a forward call on gen and two on disc but only two backward calls on disc and none on gen. Only the discriminator's parameters have their grad tensor populated.

The populating of the parameter's grads is being done in the correct way but looking at whoose backwards are being called is instructive. We reproduce it below manually.

In [22]:
gan.zero_grad()

In [25]:
o1,o2 = gan.optimizers()

In [26]:
o1.param_groups[0]['params'][0].requires_grad_(False)
o2.param_groups[0]['params'][0].requires_grad_(False)
o1.param_groups[0]['params'][0].requires_grad_(True)
l=gan.training_step(batch,0,0)
l['loss'].backward()
print(gen_params(),disc_params())
o1.zero_grad()

forward called on gen
forward called on disc
backward called on disc
backward called on gen
tensor(1) tensor(0)


In [27]:
print(gen_params(),disc_params())

tensor(0) tensor(0)


In [28]:
o1.param_groups[0]['params'][0].requires_grad_(False)
o2.param_groups[0]['params'][0].requires_grad_(False)
o2.param_groups[0]['params'][0].requires_grad_(True)
l=gan.training_step(batch,0,1)
l['loss'].backward()
print(gen_params(),disc_params())


forward called on gen
forward called on disc
forward called on disc
backward called on disc
backward called on disc
tensor(0) tensor(1)


So we see that in the generator's forward and backward, while we set the discriminator's `requires_grad=False`, it still had a `backward` called on it and that is because the gradients needed to flow down to the generator's parameters.

However, during the discriminator's forward and backward, the generator's parameters had the `requires_grad` set to False and didn't have backward called on it because it was not required in the backprop to compute anyone's gradients.