# Task 7: Conditional GAN for Basic Shapes

This notebook implements a Conditional GAN (CGAN) to generate basic shapes (squares and circles) based on conditional labels.


In [None]:
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
import os


## Model Architecture


In [None]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(102, 128),
            nn.ReLU(),
            nn.Linear(128, 28*28),
            nn.Tanh()
        )
    
    def forward(self, z, labels):
        x = torch.cat((z, labels), dim=1)
        return self.model(x).view(-1, 28, 28)

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(28*28+2, 128),
            nn.ReLU(),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )
    
    def forward(self, img, labels):
        x = torch.cat((img.view(-1, 28*28), labels), dim=1)
        return self.model(x)


## Initialize Models and Training Setup


In [None]:
generator = Generator()
discriminator = Discriminator()
criterion = nn.BCELoss()
g_optimizer = torch.optim.Adam(generator.parameters(), lr=0.001)
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=0.001)

labels_dict = {
    'square': torch.tensor([1, 0], dtype=torch.float32), 
    'circle': torch.tensor([0, 1], dtype=torch.float32)
}

print("Models initialized successfully!")
print(f"Generator parameters: {sum(p.numel() for p in generator.parameters())}")
print(f"Discriminator parameters: {sum(p.numel() for p in discriminator.parameters())}")


## Training Loop


In [None]:
# Training parameters
num_epochs = 100
batch_size = 1

# Store losses for visualization
g_losses = []
d_losses = []

# Training loop
for epoch in range(num_epochs):
    for label_name, label_tensor in labels_dict.items():
        label_tensor = label_tensor.unsqueeze(0)
        z = torch.randn(batch_size, 100)
        
        # Train Discriminator
        d_optimizer.zero_grad()
        fake_img = generator(z, label_tensor)
        d_fake = discriminator(fake_img.detach(), label_tensor)
        d_loss = criterion(d_fake, torch.zeros_like(d_fake))
        d_loss.backward()
        d_optimizer.step()
        
        # Train Generator
        g_optimizer.zero_grad()
        g_fake = discriminator(fake_img, label_tensor)
        g_loss = criterion(g_fake, torch.ones_like(g_fake))
        g_loss.backward()
        g_optimizer.step()
        
        g_losses.append(g_loss.item())
        d_losses.append(d_loss.item())
    
    if (epoch + 1) % 10 == 0:
        print(f"Epoch [{epoch+1}/{num_epochs}] - G_Loss: {g_loss.item():.4f}, D_Loss: {d_loss.item():.4f}")

print("Training completed!")


## Visualize Training Losses


In [None]:
plt.figure(figsize=(10, 5))
plt.plot(g_losses, label='Generator Loss')
plt.plot(d_losses, label='Discriminator Loss')
plt.xlabel('Iteration')
plt.ylabel('Loss')
plt.title('Training Losses')
plt.legend()
plt.grid(True)
plt.show()


## Generate and Visualize Samples


In [None]:
# Generate samples for each label
fig, axes = plt.subplots(1, 2, figsize=(10, 5))

for idx, (label_name, label_tensor) in enumerate(labels_dict.items()):
    z = torch.randn(1, 100)
    label_tensor = label_tensor.unsqueeze(0)
    
    with torch.no_grad():
        generated_img = generator(z, label_tensor).detach().numpy()
    
    axes[idx].imshow(generated_img[0], cmap='gray')
    axes[idx].set_title(f"Generated {label_name.capitalize()}")
    axes[idx].axis('off')

plt.tight_layout()
plt.show()


## Save Model


In [None]:
# Create saved_model directory
os.makedirs('saved_model', exist_ok=True)

# Save models
torch.save(generator.state_dict(), 'saved_model/generator.pth')
torch.save(discriminator.state_dict(), 'saved_model/discriminator.pth')
torch.save({
    'g_optimizer': g_optimizer.state_dict(),
    'd_optimizer': d_optimizer.state_dict(),
    'g_losses': g_losses,
    'd_losses': d_losses,
}, 'saved_model/training_state.pth')

print("Models saved successfully in saved_model/ directory!")
