### WGAN

* Модифицируйте код ячеек ниже и реализуйте [Wasserstein GAN](https://arxiv.org/abs/1701.07875) с клиппингом весов. (10 баллов)

* Замените клиппинг весов на [штраф градентов](https://arxiv.org/pdf/1704.00028v3.pdf). (10 баллов)

* Добавьте лейблы в WGAN, тем самым решая задачу [условной генерации](https://arxiv.org/pdf/1411.1784.pdf). (30 баллов)

Добавьте в этот файл анализ полученных результатов с различными графиками обучения и визуализацию генерации. Сравните как работает клиппинг весов и штраф градиентов и попробуйте пронаблюдать какие недостатки имеет модель GAN.

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

import torchvision
import matplotlib.pyplot as plt
import numpy as np

from torch.autograd import Variable

### Простой конфиг (для хранения параметров, можете использовать и модифицировать)

In [2]:
class Config:
    pass

CONFIG = Config()
CONFIG.mnist_path = None
CONFIG.batch_size = 16
CONFIG.num_workers = 3
CONFIG.num_epochs = 10
CONFIG.noise_size = 50
CONFIG.print_freq = 500
CONFIG.learning_rate = 0.0002

### Создаем dataloader

In [3]:
dataset = torchvision.datasets.FashionMNIST("fashion_mnist", train=True, transform=torchvision.transforms.ToTensor(), download=True)

In [4]:
dataloader = DataLoader(dataset, batch_size=CONFIG.batch_size, shuffle=True)
len(dataloader)

3750

In [5]:
image, label = next(iter(dataloader))
image.size()

torch.Size([16, 1, 28, 28])

### Создаем модель GAN

In [6]:
class Generator(nn.Module):
    def __init__(self):
        # super(Generator, self).__init__()
        # self.model = nn.Sequential(
        #     # config.noise_size = 50
        #     nn.Linear(CONFIG.noise_size, 200),
        #     nn.ReLU(inplace=True),
        #     nn.Linear(200, 28*28),
        #     nn.Sigmoid()
        # )

        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(100, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 784),
            nn.Tanh()
        )

    def forward(self, x):
        out = self.model(x)
        out = out.view(x.size(0), 1, 28, 28)
        return out


class Discriminator(nn.Module):
    def __init__(self):
        # super(Discriminator, self).__init__()
        # self.model = nn.Sequential(
        #     nn.Linear(28*28, 200),
        #     nn.ReLU(inplace=True),
        #     nn.Linear(200, 50),
        #     nn.ReLU(inplace=True),
        #     nn.Linear(50, 1),
        #     nn.Sigmoid())

        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(784, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1)
        )

    def forward(self, x):
        x = x.view(x.size(0), -1)
        out = self.model(x)
        return out

In [7]:
generator = Generator()
discriminator = Discriminator()

In [8]:
generator

Generator(
  (model): Sequential(
    (0): Linear(in_features=100, out_features=256, bias=True)
    (1): LeakyReLU(negative_slope=0.2)
    (2): Linear(in_features=256, out_features=512, bias=True)
    (3): LeakyReLU(negative_slope=0.2)
    (4): Linear(in_features=512, out_features=784, bias=True)
    (5): Tanh()
  )
)

In [9]:
discriminator

Discriminator(
  (model): Sequential(
    (0): Linear(in_features=784, out_features=512, bias=True)
    (1): LeakyReLU(negative_slope=0.2)
    (2): Linear(in_features=512, out_features=256, bias=True)
    (3): LeakyReLU(negative_slope=0.2)
    (4): Linear(in_features=256, out_features=1, bias=True)
  )
)

### Оптимизатор и функция потерь

In [10]:
# learning_rate = 0.0002
# device = 'cuda' if torch.cuda.is_available() else 'cpu'
# print(device)

# generator = generator.to(device)
# discriminator = discriminator.to(device)
# optim_G = optim.Adam(params=generator.parameters(), lr=learning_rate)
# optim_D = optim.Adam(params=discriminator.parameters(), lr=learning_rate)

# criterion = nn.BCELoss()

### Для оптимизации процесса обучения можно заранее определить переменные и заполнять их значения новыми данными

In [11]:
# # noise = Variable(torch.FloatTensor(config.batch_size, config.noise_size, device=device))
# # # fixed_noise = Variable(torch.FloatTensor(config.batch_size, config.noise_size, device=device).normal_(0, 1))
# # label = Variable(torch.FloatTensor(config.batch_size, device=device))

# noise_x = torch.empty(CONFIG.batch_size)
# print(noise_x)
# # noise must be 50 x 200
# # noise = torch.tensor(noise_x, dtype=torch.float, device=torch.device(device))

# # DEFAULT LINE
# noise = Variable(torch.FloatTensor(CONFIG.batch_size, CONFIG.noise_size, device=device))
# # fixed_noise = Variable(torch.FloatTensor(config.batch_size, config.noise_size, device=device).normal_(0, 1))

# noise_y = torch.empty(CONFIG.batch_size)
# # label = torch.tensor(noise_y, dtype=torch.float, device=torch.device(device))

# # DEFAULT LINE
# label = Variable(torch.FloatTensor(CONFIG.batch_size, device=device))


# real_label = 1
# fake_label = 0

### GAN обучение

In [12]:
# ERRD_x = np.zeros(CONFIG.num_epochs)
# ERRD_z = np.zeros(CONFIG.num_epochs)
# ERRG = np.zeros(CONFIG.num_epochs)
# N = len(dataloader)

# for epoch in range(CONFIG.num_epochs):
#     for iteration, (images, cat) in enumerate(dataloader):
#         #######
#         # Discriminator stage: maximize log(D(x)) + log(1 - D(G(z)))
#         #######
#         discriminator.zero_grad()

#         # real
#         label.data.fill_(real_label)
#         input_data = images.view(images.shape[0], -1).to(device)
#         output = discriminator(input_data).view(-1)
#         errD_x = criterion(output, label)
#         ERRD_x[epoch] += errD_x.item()
#         errD_x.backward()

#         # fake
#         noise.data.normal_(0, 1)
#         fake = generator(noise)
#         label.data.fill_(fake_label)
#         output = discriminator(fake.detach()).view(-1)
#         errD_z = criterion(output, label)
#         ERRD_z[epoch] += errD_z.item()
#         errD_z.backward()

#         optim_D.step()

#         #######
#         # Generator stage: maximize log(D(G(x))
#         #######
#         generator.zero_grad()
#         label.data.fill_(real_label)
#         output = discriminator(fake).view(-1)
#         errG = criterion(output, label)
#         ERRG[epoch] += errG.item()
#         errG.backward()

#         optim_G.step()

#         if (iteration+1) % CONFIG.print_freq == 0:
#             print('Epoch:{} Iter: {} errD_x: {:.2f} errD_z: {:.2f} errG: {:.2f}'.format(epoch+1,
#                                                                                             iteration+1,
#                                                                                             errD_x.item(),
#                                                                                             errD_z.item(),
#                                                                                             errG.item()))

In [13]:
def train(discriminator, generator, dataloader, num_epochs, lr):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    discriminator.to(device)
    generator.to(device)

    criterion = nn.BCEWithLogitsLoss()

    d_optimizer = optim.Adam(discriminator.parameters(), lr=lr)
    g_optimizer = optim.Adam(generator.parameters(), lr=lr)

    for epoch in range(num_epochs):
        for i, (real_images, _) in enumerate(dataloader):
            real_images = real_images.to(device)
            batch_size = real_images.size(0)
            real_labels = torch.ones(batch_size, 1).to(device)
            fake_labels = torch.zeros(batch_size, 1).to(device)

            # Train the discriminator
            discriminator.zero_grad()
            real_outputs = discriminator(real_images)
            d_loss_real = criterion(real_outputs, real_labels)

            z = torch.randn(batch_size, 100).to(device)
            fake_images = generator(z)
            fake_outputs = discriminator(fake_images.detach())
            d_loss_fake = criterion(fake_outputs, fake_labels)

            d_loss = d_loss_real + d_loss_fake
            d_loss.backward()
            d_optimizer.step()

            # Train the generator
            generator.zero_grad()
            z = torch.randn(batch_size, 100).to(device)
            fake_images = generator(z)
            fake_outputs = discriminator(fake_images)
            g_loss = criterion(fake_outputs, real_labels)

            g_loss.backward()
            g_optimizer.step()

            if (i+1) % 200 == 0:
                print(f"Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(dataloader)}], d_loss: {d_loss.item():.4f}, g_loss: {g_loss.item():.4f}")

In [14]:
torch.manual_seed(42)

<torch._C.Generator at 0x7b141814d490>

In [None]:
train(discriminator, generator, dataloader, num_epochs=CONFIG.num_epochs, lr=CONFIG.learning_rate)

Epoch [1/10], Step [200/3750], d_loss: 0.1925, g_loss: 3.4698
Epoch [1/10], Step [400/3750], d_loss: 0.3738, g_loss: 3.3591
Epoch [1/10], Step [600/3750], d_loss: 1.0137, g_loss: 2.4955
Epoch [1/10], Step [800/3750], d_loss: 0.4893, g_loss: 3.2483
Epoch [1/10], Step [1000/3750], d_loss: 1.4077, g_loss: 1.2452
Epoch [1/10], Step [1200/3750], d_loss: 1.0172, g_loss: 1.1295
Epoch [1/10], Step [1400/3750], d_loss: 0.5861, g_loss: 2.1239
Epoch [1/10], Step [1600/3750], d_loss: 1.2635, g_loss: 1.7423
Epoch [1/10], Step [1800/3750], d_loss: 1.8341, g_loss: 0.8334
Epoch [1/10], Step [2000/3750], d_loss: 1.1345, g_loss: 1.0233
Epoch [1/10], Step [2200/3750], d_loss: 0.4739, g_loss: 2.0215
Epoch [1/10], Step [2400/3750], d_loss: 0.2000, g_loss: 3.0784
Epoch [1/10], Step [2600/3750], d_loss: 0.5896, g_loss: 3.7119
Epoch [1/10], Step [2800/3750], d_loss: 1.6589, g_loss: 1.0096
Epoch [1/10], Step [3000/3750], d_loss: 0.9741, g_loss: 2.2364


In [None]:
# noise.data.normal_(0, 1)
# fake = generator(noise)

# plt.figure(figsize=(6, 7))
# for i in range(16):
#     plt.subplot(4, 4, i + 1)
#     plt.imshow(fake[i].detach().numpy().reshape(28, 28), cmap=plt.cm.Greys_r)
#     plt.axis('off')