In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchvision.utils import save_image
import os

# Hyperparameters
batch_size = 64
lr = 0.00005
z_dim = 100
n_epochs = 100
n_critic = 5  # Number of critic iterations per generator iteration
clip_value = 0.01  # Clipping value for critic weights
sample_interval = 100  # Interval to save generated images

# Create output directory
os.makedirs('images', exist_ok=True)

# Data loading
transform = transforms.Compose([
    transforms.Resize(32),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

dataset = datasets.CIFAR10(root='./data', train=True, transform=transform, download=True)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Generator
class Generator(nn.Module):
    def __init__(self, z_dim):
        super(Generator, self).__init__()
        self.init_size = 8  # Initial size before upsampling
        self.l1 = nn.Sequential(nn.Linear(z_dim, 128 * self.init_size ** 2))
        
        self.conv_blocks = nn.Sequential(
            nn.BatchNorm2d(128),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 128, 3, stride=1, padding=1),
            nn.BatchNorm2d(128, 0.8),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 64, 3, stride=1, padding=1),
            nn.BatchNorm2d(64, 0.8),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 3, 3, stride=1, padding=1),
            nn.Tanh(),
        )

    def forward(self, z):
        out = self.l1(z)
        out = out.view(out.shape[0], 128, self.init_size, self.init_size)
        img = self.conv_blocks(out)
        return img

# Critic (Discriminator)
class Critic(nn.Module):
    def __init__(self):
        super(Critic, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 64, 3, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, 3, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, 3, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, 3, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Flatten(),
            nn.Linear(512 * 2 * 2, 1)
        )

    def forward(self, img):
        return self.model(img)

# Initialize generator and critic
generator = Generator(z_dim).cuda()
critic = Critic().cuda()

# Optimizers
optimizer_G = optim.RMSprop(generator.parameters(), lr=lr)
optimizer_C = optim.RMSprop(critic.parameters(), lr=lr)

# Training loop
for epoch in range(n_epochs):
    for i, (imgs, _) in enumerate(dataloader):
        imgs = imgs.cuda()

        # Train Critic
        optimizer_C.zero_grad()
        z = torch.randn(imgs.size(0), z_dim).cuda()
        generated_imgs = generator(z).detach()
        loss_C = -(critic(imgs).mean() - critic(generated_imgs).mean())
        loss_C.backward()
        optimizer_C.step()

        # Clip weights of critic
        for p in critic.parameters():
            p.data.clamp_(-clip_value, clip_value)

        # Train Generator every n_critic iterations
        if i % n_critic == 0:
            optimizer_G.zero_grad()
            generated_imgs = generator(z)
            loss_G = -critic(generated_imgs).mean()
            loss_G.backward()
            optimizer_G.step()

        # Save generated images
        if i % sample_interval == 0:
            save_image(generated_imgs.data[:25], f"images/{epoch}_{i}.png", nrow=5, normalize=True)

    print(f"[Epoch {epoch}/{n_epochs}] [Critic Loss: {loss_C.item()}] [Generator Loss: {loss_G.item()}]")

# Save the generator model
torch.save(generator.state_dict(), "generator.pth")
