In [None]:
import torch
import torch.nn as nn
from torchvision import transforms


In [None]:
class Discriminator(nn.Module):
  def __init__(self, channels_img, features_d):
    super().__init__()
    self.disc = nn.Sequential(
        nn.Conv2d(channels_img, features_d, kernel_size=4, stride=2, padding=1),
        nn.LeakyReLU(0.2),
        self._block(features_d, features_d*2, 4, 2, 1),
        self._block(features_d*2, features_d*4, 4, 2, 1),
        self._block(features_d*4, features_d*8, 4, 2, 1),
        nn.Conv2d(features_d*8, 1, kernel_size=4, stride=2, padding=0),
        nn.Sigmoid(),
    )

  def _block(self, in_channels, out_channels, kernel_size, stride, padding):
    return nn.Sequential(
      nn.Conv2d(
        in_channels,
        out_channels,
        kernel_size,
        stride,
        padding,
        bias= False,
      ),
      nn.BatchNorm2d(out_channels),
      nn.LeakyReLU(0.2)
    )

  def forward(self, x):
    return self.disc(x)

In [None]:
class Generator(nn.Module):
  def __init__(self, z_dim, channels_img, features_g):
    super().__init__()
    self.net = nn.Sequential(
        self._block(z_dim, features_g*16, 4, 1, 0),
        self._block(features_g*16, features_g*8, 4, 2, 1),
        self._block(features_g*8, features_g*4, 4, 2, 1),
        self._block(features_g*4, features_g*2, 4, 2, 1),
        nn.ConvTranspose2d(
            features_g*2, channels_img, kernel_size=4, stride=2, padding=1
        ),
        nn.Tanh(),
    )

  def _block(self, in_channels, out_channels, kernel_size, stride, padding):
    return nn.Sequential(
        nn.ConvTranspose2d(
            in_channels,
            out_channels,
            kernel_size,
            stride,
            padding,
            bias=False,
        ),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(),
    )
  def forward(self, x):
    return self.net(x)

In [None]:
def initialize_weights(model):
  for m in model.modules():
    if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
      nn.init.normal_(m.weight.data, 0.0, 0.02)

def test():
  N, in_channels, H, W = 8, 3, 64, 64
  z_dim = 100
  x = torch.randn((N, in_channels, H, W))
  disc = Discriminator(in_channels, 8)
  initialize_weights(disc)
  assert disc(x).shape == (N, 1, 1, 1), "Discriminator test failed"
  gen = Generator(z_dim, in_channels, 8)

  z = torch.randn((N, z_dim, 1, 1))
  assert gen(z).shape == (N, in_channels, H, W), "Generator test failed"
  print("Success, tests passed!")

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

# Hyperparameters etc.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
LEARNING_RATE = 5e-5  # could also use two lrs, one for gen and one for disc
BATCH_SIZE = 64
IMAGE_SIZE = 64
CHANNELS_IMG = 1
Z_DIM = 100
NOISE_DIM = Z_DIM
NUM_EPOCHS = 5
FEATURES_DISC = 64
FEATURES_GEN = 64
CRITIC_ITERATIONS = 5
WEIGHT_CLIP = 0.01

transform = transforms.Compose(
    [
        transforms.Resize(IMAGE_SIZE),
        transforms.ToTensor(),
        transforms.Normalize(
            [0.5 for _ in range(CHANNELS_IMG)], [0.5 for _ in range(CHANNELS_IMG)]
        ),
    ]
)

# If you train on MNIST, remember to set channels_img to 1
dataset = datasets.MNIST(
    root="dataset/", train=True, transform=transform, download=True
)

# comment mnist above and uncomment below if train on CelebA
# dataset = datasets.ImageFolder(root="celeb_dataset", transform=transforms)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
gen = Generator(NOISE_DIM, CHANNELS_IMG, FEATURES_GEN).to(device)
critic = Discriminator(CHANNELS_IMG, FEATURES_DISC).to(device)
initialize_weights(gen)
initialize_weights(critic)

