In [80]:
# prerequisites
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torchvision.utils import save_image

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [81]:
BATCH_SIZE = 100

# MNIST Dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(0.13, 0.31)])

# train_dataset = datasets.MNIST(root='data', train=True, transform=transform, download=True)
# test_dataset = datasets.MNIST(root='data', train=False, transform=transform, download=True)

train_dataset = datasets.FashionMNIST(root='data', train=True, transform=transform, download=True)
test_dataset = datasets.FashionMNIST(root='data', train=False, transform=transform, download=True)

# Data Loader (Input Pipeline)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE, shuffle=False)


In [82]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Инициализация весов
def init_weights(m):
    if isinstance(m, nn.Linear):
        nn.init.kaiming_normal_(m.weight, nonlinearity='leaky_relu')  # He Initialization
        if m.bias is not None:
            m.bias.data.fill_(0.01)

# Генератор
class Generator(nn.Module):
    def __init__(self, g_input_dim, g_output_dim):
        super(Generator, self).__init__()
        self.fc1 = nn.Linear(g_input_dim, 512)
        self.bn1 = nn.BatchNorm1d(512)
        self.fc2 = nn.Linear(512, 1024)
        self.bn2 = nn.BatchNorm1d(1024)
        self.fc3 = nn.Linear(1024, 1024)
        self.bn3 = nn.BatchNorm1d(1024)
        self.fc4 = nn.Linear(1024, g_output_dim)
        self.apply(init_weights)

    def forward(self, x):
        x = F.leaky_relu(self.bn1(self.fc1(x)), 0.2)
        x = F.leaky_relu(self.bn2(self.fc2(x)), 0.2)
        x = F.leaky_relu(self.bn3(self.fc3(x)), 0.2)
        return torch.tanh(self.fc4(x))

# Дискриминатор
class Discriminator(nn.Module):
    def __init__(self, d_input_dim):
        super(Discriminator, self).__init__()
        self.fc1 = nn.Linear(d_input_dim, 1024)
        self.fc2 = nn.Linear(1024, 512)
        self.fc3 = nn.Linear(512, 256)
        self.fc4 = nn.Linear(256, 1)
        self.dropout = nn.Dropout(0.3)
        self.apply(init_weights)

    def forward(self, x):
        x = F.leaky_relu(self.fc1(x), 0.2)
        x = self.dropout(x)
        x = F.leaky_relu(self.fc2(x), 0.2)
        x = self.dropout(x)
        x = F.leaky_relu(self.fc3(x), 0.2)
        x = self.dropout(x)
        return torch.sigmoid(self.fc4(x))


In [83]:
# build network
Z_DIM = 50
mnist_dim = train_dataset.train_data.size(1) * train_dataset.train_data.size(2)

G = Generator(g_input_dim = Z_DIM, g_output_dim = mnist_dim).to(device)
D = Discriminator(mnist_dim).to(device)

In [84]:
G

Generator(
  (fc1): Linear(in_features=50, out_features=512, bias=True)
  (bn1): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc2): Linear(in_features=512, out_features=1024, bias=True)
  (bn2): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc3): Linear(in_features=1024, out_features=1024, bias=True)
  (bn3): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc4): Linear(in_features=1024, out_features=784, bias=True)
)

In [85]:
D

Discriminator(
  (fc1): Linear(in_features=784, out_features=1024, bias=True)
  (fc2): Linear(in_features=1024, out_features=512, bias=True)
  (fc3): Linear(in_features=512, out_features=256, bias=True)
  (fc4): Linear(in_features=256, out_features=1, bias=True)
  (dropout): Dropout(p=0.3, inplace=False)
)

In [86]:
# loss
criterion = nn.BCELoss()

# optimizer
G_lr = 0.0001
D_lr = 0.00002
G_optimizer = optim.Adam(G.parameters(), lr = G_lr, betas=(0.5, 0.999))
D_optimizer = optim.Adam(D.parameters(), lr = D_lr, betas=(0.5, 0.999))

