In [2]:
import torch
from torch import nn

import math
import matplotlib.pyplot as plt
import torchvision
import torchvision.transforms as transforms


In [7]:
torch.manual_seed(111)

device = ""
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")


In [8]:
transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
)


In [12]:
!pip install  widgetsnbextension

Keyring is skipped due to an exception: org.freedesktop.DBus.Error.InvalidFileContent: D-Bus library appears to be incorrectly set up: see the manual page for dbus-uuidgen to correct this issue. (Failed to open "/var/lib/dbus/machine-id": No such file or directory; UUID file '/etc/machine-id' should contain a hex string of length 32, not length 0, with no other text)
Defaulting to user installation because normal site-packages is not writeable
Collecting widgetsnbextension
  Downloading widgetsnbextension-3.6.4-py2.py3-none-any.whl (1.6 MB)
     |████████████████████████████████| 1.6 MB 2.7 MB/s            


Installing collected packages: widgetsnbextension
Successfully installed widgetsnbextension-3.6.4


In [13]:
!pip install ipywidgets


Keyring is skipped due to an exception: org.freedesktop.DBus.Error.InvalidFileContent: D-Bus library appears to be incorrectly set up: see the manual page for dbus-uuidgen to correct this issue. (Failed to open "/var/lib/dbus/machine-id": No such file or directory; UUID file '/etc/machine-id' should contain a hex string of length 32, not length 0, with no other text)
Defaulting to user installation because normal site-packages is not writeable
Collecting ipywidgets
  Downloading ipywidgets-7.7.5-py2.py3-none-any.whl (123 kB)
     |████████████████████████████████| 123 kB 2.5 MB/s            
Collecting jupyterlab-widgets<3,>=1.0.0
  Downloading jupyterlab_widgets-1.1.4-py3-none-any.whl (246 kB)
     |████████████████████████████████| 246 kB 11.9 MB/s            


Installing collected packages: jupyterlab-widgets, ipywidgets
Successfully installed ipywidgets-7.7.5 jupyterlab-widgets-1.1.4


In [14]:
train_set = torchvision.datasets.MNIST(
    root=".", train=True, download=True, transform=transform
)


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./MNIST/raw/train-images-idx3-ubyte.gz



Exception ignored in: <bound method tqdm.__del__ of <tqdm.auto.tqdm object at 0x7fb241d566d8>>
Traceback (most recent call last):
  File "/home/marcel/.local/lib/python3.6/site-packages/tqdm/std.py", line 1145, in __del__
    self.close()
  File "/home/marcel/.local/lib/python3.6/site-packages/tqdm/notebook.py", line 283, in close
    self.disp(bar_style='danger', check_delay=False)
AttributeError: 'tqdm' object has no attribute 'disp'


ImportError: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html

In [None]:
batch_size = 32
train_loader = torch.utils.data.DataLoader(
    train_set, batch_size=batch_size, shuffle=True
)


In [None]:
real_samples, mnist_labels = next(iter(train_loader))
for i in range(16):
    ax = plt.subplot(4, 4, i + 1)
    plt.imshow(real_samples[i].reshape(28, 28), cmap="gray_r")
    plt.xticks([])
    plt.yticks([])


In [15]:
class Discriminator(nn.Module):

    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(784, 1024),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )


    def forward(self, x):
        x = x.view(x.size(0), 784)
        output = self.model(x)
        return output

In [16]:
class Generator(nn.Module):

    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(100, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, 1024),
            nn.ReLU(),
            nn.Linear(1024, 784),
            nn.Tanh(),
        )


    def forward(self, x):
        output = self.model(x)
        output = output.view(x.size(0), 1, 28, 28)
        return output

In [17]:
discriminator = Discriminator().to(device=device)
generator = Generator().to(device=device)

In [18]:
lr = 0.0001
num_epochs = 50
loss_function = nn.BCELoss()

optimizer_discriminator = torch.optim.Adam(discriminator.parameters(), lr=lr)
optimizer_generator = torch.optim.Adam(generator.parameters(), lr=lr)


In [None]:
for epoch in range(num_epochs):

    for n, (real_samples, mnist_labels) in enumerate(train_loader):
        # Data for training the discriminator
        real_samples = real_samples.to(device=device)
        real_samples_labels = torch.ones((batch_size, 1)).to(
            device=device
        )

        latent_space_samples = torch.randn((batch_size, 100)).to(
            device=device
        )

        generated_samples = generator(latent_space_samples)
        generated_samples_labels = torch.zeros((batch_size, 1)).to(
            device=device
        )

        all_samples = torch.cat((real_samples, generated_samples))
        all_samples_labels = torch.cat(
            (real_samples_labels, generated_samples_labels)
        )


        # Training the discriminator
        discriminator.zero_grad()
        output_discriminator = discriminator(all_samples)
        loss_discriminator = loss_function(
            output_discriminator, all_samples_labels
        )

        loss_discriminator.backward()
        optimizer_discriminator.step()


        # Data for training the generator
        latent_space_samples = torch.randn((batch_size, 100)).to(
            device=device
        )


        # Training the generator
        generator.zero_grad()
        generated_samples = generator(latent_space_samples)
        output_discriminator_generated = discriminator(generated_samples)
        loss_generator = loss_function(
            output_discriminator_generated, real_samples_labels
        )

        loss_generator.backward()

        optimizer_generator.step()


        # Show loss
        if n == batch_size - 1:
            print(f"Epoch: {epoch} Loss D.: {loss_discriminator}")
            print(f"Epoch: {epoch} Loss G.: {loss_generator}")