In [1]:
from __future__ import print_function
import os
import random
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils


In [2]:
# Configuration
dataset_name = 'cifar10'  # Options: 'cifar10', 'mnist', 'fake'
dataroot = './data'  # Path to dataset
batch_size = 64
image_size = 64
nz = 100  # Latent vector size
ngf = 64  # Generator feature map size
ndf = 64  # Discriminator feature map size
niter = 25  # Number of epochs
lr = 0.0002  # Learning rate
beta1 = 0.5  # Beta1 for Adam optimizer
ngpu = 1  # Number of GPUs to use
outf = './output'  # Folder to save outputs

# Set random seed
manual_seed = random.randint(1, 10000)
print("Random Seed:", manual_seed)
random.seed(manual_seed)
torch.manual_seed(manual_seed)

# Device setup
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Create output folder if it doesn't exist
os.makedirs(outf, exist_ok=True)

# Enable cuDNN auto-tuner
cudnn.benchmark = True


Random Seed: 2232
Using device: cuda:0


In [3]:
def load_dataset(dataset_name, dataroot, image_size):
    if dataset_name == 'cifar10':
        dataset = dset.CIFAR10(
            root=dataroot, download=True,
            transform=transforms.Compose([
                transforms.Resize(image_size),
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            ])
        )
        nc = 3
    elif dataset_name == 'mnist':
        dataset = dset.MNIST(
            root=dataroot, download=True,
            transform=transforms.Compose([
                transforms.Resize(image_size),
                transforms.ToTensor(),
                transforms.Normalize((0.5,), (0.5,)),
            ])
        )
        nc = 1
    elif dataset_name == 'fake':
        dataset = dset.FakeData(
            image_size=(3, image_size, image_size),
            transform=transforms.ToTensor()
        )
        nc = 3
    else:
        raise ValueError(f"Unsupported dataset: {dataset_name}")

    dataloader = torch.utils.data.DataLoader(
        dataset, batch_size=batch_size, shuffle=True, num_workers=2
    )
    return dataloader, nc


In [4]:
def weights_init(m):
    if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
        nn.init.normal_(m.weight, 0.0, 0.02)
    elif isinstance(m, nn.BatchNorm2d):
        nn.init.normal_(m.weight, 1.0, 0.02)
        nn.init.zeros_(m.bias)

class Generator(nn.Module):
    def __init__(self, ngpu):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, input):
        return self.main(input)

In [5]:
class Discriminator(nn.Module):
    def __init__(self, ngpu, ndf, nc):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),  # (ndf) x 32 x 32
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),  # (ndf*2) x 16 x 16
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),  # (ndf*4) x 8 x 8
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),  # (ndf*8) x 4 x 4
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),  # 1 x 1 x 1
            nn.Sigmoid()  # Output a scalar value between 0 and 1
        )

    def forward(self, input):
        return self.main(input).view(-1)  # Flatten to (batch_size,)


In [6]:
dataloader, nc = load_dataset(dataset_name, dataroot, image_size)

netG = Generator(ngpu).to(device)
netG.apply(weights_init)

netD = Discriminator(ngpu,ndf,nc).to(device)
netD.apply(weights_init)

criterion = nn.BCELoss()

optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))

fixed_noise = torch.randn(batch_size, nz, 1, 1, device=device)
real_label = 1
fake_label = 0

Files already downloaded and verified


In [7]:
for epoch in range(niter):
    for i, data in enumerate(dataloader):
        # Update Discriminator: maximize log(D(x)) + log(1 - D(G(z)))
        netD.zero_grad()

        real_data = data[0].to(device)
        batch_size = real_data.size(0)
        real_label = torch.full((batch_size,), 1., device=device)
        fake_label = torch.full((batch_size,), 0., device=device)

        # Forward pass real batch through D
        output = netD(real_data)  # Output shape: (batch_size,)
        errD_real = criterion(output, real_label)
        errD_real.backward()

        # Generate fake images
        noise = torch.randn(batch_size, nz, 1, 1, device=device)
        fake_data = netG(noise)

        # Forward pass fake batch through D
        output = netD(fake_data.detach())  # Output shape: (batch_size,)
        errD_fake = criterion(output, fake_label)
        errD_fake.backward()
        optimizerD.step()

        # Update Generator: maximize log(D(G(z)))
        netG.zero_grad()
        output = netD(fake_data)  # Output shape: (batch_size,)
        errG = criterion(output, real_label)  # Fake labels as real for generator loss
        errG.backward()
        optimizerG.step()

        if i % 10 == 0:
            print(f'[{epoch}/{niter}][{i}/{len(dataloader)}] '
                  f'Loss_D: {errD_real + errD_fake:.4f} Loss_G: {errG:.4f}')


[0/25][0/782] Loss_D: 1.6777 Loss_G: 5.9508
[0/25][10/782] Loss_D: 0.5796 Loss_G: 7.0942
[0/25][20/782] Loss_D: 3.0634 Loss_G: 16.1639
[0/25][30/782] Loss_D: 0.2278 Loss_G: 13.7360
[0/25][40/782] Loss_D: 0.0799 Loss_G: 8.7227
[0/25][50/782] Loss_D: 0.0773 Loss_G: 17.6230
[0/25][60/782] Loss_D: 0.4639 Loss_G: 28.5448
[0/25][70/782] Loss_D: 0.1565 Loss_G: 7.7093
[0/25][80/782] Loss_D: 0.1199 Loss_G: 16.1576
[0/25][90/782] Loss_D: 0.0510 Loss_G: 13.3942
[0/25][100/782] Loss_D: 0.0976 Loss_G: 8.9791
[0/25][110/782] Loss_D: 0.0811 Loss_G: 8.7663
[0/25][120/782] Loss_D: 0.0776 Loss_G: 10.0992
[0/25][130/782] Loss_D: 0.0165 Loss_G: 24.6063
[0/25][140/782] Loss_D: 0.1858 Loss_G: 9.3323
[0/25][150/782] Loss_D: 0.7988 Loss_G: 11.8088
[0/25][160/782] Loss_D: 0.7648 Loss_G: 4.2410
[0/25][170/782] Loss_D: 0.4491 Loss_G: 6.6321
[0/25][180/782] Loss_D: 0.4890 Loss_G: 2.6555
[0/25][190/782] Loss_D: 0.7242 Loss_G: 3.4937
[0/25][200/782] Loss_D: 1.6807 Loss_G: 4.6685
[0/25][210/782] Loss_D: 0.2898 Loss_