In [None]:
import torch
import torch.nn as nn

In [None]:
class Discriminator(nn.Module):
  def __init__(self, channels_img, features_d):
    super(Discriminator, self).__init__()
    self.disc = nn.Sequential(
        # input: N x channels_img x 64 x 64
        nn.Conv2d(
            channels_img,
            features_d,
            kernel_size=4,
            stride=2,
            padding=1,
        ),
        nn.LeakyReLU(0.2),
        self._block(features_d, features_d*2, 4, 2, 1), # 16*16
        self._block(features_d*2, features_d*4, 4, 2, 1), # 8*8
        self._block(features_d*4, features_d*8, 4, 2, 1), # 4*4
        nn.Conv2d(features_d*8, 1, kernel_size=4, stride=2, padding=0), # 1*1
        nn.Sigmoid(),
    )

  def _block(self, in_channels, out_channels, kernel_size, stride, padding):
    return nn.Sequential(
        nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size,
            stride,
            padding,
            bias=False,
        ),
        nn.BatchNorm2d(out_channels),
        nn.LeakyReLU(0.2),
    )

  def forward(self, x):
    return self.disc(x)

class Generator(nn.Module):
  def __init__(self, z_dim, channels_img, features_g):
    super(Generator, self).__init__()
    self.gen = nn.Sequential(
        # input: N x z_dim x 1 x 1
        self._block(z_dim, features_g*16, 4, 1, 0), # N x f_
        self._block(features_g*16, features_g*8, 4, 2, 1), # 8*8
        self._block(features_g*8, features_g*4, 4, 2, 1), # 16*16
        self._block(features_g*4, features_g*2, 4, 2, 1), # 32*32
        nn.ConvTranspose2d(
            features_g*2,
            channels_img,
            kernel_size=4,
            stride=2,
            padding=1,
        ),
        nn.Tanh(), # [-1, 1]
    )
  def _block(self, in_channels, out_channels, kernel_size, stride, padding):
    return nn.Sequential(
        nn.ConvTranspose2d(
            in_channels,
            out_channels,
            kernel_size,
            stride,
            padding,
            bias=False,
        ),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(),
    )

  def forward(self, x):
    return self.gen(x)

def initialize_weights(model):
  for m in model.modules():
    if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
      nn.init.normal_(m.weight.data, 0.0, 0.02)

def test():
  N, in_channels, H, W = 8, 3, 64, 64
  z_dim = 100
  x = torch.randn((N, in_channels, H, W))
  disc = Discriminator(in_channels, 8)
  initialize_weights(disc)
  assert disc(x).shape == (N, 1, 1, 1)
  gen = Generator(z_dim, in_channels, 8)
  initialize_weights(gen)
  z = torch.randn((N, z_dim, 1, 1))
  assert gen(z).shape == (N, in_channels, H, W)
  print("Success")

test()

Success


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
LEARNING  = 2e-4
BATCH_SIZE = 128
IMAGE_SIZE = 64
CHANNELS_IMG = 1
Z_DIM = 100
NUM_EPOCHS = 5
FEATURES_DISC = 64
FEATURES_GEN = 64

transforms = transforms.Compose(
    [
        transforms.Resize(IMAGE_SIZE),
        transforms.ToTensor(),
        transforms.Normalize(
            [0.5 for _ in range(CHANNELS_IMG)], [0.5 for _ in range(CHANNELS_IMG)]
        ),
    ]
)

dataset = datasets.MNIST(root="dataset/", train=True, transform=transforms, download=True)
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
gen = Generator(Z_DIM, CHANNELS_IMG, FEATURES_GEN).to(device)
disc = Discriminator(CHANNELS_IMG, FEATURES_DISC).to(device)
initialize_weights(gen)
initialize_weights(disc)

opt_gen = optim.Adam(gen.parameters(), lr=LEARNING, betas=(0.5, 0.999))
opt_disc = optim.Adam(disc.parameters(), lr=LEARNING, betas=(0.5, 0.999))
criterion = nn.BCELoss()

