In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms,datasets
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

# Set random seed for reproducibility
torch.manual_seed(42)

# Hyperparameters
latent_dim = 100
hidden_dim = 64
image_dim = 3  # CIFAR-10: 3 channels (RGB)
num_epochs = 150
batch_size = 128
lr = 0.0002
beta1 = 0.5

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Generator
class Generator(nn.Module):
    def __init__(self, noise_dim=100):
        super(Generator, self).__init__()
        self.noise_dim = noise_dim

        # Initial dense layer to create 7x7x128 feature map
        self.fc = nn.Linear(noise_dim, 7 * 7 * 128)
        self.model = nn.Sequential(
            nn.BatchNorm2d(128),
            nn.ReLU(True),

            # Upsample to 14x14x64
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(True),

            # Upsample to 28x28x32
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(True),

            # Final layer to 28x28x1
            nn.Conv2d(32, 1, kernel_size=3, stride=1, padding=1),
            nn.Tanh()  # Output in [-1, 1]
        )

    def forward(self, x):
        # x: [batch_size, noise_dim]
        x = self.fc(x)  # [batch_size, 5*5*512]
        x = x.view(-1, 128, 7, 7)  # Reshape to [batch_size, 512, 5, 5]
        return self.model(x)  # Pass through convolutional layers

# Discriminator
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=4, stride=2, padding=1),  # [batch_size, 64, 100, 100]
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1),  # [batch_size, 128, 50, 50]
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Flatten(),
            nn.Linear(64*7*7, 1)  # Logit output for real/fake

        )

    def forward(self, x):
        return self.model(x)

# Initialize models
generator = Generator().to(device)
discriminator = Discriminator().to(device)

# Loss and optimizers
criterion = nn.BCELoss()
g_optimizer = optim.Adam(generator.parameters(), lr=lr, betas=(beta1, 0.999))
d_optimizer = optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, 0.999))

# Data loading
transform = transforms.Compose([
    transforms.ToTensor(),  # [0, 1]
    transforms.Normalize((0.5,), (0.5,))  # [-1, 1]
])
dataset = datasets.MNIST(root='./data',train=True,download=True,transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2)
generator = Generator().to(device)
discriminator = Discriminator().to(device)

# Optimizers
g_optimizer = optim.Adam(generator.parameters(), lr=2e-4, betas=(0.5, 0.999))
d_optimizer = optim.Adam(discriminator.parameters(), lr=2e-4, betas=(0.5, 0.999))

# Loss function
criterion = nn.BCEWithLogitsLoss()

# Training loop
for epoch in range(num_epochs):
    for i, (real_images, _) in enumerate(dataloader):  # Unpack tuple, ignore labels
        real_images = real_images.to(device)
        batch_size = real_images.size(0)

        # Labels
        real_labels = torch.ones(batch_size, 1).to(device)
        fake_labels = torch.zeros(batch_size, 1).to(device)

        # Train Discriminator
        d_optimizer.zero_grad()
        real_output = discriminator(real_images)
        d_real_loss = criterion(real_output, real_labels)
        noise = torch.randn(batch_size, 100).to(device)
        fake_images = generator(noise)
        fake_output = discriminator(fake_images.detach())
        d_fake_loss = criterion(fake_output, fake_labels)
        d_loss = d_real_loss + d_fake_loss
        d_loss.backward()
        d_optimizer.step()

        # Train Generator
        g_optimizer.zero_grad()
        fake_output = discriminator(fake_images)
        g_loss = criterion(fake_output, real_labels)
        g_loss.backward()
        g_optimizer.step()

        if i % 100 == 0:
            print(f"Epoch [{epoch}/{num_epochs}] Batch [{i}/{len(dataloader)}] "
                  f"D Loss: {d_loss.item():.4f}, G Loss: {g_loss.item():.4f}")
       # Inside the training loop, after each epoch
    if epoch % 5 == 0:
        torch.save(generator.state_dict(), f'generator_epoch_{epoch}.pth')
        torch.save(discriminator.state_dict(), f'discriminator_epoch_{epoch}.pth')

    # Save generated images
    if epoch % 10 == 0:
        with torch.no_grad():
            fake_images = generator(torch.randn(16, 100).to(device))
            fake_images = fake_images * 0.5 + 0.5  # Denormalize to [0, 1]
            fake_images = fake_images.permute(0, 2, 3, 1).cpu().numpy()
            fig, axes = plt.subplots(4, 4, figsize=(8, 8))
            for i, ax in enumerate(axes.flat):
                ax.imshow(fake_images[i])
                ax.axis('off')
            plt.savefig(f'image_at_epoch_{epoch:04d}.png')
            plt.close()

Epoch [0/150] Batch [0/469] D Loss: 1.3041, G Loss: 0.9106
Epoch [0/150] Batch [100/469] D Loss: 0.7454, G Loss: 1.4936
Epoch [0/150] Batch [200/469] D Loss: 1.0467, G Loss: 1.0422
Epoch [0/150] Batch [300/469] D Loss: 1.0409, G Loss: 1.0112
Epoch [0/150] Batch [400/469] D Loss: 0.8429, G Loss: 1.2907
Epoch [1/150] Batch [0/469] D Loss: 0.7475, G Loss: 1.7775
Epoch [1/150] Batch [100/469] D Loss: 0.7041, G Loss: 1.1943
Epoch [1/150] Batch [200/469] D Loss: 0.7691, G Loss: 1.2750
Epoch [1/150] Batch [300/469] D Loss: 0.8212, G Loss: 1.6103
Epoch [1/150] Batch [400/469] D Loss: 0.8165, G Loss: 1.3937
Epoch [2/150] Batch [0/469] D Loss: 0.8964, G Loss: 0.9299
Epoch [2/150] Batch [100/469] D Loss: 0.8305, G Loss: 1.2227
Epoch [2/150] Batch [200/469] D Loss: 0.9446, G Loss: 0.9779
Epoch [2/150] Batch [300/469] D Loss: 0.9003, G Loss: 1.3757
Epoch [2/150] Batch [400/469] D Loss: 0.9255, G Loss: 1.2432
Epoch [3/150] Batch [0/469] D Loss: 0.9979, G Loss: 0.8369
Epoch [3/150] Batch [100/469] D 

KeyboardInterrupt: 