In [87]:
def D_train(x):
    D.zero_grad()

    # train discriminator on real
    x_real, y_real = x.view(-1, mnist_dim), torch.ones(BATCH_SIZE, 1)
    x_real, y_real = Variable(x_real.to(device)), Variable(y_real.to(device))

    D_output = D(x_real)
    D_real_loss = criterion(D_output, y_real)
    D_real_score = D_output

    # train discriminator on fake
    z = Variable(torch.randn(BATCH_SIZE, Z_DIM).to(device))
    x_fake, y_fake = G(z), Variable(torch.zeros(BATCH_SIZE, 1).to(device))

    D_output = D(x_fake)
    D_fake_loss = criterion(D_output, y_fake)
    D_fake_score = D_output

    # gradient backprop & optimize ONLY D's parameters
    D_loss = D_real_loss + D_fake_loss
    D_loss.backward()
    D_optimizer.step()

    return  D_loss.data.item()

In [88]:
def G_train(x):
    G.zero_grad()

    z = Variable(torch.randn(BATCH_SIZE, Z_DIM).to(device))
    y = Variable(torch.ones(BATCH_SIZE, 1).to(device))

    G_output = G(z)
    D_output = D(G_output)
    G_loss = criterion(D_output, y)

    # gradient backprop & optimize ONLY G's parameters
    G_loss.backward()
    G_optimizer.step()

    return G_loss.data.item()

In [89]:
test_z = Variable(torch.randn(BATCH_SIZE, Z_DIM).to(device))
def generate_test_image(epoch):
    with torch.no_grad():
        generated = G(test_z)
        save_image(generated.view(generated.size(0), 1, 28, 28)[0],
                   f'./output/sample_{epoch}.png')

In [90]:
n_epoch = 300
for epoch in range(1, n_epoch+1):
    D_losses, G_losses = [], []
    for batch_idx, (x, _) in enumerate(train_loader):
        D_losses.append(D_train(x))
        G_losses.append(G_train(x))

    print('[%d/%d]: loss_d: %.3f, loss_g: %.3f' % (
            (epoch), n_epoch, torch.mean(torch.FloatTensor(D_losses)), torch.mean(torch.FloatTensor(G_losses))))

    generate_test_image(epoch)

[1/300]: loss_d: 1.916, loss_g: 0.642
[2/300]: loss_d: 1.553, loss_g: 0.658
[3/300]: loss_d: 1.422, loss_g: 0.631
[4/300]: loss_d: 1.344, loss_g: 0.642
[5/300]: loss_d: 1.288, loss_g: 0.663
[6/300]: loss_d: 1.258, loss_g: 0.673
[7/300]: loss_d: 1.213, loss_g: 0.700
[8/300]: loss_d: 1.161, loss_g: 0.740
[9/300]: loss_d: 1.106, loss_g: 0.794
[10/300]: loss_d: 1.055, loss_g: 0.852
[11/300]: loss_d: 1.007, loss_g: 0.919
[12/300]: loss_d: 0.971, loss_g: 0.978
[13/300]: loss_d: 0.924, loss_g: 1.052
[14/300]: loss_d: 0.877, loss_g: 1.131
[15/300]: loss_d: 0.833, loss_g: 1.219
[16/300]: loss_d: 0.803, loss_g: 1.290
[17/300]: loss_d: 0.767, loss_g: 1.357
[18/300]: loss_d: 0.735, loss_g: 1.432
[19/300]: loss_d: 0.705, loss_g: 1.497
[20/300]: loss_d: 0.673, loss_g: 1.568
[21/300]: loss_d: 0.659, loss_g: 1.631
[22/300]: loss_d: 0.634, loss_g: 1.676
[23/300]: loss_d: 0.617, loss_g: 1.727
[24/300]: loss_d: 0.602, loss_g: 1.770
[25/300]: loss_d: 0.586, loss_g: 1.821
[26/300]: loss_d: 0.574, loss_g: 1

In [91]:
generated = G(test_z)

In [92]:
generated.shape

torch.Size([100, 784])

In [93]:
save_image(generated.view(generated.size(0), 1, 28, 28),
                   f'./output/all.png')