opt_gen = optim.RMSprop(gen.parameters(), lr=LEARNING_RATE)
opt_critic = optim.RMSprop(critic.parameters(), lr=LEARNING_RATE)


fixed_noise = torch.randn(32, NOISE_DIM, 1, 1).to(device)
writer_real = SummaryWriter(f"logs/real")
writer_fake = SummaryWriter(f"logs/fake")
step = 0



In [None]:
gen.train()
critic.train()

for epoch in range(NUM_EPOCHS):
    # Target labels not needed! <3 unsupervised
    for batch_idx, (real, _) in enumerate(dataloader):
        real = real.to(device)

        # Train Critic: max E[critic(real)] - E[critic(fake)] + lambda * ||grad critic(interpolated)||
        for _ in range(CRITIC_ITERATIONS):
            noise = torch.randn(BATCH_SIZE, NOISE_DIM, 1, 1).to(device)
            fake = gen(noise).detach()  # Detach fake from generator graph
            critic_real = critic(real).reshape(-1)
            critic_fake = critic(fake).reshape(-1)
            loss_critic = (
                -(torch.mean(critic_real) - torch.mean(critic_fake))
            )
            critic.zero_grad()
            loss_critic.backward()
            opt_critic.step()
            for p in critic.parameters():
                p.data.clamp_(-WEIGHT_CLIP, WEIGHT_CLIP)

        output = critic(fake).reshape(-1)
        loss_gen = -torch.mean(output)
        gen.zero_grad()
        loss_gen.backward()
        opt_gen.step()


        # Print losses occasionally and print to tensorboard
        if batch_idx % 100 == 0:
            print(
                f"Epoch [{epoch}/{NUM_EPOCHS}] Batch {batch_idx}/{len(dataloader)} \
                  Loss D: {loss_critic:.4f}, loss G: {loss_gen:.4f}"
            )

            with torch.no_grad():
                fake = gen(fixed_noise)
                # take out (up to) 32 examples
                img_grid_real = torchvision.utils.make_grid(real[:32], normalize=True)
                img_grid_fake = torchvision.utils.make_grid(fake[:32], normalize=True)

                writer_real.add_image("Real", img_grid_real, global_step=step)
                writer_fake.add_image("Fake", img_grid_fake, global_step=step)

            step += 1


Epoch [0/5] Batch 0/938                   Loss D: -0.0271, loss G: -0.4964
Epoch [0/5] Batch 100/938                   Loss D: -0.3642, loss G: -0.3245
Epoch [0/5] Batch 200/938                   Loss D: -0.3671, loss G: -0.3232
Epoch [0/5] Batch 300/938                   Loss D: -0.3678, loss G: -0.3229
Epoch [0/5] Batch 400/938                   Loss D: -0.3681, loss G: -0.3227
Epoch [0/5] Batch 500/938                   Loss D: -0.3689, loss G: -0.3221
Epoch [0/5] Batch 600/938                   Loss D: -0.3696, loss G: -0.3213
Epoch [0/5] Batch 700/938                   Loss D: -0.3704, loss G: -0.3208
Epoch [0/5] Batch 800/938                   Loss D: -0.3706, loss G: -0.3207
Epoch [0/5] Batch 900/938                   Loss D: -0.3707, loss G: -0.3205
Epoch [1/5] Batch 0/938                   Loss D: -0.3710, loss G: -0.3204
Epoch [1/5] Batch 100/938                   Loss D: -0.3711, loss G: -0.3204
Epoch [1/5] Batch 200/938                   Loss D: -0.3711, loss G: -0.3202
Epo

In [None]:
# Save the generator and discriminator models
torch.save(gen.state_dict(), 'generator.pth')
torch.save(critic.state_dict(), 'discriminator.pth')

print("Models saved successfully!")

Models saved successfully!
