Skip to content

Commit

Permalink
Add section on GAN (#115)
Browse files Browse the repository at this point in the history
* Create gan.py

Add the Generator and Discriminator

* add train loop

some of this was taken from https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html

* Update architectures.rst
  • Loading branch information
devksingh4 committed Dec 24, 2020
1 parent 64e3cff commit c1e22a0
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 4 deletions.
76 changes: 76 additions & 0 deletions code/gan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import torch
import torch.nn as nn

class Generator(nn.Module):
def __init__(self):
super()
self.net = nn.Sequential(
nn.ConvTranspose2d( 200, 32 * 8, 4, 1, 0, bias=False),
nn.BatchNorm2d(32 * 8),
nn.ReLU(),
nn.ConvTranspose2d(32 * 8, 32 * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(32 * 4),
nn.ReLU(),
nn.ConvTranspose2d( 32 * 4, 32 * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(32 * 2),
nn.ReLU(),
nn.ConvTranspose2d( 32 * 2, 32, 4, 2, 1, bias=False),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.ConvTranspose2d( 32, 1, 4, 2, 1, bias=False),
nn.Tanh()
)
def forward(self, tens):
return self.net(tens)

class Discriminator(nn.Module):
def __init__(self):
super()
self.net = nn.Sequential(
nn.Conv2d(1, 32, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2),
nn.Conv2d(32, 32 * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(32 * 2),
nn.LeakyReLU(0.2),
nn.Conv2d(32 * 2, 32 * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(32 * 4),
nn.LeakyReLU(0.2),
# state size. (32*4) x 8 x 8
nn.Conv2d(32 * 4, 32 * 8, 4, 2, 1, bias=False),
nn.BatchNorm2d(32 * 8),
nn.LeakyReLU(0.2),
# state size. (32*8) x 4 x 4
nn.Conv2d(32 * 8, 1, 4, 1, 0, bias=False),
nn.Sigmoid()
)

def forward(self, tens):
return self.net(tens)

def train(netD, netG, loader, loss_func, optimizerD, optimizerG, num_epochs):
netD.train()
netG.train()
device = "cuda:0" if torch.cuda.is_available() else "cpu"
for epoch in range(num_epochs):
for i, data in enumerate(loader, 0):
netD.zero_grad()
realtens = data[0].to(device)
b_size = realtens.size(0)
label = torch.full((b_size,), 1, dtype=torch.float, device=device) # gen labels
output = netD(realtens)
errD_real = loss_func(output, label)
errD_real.backward() # backprop discriminator fake and real based on label
noise = torch.randn(b_size, 200, 1, 1, device=device)
fake = netG(noise)
label.fill_(0)
output = netD(fake.detach()).view(-1)
errD_fake = loss_func(output, label)
errD_fake.backward() # backprop discriminator fake and real based on label
errD = errD_real + errD_fake # discriminator error
optimizerD.step()
netG.zero_grad()
label.fill_(1)
output = netD(fake).view(-1)
errG = loss_func(output, label) # generator error
errG.backward()
optimizerG.step()
23 changes: 19 additions & 4 deletions docs/architectures.rst
Original file line number Diff line number Diff line change
Expand Up @@ -60,25 +60,40 @@ An example implementation in PyTorch.

GAN
===
A Generative Adversarial Network (GAN) is a type of network which creates novel tensors (often images, voices, etc.). The generative portion of the architecture competes with the discriminator part of the architecture in a zero-sum game. The goal of the generative network is to create novel tensors which the adversarial network attempts to classify as real or fake. The goal of the generative network is generate tensors where the discriminator network determines that the tensor has a 50% chance of being fake and a 50% chance of being real.

TODO: Description of GAN use case and basic architecture. Figure from [3].
Figure from [3].

.. image:: images/gan.png
:align: center

.. rubric:: Model

TODO: An example implementation in PyTorch.
An example implementation in PyTorch.


.. rubric:: Generator

.. literalinclude:: ../code/gan.py
:pyobject: Generator

.. rubric:: Discriminator

.. literalinclude:: ../code/gan.py
:pyobject: Discriminator


.. rubric:: Training

TODO
.. literalinclude:: ../code/gan.py
:pyobject: train

.. rubric:: Further reading

- `Generative Adversarial Networks <http://guertl.me/post/162759264070/generative-adversarial-networks>`_
- `Deep Learning Book <http://www.deeplearningbook.org/contents/generative_models.html>`_

- `PyTorch DCGAN Example <https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html>`_
- `Original Paper <https://papers.nips.cc/paper/2014/file/5ca3e9b122f61f8f06494c97b1afccf3-Paper.pdf>`_

MLP
===
Expand Down

0 comments on commit c1e22a0

Please sign in to comment.