fixed_noise = torch.randn(32, Z_DIM, 1, 1).to(device)
writer_real = SummaryWriter(f"logs/real")
writer_fake = SummaryWriter(f"logs/fake")
step = 0

gen.train()
disc.train()

for epoch in range(NUM_EPOCHS):
  for batch_idx, (real, _) in enumerate(loader):
    real = real.to(device)
    noise = torch.randn(BATCH_SIZE, Z_DIM, 1, 1).to(device)
    fake = gen(noise)
    # Train Discriminator
    disc_real = disc(real).reshape(-1)
    loss_disc_real = criterion(disc_real, torch.ones_like(disc_real))
    disc_fake = disc(gen(noise)).reshape(-1)
    loss_disc_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
    loss_disc = (loss_disc_real + loss_disc_fake) / 2
    disc.zero_grad()
    loss_disc.backward(retain_graph=True)
    opt_disc.step()
    # Train Generator
    output = disc(gen(noise)).reshape(-1)
    loss_gen = criterion(output, torch.ones_like(output))
    gen.zero_grad()
    loss_gen.backward()
    opt_gen.step()

    if batch_idx % 100 == 0:
      print(
          f"Epoch [{epoch}/{NUM_EPOCHS}] Batch {batch_idx}/{len(loader)} \
          Loss D: {loss_disc:.4f}, loss G: {loss_gen:.4f}"
      )
      with torch.no_grad():
        fake = gen(fixed_noise)
        img_grid_real = torchvision.utils.make_grid(real[:32], normalize=True)
        img_grid_fake = torchvision.utils.make_grid(fake[:32], normalize=True)
        writer_real.add_image("Real", img_grid_real, global_step=step)
        writer_fake.add_image("Fake", img_grid_fake, global_step=step)
      step += 1



Epoch [0/5] Batch 0/469           Loss D: 0.6856, loss G: 0.7932
Epoch [0/5] Batch 100/469           Loss D: 0.0148, loss G: 4.1228
Epoch [0/5] Batch 200/469           Loss D: 0.5916, loss G: 0.8756
Epoch [0/5] Batch 300/469           Loss D: 0.6170, loss G: 0.9310
Epoch [0/5] Batch 400/469           Loss D: 0.5766, loss G: 1.0744
Epoch [1/5] Batch 0/469           Loss D: 0.5454, loss G: 0.9104
Epoch [1/5] Batch 100/469           Loss D: 0.6061, loss G: 0.8828
Epoch [1/5] Batch 200/469           Loss D: 0.6389, loss G: 1.6866
Epoch [1/5] Batch 300/469           Loss D: 0.5534, loss G: 0.9648
Epoch [1/5] Batch 400/469           Loss D: 0.5632, loss G: 1.0448
Epoch [2/5] Batch 0/469           Loss D: 0.6251, loss G: 0.6870
Epoch [2/5] Batch 100/469           Loss D: 0.5721, loss G: 1.2046
Epoch [2/5] Batch 200/469           Loss D: 0.6816, loss G: 2.5336
Epoch [2/5] Batch 300/469           Loss D: 0.4524, loss G: 1.4973
Epoch [2/5] Batch 400/469           Loss D: 0.4204, loss G: 0.8721
E

In [None]:
import torch
import torchvision.utils as vutils

# Set the Generator to evaluation mode
gen.eval()

# Create a new batch of noise vectors
with torch.no_grad():
    noise = torch.randn(32, Z_DIM, 1, 1).to(device)
    fake_images = gen(noise)

    # Normalize the images for visualization
    fake_images = (fake_images + 1) / 2  # Scale from [-1, 1] to [0, 1]

    # Create a grid of images
    img_grid = vutils.make_grid(fake_images, padding=2, normalize=True)

    # Save the generated images
    vutils.save_image(img_grid, "generated_images.png")
