In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torch.autograd import Variable
from utils import create_dataloader

### Creating config object (argparse workaround)

In [2]:
class Config:
    pass

config = Config()
config.mnist_path = None
config.batch_size = 16
config.num_workers = 3
config.num_epochs = 10
config.noise_size = 50
config.print_freq = 100

### Create dataloder

In [3]:
dataloader = create_dataloader(config)

In [4]:
len(dataloader)

3750

In [5]:
for image, cat in dataloader:
    break

In [6]:
image.size()

torch.Size([16, 1, 28, 28])

In [7]:
28*28

784

### Create generator and discriminator

In [8]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential( 
            nn.Linear(config.noise_size, 200),
            nn.ReLU(inplace=True),
            nn.Linear(200, 28*28),
            nn.Sigmoid())
        
    def forward(self, x):
        return self.model(x)
    
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(28*28, 200),
            nn.ReLU(inplace=True),
            nn.Linear(200, 50),
            nn.ReLU(inplace=True),
            nn.Linear(50, 1), 
            nn.Sigmoid())
    def forward(self, x):
        return self.model(x)

In [9]:
generator = Generator()
discriminator = Discriminator()

### Create optimizers and loss

In [10]:
optim_G = optim.Adam(params=generator.parameters(), lr=0.0001)
optim_D = optim.Adam(params=discriminator.parameters(), lr=0.0001)

criterion = nn.BCELoss()

### Create necessary variables

In [11]:
input = Variable(torch.FloatTensor(config.batch_size, 28*28))
noise = Variable(torch.FloatTensor(config.batch_size, config.noise_size))
fixed_noise = Variable(torch.FloatTensor(config.batch_size, config.noise_size).normal_(0, 1))
label = Variable(torch.FloatTensor(config.batch_size))
real_label = 1
fake_label = 0

### Задание

1) Имплементируйте GAN из статьи

2) Попробуйте LSGAN https://arxiv.org/pdf/1611.04076v2.pdf

3) Попробуйте оба GAN на CelebA http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html

4) Напишите отчет что попробовали, какие результаты получили, как вам кажется надо обучать GAN, чтобы добиться сходимости?

Обязательны графики.

### Main loop

In [12]:
for epoch in range(config.num_epochs):
    for iteration, (images, cat) in enumerate(dataloader):
        ####### 
        # Discriminator stage: maximize log(D(x)) + log(1 - D(G(z))) 
        #######
        discriminator.zero_grad()
        
        # real
        
        # fake 
        
        optim_D.step()
        
        ####### 
        # Generator stage: maximize log(D(G(x))
        #######
        generator.zero_grad()
        
        optim_G.step()
        
        if (iteration+1) % config.print_freq == 0:
            print('Epoch:{} Iter: {} errD_x: {:.2f} errD_z: {:.2f} errG: {:.2f}'.format(epoch+1,
                                                                                            iteration+1, 
                                                                                            errD_x.data[0],
                                                                                            errD_z.data[0], 
                                                                                            errG.data[0]))

Epoch:1 Iter: 100 errD_x: 0.08 errD_z: 0.05 errG: 3.03
Epoch:1 Iter: 200 errD_x: 0.03 errD_z: 0.02 errG: 4.25
Epoch:1 Iter: 300 errD_x: 0.10 errD_z: 0.03 errG: 4.40
Epoch:1 Iter: 400 errD_x: 0.15 errD_z: 0.06 errG: 3.74
Epoch:1 Iter: 500 errD_x: 0.17 errD_z: 0.11 errG: 2.92
Epoch:1 Iter: 600 errD_x: 0.22 errD_z: 0.08 errG: 2.42
Epoch:1 Iter: 700 errD_x: 0.20 errD_z: 0.10 errG: 2.67
Epoch:1 Iter: 800 errD_x: 0.04 errD_z: 0.05 errG: 3.00
Epoch:1 Iter: 900 errD_x: 0.23 errD_z: 0.05 errG: 3.07
Epoch:1 Iter: 1000 errD_x: 0.04 errD_z: 0.06 errG: 2.64
Epoch:1 Iter: 1100 errD_x: 0.05 errD_z: 0.06 errG: 3.02


Process Process-5:
Process Process-4:
Process Process-6:
Traceback (most recent call last):
  File "/Users/mickle/anaconda3/lib/python3.5/multiprocessing/process.py", line 249, in _bootstrap
    self.run()
  File "/Users/mickle/anaconda3/lib/python3.5/multiprocessing/synchronize.py", line 96, in __enter__
    return self._semlock.__enter__()
  File "/Users/mickle/anaconda3/lib/python3.5/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
Traceback (most recent call last):
Traceback (most recent call last):
  File "/Users/mickle/anaconda3/lib/python3.5/site-packages/torch/utils/data/dataloader.py", line 28, in _worker_loop
    r = index_queue.get()
  File "/Users/mickle/anaconda3/lib/python3.5/multiprocessing/process.py", line 249, in _bootstrap
    self.run()
  File "/Users/mickle/anaconda3/lib/python3.5/multiprocessing/queues.py", line 342, in get
    with self._rlock:
  File "/Users/mickle/anaconda3/lib/python3.5/multiprocessing/process.py", lin

KeyboardInterrupt: 