In [1]:
import torch
import torch.nn as nn


class Discriminator(nn.Module):
    def __init__(self, channels_img, features_d):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            nn.Conv2d(
                channels_img, features_d, kernel_size=4, stride=2, padding=1
            ),
            nn.LeakyReLU(0.2),
            self._block(features_d, features_d * 2, 4, 2, 1),
            self._block(features_d * 2, features_d * 4, 4, 2, 1),
            self._block(features_d * 4, features_d * 8, 4, 2, 1),
            nn.Conv2d(features_d * 8, 1, kernel_size=4, stride=2, padding=0),
        )

    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.Conv2d(
                in_channels,
                out_channels,
                kernel_size,
                stride,
                padding,
                bias=False,
            ),
            nn.InstanceNorm2d(out_channels, affine=True),
            nn.LeakyReLU(0.2),
        )

    def forward(self, x):
        return self.disc(x)


class Generator(nn.Module):
    def __init__(self, channels_noise, channels_img, features_g):
        super(Generator, self).__init__()
        self.net = nn.Sequential(
            self._block(channels_noise, features_g * 16, 4, 1, 0),  # img: 4x4
            self._block(features_g * 16, features_g * 8, 4, 2, 1),  # img: 8x8
            self._block(features_g * 8, features_g * 4, 4, 2, 1),  # img: 16x16
            self._block(features_g * 4, features_g * 2, 4, 2, 1),  # img: 32x32
            nn.ConvTranspose2d(
                features_g * 2, channels_img, kernel_size=4, stride=2, padding=1
            ),
            nn.Tanh(),
        )

    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.ConvTranspose2d(
                in_channels,
                out_channels,
                kernel_size,
                stride,
                padding,
                bias=False,
            ),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        )

    def forward(self, x):
        return self.net(x)


def initialize_weights(model):
    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
            nn.init.normal_(m.weight.data, 0.0, 0.02)

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter

# Hyperparameters etc
device = "cuda" if torch.cuda.is_available() else "cpu"
LEARNING_RATE = 5e-5
BATCH_SIZE = 64
IMAGE_SIZE = 64
CHANNELS_IMG = 1
Z_DIM = 128
NUM_EPOCHS = 5
FEATURES_CRITIC = 64
FEATURES_GEN = 64
CRITIC_ITERATIONS = 5
WEIGHT_CLIP = 0.01

transforms = transforms.Compose(
    [
        transforms.Resize(IMAGE_SIZE),
        transforms.ToTensor(),
        transforms.Normalize(
            [0.5 for _ in range(CHANNELS_IMG)], [0.5 for _ in range(CHANNELS_IMG)]
        ),
    ]
)

dataset = datasets.MNIST(root="dataset/", transform=transforms, download=True)

loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

gen = Generator(Z_DIM, CHANNELS_IMG, FEATURES_GEN).to(device)
critic = Discriminator(CHANNELS_IMG, FEATURES_CRITIC).to(device)
initialize_weights(gen)
initialize_weights(critic)

#optimizer
opt_gen = optim.RMSprop(gen.parameters(), lr=LEARNING_RATE)
opt_critic = optim.RMSprop(critic.parameters(), lr=LEARNING_RATE)

# for tensorboard plotting
fixed_noise = torch.randn(32, Z_DIM, 1, 1).to(device)
writer_real = SummaryWriter(f"logs/real")
writer_fake = SummaryWriter(f"logs/fake")
step = 0

gen.train()
critic.train()

for epoch in range(NUM_EPOCHS):
    for batch_idx, (data, _) in enumerate(tqdm(loader)):
        data = data.to(device)
        cur_batch_size = data.shape[0]

        # Train Critic: max E[critic(real)] - E[critic(fake)]
        for _ in range(CRITIC_ITERATIONS):
            noise = torch.randn(cur_batch_size, Z_DIM, 1, 1).to(device)
            fake = gen(noise)
            critic_real = critic(data).reshape(-1)
            critic_fake = critic(fake).reshape(-1)
            loss_critic = -(torch.mean(critic_real) - torch.mean(critic_fake))
            critic.zero_grad()
            loss_critic.backward(retain_graph=True)
            opt_critic.step()

            for p in critic.parameters():
                p.data.clamp_(-WEIGHT_CLIP, WEIGHT_CLIP)

        # Train Generator: max E[critic(gen_fake)] <-> min -E[critic(gen_fake)]
        gen_fake = critic(fake).reshape(-1)
        loss_gen = -torch.mean(gen_fake)
        gen.zero_grad()
        loss_gen.backward()
        opt_gen.step()

        # Print losses occasionally and print to tensorboard
        if batch_idx % 100 == 0 and batch_idx > 0:
            gen.eval()
            critic.eval()
            print(
                f"Epoch [{epoch}/{NUM_EPOCHS}] Batch {batch_idx}/{len(loader)} \
                  Loss D: {loss_critic:.4f}, loss G: {loss_gen:.4f}"
            )

            with torch.no_grad():
                fake = gen(noise)
                img_grid_real = torchvision.utils.make_grid(
                    data[:32], normalize=True
                )
                img_grid_fake = torchvision.utils.make_grid(
                    fake[:32], normalize=True
                )

                writer_real.add_image("Real", img_grid_real, global_step=step)
                writer_fake.add_image("Fake", img_grid_fake, global_step=step)

            step += 1
            gen.train()
            critic.train()

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1007)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to dataset/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 53197776.22it/s]


