#### BASIC GAN IMPLEMENTATION

In [93]:
from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from scipy.misc import imshow
import numpy as np
import matplotlib.pyplot as plt

# change this if needed
use_cuda = True

device = torch.device("cuda" if use_cuda else "cpu")

In [94]:
# torch.randn?

Discriminator:


In [95]:
class Discriminator(nn.Module):
    
    def __init__(self):
        super(Discriminator, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5, 1)
        self.conv2 = nn.Conv2d(20, 50, 5, 1)
        self.fc1 = nn.Linear(4*4*50, 500)
        self.fc2 = nn.Linear(500, 1)

    def forward(self, x, verbose=False):
        if verbose: print("Discriminator")
        if verbose: print(x.shape)
        x = F.relu(self.conv1(x))
        if verbose: print(x.shape)
        x = F.max_pool2d(x, 2, 2)
        if verbose: print(x.shape)
        x = F.relu(self.conv2(x))
        if verbose: print(x.shape)
        x = F.max_pool2d(x, 2, 2)
        if verbose: print(x.shape)
        x = x.view(-1, 4*4*50)
        if verbose: print(x.shape)
        x = F.relu(self.fc1(x))
        if verbose: print(x.shape)
        x = self.fc2(x)
        if verbose: print(x.shape)
        return torch.sigmoid(x)
    
# Let's test your decoder
n_components = 2
discriminator = Discriminator()
y = discriminator(torch.randn(1, 1, 28, 28), verbose=True)

Discriminator
torch.Size([1, 1, 28, 28])
torch.Size([1, 20, 24, 24])
torch.Size([1, 20, 12, 12])
torch.Size([1, 50, 8, 8])
torch.Size([1, 50, 4, 4])
torch.Size([1, 800])
torch.Size([1, 500])
torch.Size([1, 1])


Generator:

In [96]:
class Generator(nn.Module):
    # YOUR CODE HERE
    def __init__(self, n_components=2):
        super(Generator, self).__init__()
        self.fc1 = nn.Linear(n_components, 250)
        self.fc2 = nn.Linear(250, 14*14*16)
        self.conv1 = nn.ConvTranspose2d(in_channels=16, out_channels=6, kernel_size=5, padding=2)
        self.conv2 = nn.ConvTranspose2d(in_channels=6, out_channels=1, kernel_size=5, padding=2, stride=2, output_padding=1)
        
    def forward(self, x, verbose=False):
        if verbose: print('Decoder')
        if verbose: print(x.shape)
        x = F.relu(self.fc1(x))
        if verbose: print(x.shape)
        x = F.relu(self.fc2(x))
        if verbose: print(x.shape)
        x = x.view((-1, 16, 14, 14))
        if verbose: print(x.shape)
        x = F.relu(self.conv1(x))
        if verbose: print(x.shape)
        x = self.conv2(x)
        if verbose: print(x.shape)
    
        return x
    
# Let's test your Generator
n_components = 2
decoder_test = Generator(n_components)
y = decoder_test(torch.randn(1, n_components), verbose=True)
assert y.shape == torch.Size([1, 1, 28, 28]), "Bad shape of y: y.shape={}".format(y.shape)
print("The shapes seem to be ok.")

Decoder
torch.Size([1, 2])
torch.Size([1, 250])
torch.Size([1, 3136])
torch.Size([1, 16, 14, 14])
torch.Size([1, 6, 14, 14])
torch.Size([1, 1, 28, 28])
The shapes seem to be ok.


In [97]:
# image_real = np.arange(900).reshape((30, 30))



# imshow(y.detach().numpy().reshape((28, 28)))

# ax.imshow()


In [98]:
image_neg = 100+image_real
fig.set_data(image_neg)
plt.show(block=False)

In [114]:
targets_real = torch.ones(64, 1).to(device)  # Targets for discriminator: real data
targets_fake = torch.zeros(64, 1).to(device)  # Targets for discriminator: fake data
criterion = nn.BCELoss()
latent_size = 2

def train(D, G, train_loader, epoch, optimizer_d, optimizer_g, debug=False):
    # Initialize train mode
    D.train()
    G.train()
    # train on batches
    for batch_idx, (data, target) in enumerate(train_loader):
        data = data.to(device) #, target.to(device)
#         data shape:  torch.Size([64, 1, 28, 28])
        print(batch_idx)
        optimizer_d.zero_grad()
        optimizer_g.zero_grad()
        n_samples = len(data)
        if debug: print('data shape: ', data.shape)
        # update the discriminator
        for i in range(10):
            z_d = torch.randn(n_samples, latent_size).to(device)
            data_fake = G(z_d)
            # fake outputs
            outputs_fake = D(data_fake)
            # real outputs
            outputs_real = D(data)
            if debug: print('outputs fake:', outputs_fake.shape)
            if debug: print('outputs real_data', outputs_real.shape)
            outputs_real = D(data)
            d_loss_fake = criterion(outputs_fake, targets_fake)
            d_loss_real = criterion(outputs_real, targets_real)
            combined_d_loss = d_loss_fake + d_loss_real
            combined_d_loss.backward()
            optimizer_d.step()
        
        # Train the generator
        # Compute loss with fake data
        z = torch.randn(n_samples, latent_size).to(device)
        data_fake = G(z)
        outputs = D(data_fake)
        g_loss = criterion(outputs, targets_real)
        g_loss.backward()
        optimizer_g.step()
        
def main():
    seed = 1
    torch.manual_seed(1)
    
    kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
    train_loader = torch.utils.data.DataLoader(
        datasets.MNIST('../data', train=True, download=True,
                       transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                       ])),
        batch_size=64, shuffle=True, **kwargs)
    
    D = Discriminator().to(device)
    G = Generator(2).to(device)
    optimizer_d = optim.Adam(D.parameters(), lr=0.0002)
    optimizer_g = optim.Adam(G.parameters(), lr=0.0002)
    for epoch in range(1, 1000 + 1):
        print(1)
        train(D, G, train_loader, epoch, optimizer_d, optimizer_g)
        # TODO visualize generated images by the Generator
        if epoch % 50 == 0:
            n_components = 2
            decoder_test = Generator(n_components)
            y = decoder_test(torch.randn(1, n_components), verbose=True)
            fig = plt.imshow(y.detach().numpy().reshape((28, 28)), cmap='gray')
            plt.show()
        print('epoch', epoch, 'finished')
    if (True):
        torch.save(model.state_dict(),"mnist_cnn.pt")
main()

1
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37


Process Process-16:
  File "/u/53/aftenim1/unix/.conda/envs/capsenv/lib/python3.6/selectors.py", line 376, in select
    fd_event_list = self._poll.poll(timeout)
Traceback (most recent call last):
  File "/u/53/aftenim1/unix/.conda/envs/capsenv/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/u/53/aftenim1/unix/.conda/envs/capsenv/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/u/53/aftenim1/unix/.conda/envs/capsenv/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 96, in _worker_loop
    r = index_queue.get(timeout=MANAGER_STATUS_CHECK_INTERVAL)
  File "/u/53/aftenim1/unix/.conda/envs/capsenv/lib/python3.6/multiprocessing/queues.py", line 104, in get
    if not self._poll(timeout):
  File "/u/53/aftenim1/unix/.conda/envs/capsenv/lib/python3.6/multiprocessing/connection.py", line 257, in poll
    return self._poll(timeout)
  File "/u/53/aftenim1/unix/.conda/envs/c

KeyboardInterrupt: 