# Task 9: Attention GAN

This notebook implements a GAN with self-attention mechanism for improved image generation.


In [None]:
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
import os


## Self-Attention Module


In [None]:
class SelfAttention(nn.Module):
    def __init__(self, in_dim):
        super(SelfAttention, self).__init__()
        self.query = nn.Linear(in_dim, in_dim)
        self.key = nn.Linear(in_dim, in_dim)
        self.value = nn.Linear(in_dim, in_dim)
        self.gamma = nn.Parameter(torch.zeros(1))
    
    def forward(self, x):
        batch_size, channels, height, width = x.size()
        x_flat = x.view(batch_size, channels, height * width).permute(0, 2, 1)
        
        q = self.query(x_flat)
        k = self.key(x_flat)
        v = self.value(x_flat)
        
        attn = torch.softmax(q @ k.transpose(-2, -1), dim=-1)
        out = (attn @ v).permute(0, 2, 1).view(batch_size, channels, height, width)
        
        return self.gamma * out + x


## Generator with Attention


In [None]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.fc1 = nn.Linear(102, 128)
        self.relu = nn.ReLU()
        self.attention = SelfAttention(128)
        self.fc2 = nn.Linear(128, 28*28)
        self.tanh = nn.Tanh()
    
    def forward(self, z, labels):
        x = torch.cat((z, labels), dim=1)
        x = self.fc1(x)
        x = self.relu(x)
        # Reshape for attention (simplified - in practice would need proper reshaping)
        x = x.unsqueeze(-1).unsqueeze(-1)  # Add spatial dimensions
        x = self.attention(x)
        x = x.squeeze(-1).squeeze(-1)  # Remove spatial dimensions
        x = self.fc2(x)
        return self.tanh(x).view(-1, 28, 28)


## Discriminator


In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(28*28+2, 128),
            nn.ReLU(),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )
    
    def forward(self, img, labels):
        x = torch.cat((img.view(-1, 28*28), labels), dim=1)
        return self.model(x)


## Initialize and Train


In [None]:
generator = Generator()
discriminator = Discriminator()
criterion = nn.BCELoss()
g_optimizer = torch.optim.Adam(generator.parameters(), lr=0.001)
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=0.001)

labels_dict = {
    'square': torch.tensor([1, 0], dtype=torch.float32), 
    'circle': torch.tensor([0, 1], dtype=torch.float32)
}

print("Attention GAN models initialized!")


In [None]:
# Training loop
num_epochs = 100
g_losses = []
d_losses = []

for epoch in range(num_epochs):
    for label_name, label_tensor in labels_dict.items():
        label_tensor = label_tensor.unsqueeze(0)
        z = torch.randn(1, 100)
        
        # Train Discriminator
        d_optimizer.zero_grad()
        fake_img = generator(z, label_tensor)
        d_fake = discriminator(fake_img.detach(), label_tensor)
        d_loss = criterion(d_fake, torch.zeros(1))
        d_loss.backward()
        d_optimizer.step()
        
        # Train Generator
        g_optimizer.zero_grad()
        g_fake = discriminator(fake_img, label_tensor)
        g_loss = criterion(g_fake, torch.ones(1))
        g_loss.backward()
        g_optimizer.step()
        
        g_losses.append(g_loss.item())
        d_losses.append(d_loss.item())
    
    if (epoch + 1) % 10 == 0:
        print(f"Epoch [{epoch+1}/{num_epochs}] - G_Loss: {g_loss.item():.4f}, D_Loss: {d_loss.item():.4f}")

print("Training completed!")


In [None]:
# Visualize losses
plt.figure(figsize=(10, 5))
plt.plot(g_losses, label='Generator Loss')
plt.plot(d_losses, label='Discriminator Loss')
plt.xlabel('Iteration')
plt.ylabel('Loss')
plt.title('Training Losses - Attention GAN')
plt.legend()
plt.grid(True)
plt.show()

# Generate samples
fig, axes = plt.subplots(1, 2, figsize=(10, 5))
for idx, (label_name, label_tensor) in enumerate(labels_dict.items()):
    z = torch.randn(1, 100)
    label_tensor = label_tensor.unsqueeze(0)
    
    with torch.no_grad():
        generated_img = generator(z, label_tensor).detach().numpy()
    
    axes[idx].imshow(generated_img[0], cmap='gray')
    axes[idx].set_title(f"Generated {label_name.capitalize()} with Attention")
    axes[idx].axis('off')

plt.tight_layout()
plt.show()