Extracting dataset/MNIST/raw/train-images-idx3-ubyte.gz to dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1007)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to dataset/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 1518771.47it/s]

Extracting dataset/MNIST/raw/train-labels-idx1-ubyte.gz to dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz





Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1007)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to dataset/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 13555043.03it/s]


Extracting dataset/MNIST/raw/t10k-images-idx3-ubyte.gz to dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1007)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to dataset/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 2519244.75it/s]

Extracting dataset/MNIST/raw/t10k-labels-idx1-ubyte.gz to dataset/MNIST/raw




 11%|█         | 101/938 [00:38<06:23,  2.18it/s]

Epoch [0/5] Batch 100/938                   Loss D: -1.3647, loss G: 0.6610


 21%|██▏       | 201/938 [01:16<05:19,  2.30it/s]

Epoch [0/5] Batch 200/938                   Loss D: -1.5186, loss G: 0.7368


 32%|███▏      | 301/938 [01:54<04:36,  2.31it/s]

Epoch [0/5] Batch 300/938                   Loss D: -1.5354, loss G: 0.7415


 43%|████▎     | 401/938 [02:31<03:52,  2.30it/s]

Epoch [0/5] Batch 400/938                   Loss D: -1.5400, loss G: 0.7437


 53%|█████▎    | 501/938 [03:09<03:11,  2.28it/s]

Epoch [0/5] Batch 500/938                   Loss D: -1.5432, loss G: 0.7441


 64%|██████▍   | 601/938 [03:46<02:25,  2.31it/s]

Epoch [0/5] Batch 600/938                   Loss D: -1.5447, loss G: 0.7448


 75%|███████▍  | 701/938 [04:24<01:43,  2.30it/s]

Epoch [0/5] Batch 700/938                   Loss D: -1.3824, loss G: 0.6047


 85%|████████▌ | 801/938 [05:01<00:59,  2.31it/s]

Epoch [0/5] Batch 800/938                   Loss D: -1.5321, loss G: 0.7425


 96%|█████████▌| 901/938 [05:39<00:16,  2.30it/s]

Epoch [0/5] Batch 900/938                   Loss D: -1.4830, loss G: 0.7259


100%|██████████| 938/938 [05:53<00:00,  2.66it/s]
 11%|█         | 101/938 [00:37<06:02,  2.31it/s]

Epoch [1/5] Batch 100/938                   Loss D: -1.4134, loss G: 0.6896


 21%|██▏       | 201/938 [01:15<05:19,  2.31it/s]

Epoch [1/5] Batch 200/938                   Loss D: -1.4286, loss G: 0.6958


 32%|███▏      | 301/938 [01:52<04:35,  2.31it/s]

Epoch [1/5] Batch 300/938                   Loss D: -1.4252, loss G: 0.7108


 43%|████▎     | 401/938 [02:30<03:51,  2.32it/s]

Epoch [1/5] Batch 400/938                   Loss D: -1.4125, loss G: 0.7208


 53%|█████▎    | 501/938 [03:08<03:08,  2.31it/s]

Epoch [1/5] Batch 500/938                   Loss D: -1.4421, loss G: 0.6951


 64%|██████▍   | 601/938 [03:45<02:25,  2.31it/s]

Epoch [1/5] Batch 600/938                   Loss D: -1.4034, loss G: 0.6976


 75%|███████▍  | 701/938 [04:23<01:42,  2.32it/s]

Epoch [1/5] Batch 700/938                   Loss D: -1.3704, loss G: 0.6985


 85%|████████▌ | 801/938 [05:00<00:59,  2.32it/s]

Epoch [1/5] Batch 800/938                   Loss D: -1.3853, loss G: 0.6730


 96%|█████████▌| 901/938 [05:38<00:15,  2.32it/s]

Epoch [1/5] Batch 900/938                   Loss D: -0.9974, loss G: 0.6169


100%|██████████| 938/938 [05:51<00:00,  2.67it/s]
 11%|█         | 101/938 [00:37<05:59,  2.33it/s]

Epoch [2/5] Batch 100/938                   Loss D: -1.2985, loss G: 0.6840


 21%|██▏       | 201/938 [01:15<05:17,  2.32it/s]

Epoch [2/5] Batch 200/938                   Loss D: -1.1315, loss G: 0.5376


 32%|███▏      | 301/938 [01:52<04:33,  2.33it/s]

Epoch [2/5] Batch 300/938                   Loss D: -1.2801, loss G: 0.6413


 43%|████▎     | 401/938 [02:30<03:50,  2.33it/s]

Epoch [2/5] Batch 400/938                   Loss D: -1.1872, loss G: 0.6575


 53%|█████▎    | 501/938 [03:07<03:07,  2.34it/s]

Epoch [2/5] Batch 500/938                   Loss D: -1.1423, loss G: 0.4811


 64%|██████▍   | 601/938 [03:45<02:24,  2.33it/s]

Epoch [2/5] Batch 600/938                   Loss D: -1.2443, loss G: 0.6471


 75%|███████▍  | 701/938 [04:22<01:41,  2.33it/s]

Epoch [2/5] Batch 700/938                   Loss D: -1.2490, loss G: 0.5935


 85%|████████▌ | 801/938 [05:00<00:58,  2.33it/s]

Epoch [2/5] Batch 800/938                   Loss D: -1.1333, loss G: 0.4528


 96%|█████████▌| 901/938 [05:37<00:15,  2.34it/s]

Epoch [2/5] Batch 900/938                   Loss D: -1.1454, loss G: 0.4702


100%|██████████| 938/938 [05:51<00:00,  2.67it/s]
 11%|█         | 101/938 [00:37<05:57,  2.34it/s]

Epoch [3/5] Batch 100/938                   Loss D: -1.1730, loss G: 0.6198


 21%|██▏       | 201/938 [01:15<05:15,  2.33it/s]

Epoch [3/5] Batch 200/938                   Loss D: -1.1837, loss G: 0.5750


 32%|███▏      | 301/938 [01:53<04:32,  2.34it/s]

Epoch [3/5] Batch 300/938                   Loss D: -1.0097, loss G: 0.6747


 43%|████▎     | 401/938 [02:30<03:49,  2.34it/s]

Epoch [3/5] Batch 400/938                   Loss D: -1.1198, loss G: 0.4992


 53%|█████▎    | 501/938 [03:08<03:06,  2.34it/s]

Epoch [3/5] Batch 500/938                   Loss D: -1.2330, loss G: 0.6018


 64%|██████▍   | 601/938 [03:45<02:23,  2.34it/s]

Epoch [3/5] Batch 600/938                   Loss D: -1.0980, loss G: 0.3906


 75%|███████▍  | 701/938 [04:23<01:41,  2.34it/s]

Epoch [3/5] Batch 700/938                   Loss D: -1.1195, loss G: 0.5476


 85%|████████▌ | 801/938 [05:00<00:58,  2.34it/s]

Epoch [3/5] Batch 800/938                   Loss D: -1.1403, loss G: 0.6420


 96%|█████████▌| 901/938 [05:38<00:15,  2.34it/s]

Epoch [3/5] Batch 900/938                   Loss D: -1.1726, loss G: 0.5403


100%|██████████| 938/938 [05:52<00:00,  2.66it/s]
 11%|█         | 101/938 [00:37<05:56,  2.35it/s]

Epoch [4/5] Batch 100/938                   Loss D: -1.1208, loss G: 0.6198


 21%|██▏       | 201/938 [01:15<05:14,  2.34it/s]

Epoch [4/5] Batch 200/938                   Loss D: -1.1251, loss G: 0.5972


 32%|███▏      | 301/938 [01:53<04:31,  2.34it/s]

Epoch [4/5] Batch 300/938                   Loss D: -1.0786, loss G: 0.3942


 43%|████▎     | 401/938 [02:30<03:49,  2.34it/s]

Epoch [4/5] Batch 400/938                   Loss D: -1.0671, loss G: 0.6289


 53%|█████▎    | 501/938 [03:08<03:06,  2.34it/s]

Epoch [4/5] Batch 500/938                   Loss D: -1.1018, loss G: 0.6101


 64%|██████▍   | 601/938 [03:45<02:24,  2.33it/s]

Epoch [4/5] Batch 600/938                   Loss D: -1.0761, loss G: 0.5917


 75%|███████▍  | 701/938 [04:23<01:41,  2.35it/s]

Epoch [4/5] Batch 700/938                   Loss D: -1.0917, loss G: 0.5974


 85%|████████▌ | 801/938 [05:00<00:58,  2.34it/s]

Epoch [4/5] Batch 800/938                   Loss D: -1.1978, loss G: 0.6258


 96%|█████████▌| 901/938 [05:38<00:15,  2.34it/s]

Epoch [4/5] Batch 900/938                   Loss D: -1.0133, loss G: 0.3425


100%|██████████| 938/938 [05:52<00:00,  2.66it/s]
