# Generating Artistic Abstract Images with GANs

This notebook explores the use of Generative Adversarial Networks (GANs) to create unique and artistic abstract images.
The project focuses on training a GAN using the [Abstract Art dataset](https://www.kaggle.com/datasets/greg115/abstract-art).

## Objectives
- Understand how GANs can generate visually diverse artistic outputs.
- Train a GAN using a dataset of abstract art images.
- Generate and evaluate the quality of artistic images produced by the model.

---
## Part 1: Dataset Preparation

In [None]:
# Import necessary libraries
import os
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
from torchvision.utils import make_grid

# Define dataset path
DATA_PATH = '../input/abstract-art/'  # Adjust as needed for Kaggle Notebook

# Define image transformations
transform = transforms.Compose([
    transforms.Resize((128, 128)),  # Resize images
    transforms.ToTensor(),  # Convert images to tensors
    transforms.Normalize([0.5], [0.5])  # Normalize pixel values to [-1, 1]
])

# Load dataset
dataset = datasets.ImageFolder(root=DATA_PATH, transform=transform)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

print(f'Dataset size: {len(dataset)} images')

## Part 2: Building the GAN

### Defining the Generator and Discriminator

In [None]:
# Define the Generator
class Generator(nn.Module):
    def __init__(self, latent_dim=100):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_dim, 128*16*16),
            nn.BatchNorm1d(128*16*16),
            nn.ReLU(),
            nn.Unflatten(1, (128, 16, 16)),
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 3, 4, stride=2, padding=1),
            nn.Tanh()
        )

    def forward(self, z):
        return self.model(z)

# Define the Discriminator
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 64, 4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, 4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.Flatten(),
            nn.Linear(128*16*16, 1),
            nn.Sigmoid()
        )

    def forward(self, img):
        return self.model(img)

### Training the GAN

Using an adversarial training loop to optimize both models.

In [None]:
# Hyperparameters
latent_dim = 100
num_epochs = 100
lr = 0.0002

# Initialize models and optimizers
generator = Generator(latent_dim)
discriminator = Discriminator()
loss_fn = nn.BCELoss()
optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))

# Training Loop
for epoch in range(num_epochs):
    for i, (real_imgs, _) in enumerate(dataloader):
        # Train Discriminator
        optimizer_D.zero_grad()
        real_labels = torch.ones(real_imgs.size(0), 1)
        fake_labels = torch.zeros(real_imgs.size(0), 1)
        real_loss = loss_fn(discriminator(real_imgs), real_labels)
        z = torch.randn(real_imgs.size(0), latent_dim)
        fake_imgs = generator(z)
        fake_loss = loss_fn(discriminator(fake_imgs.detach()), fake_labels)
        d_loss = real_loss + fake_loss
        d_loss.backward()
        optimizer_D.step()

        # Train Generator
        optimizer_G.zero_grad()
        g_loss = loss_fn(discriminator(fake_imgs), real_labels)
        g_loss.backward()
        optimizer_G.step()

    print(f'Epoch {epoch+1}/{num_epochs} - D Loss: {d_loss.item()} - G Loss: {g_loss.item()}')

### Generating Abstract Art Images
Visualizing the outputs from the trained generator.

This section will:

1.  Generate new images using the trained **Generator**.
2.  Save and display the generated images.
3.  Provide basic evaluation metrics and visualization.

### **Generating and Evaluating Artistic Abstract Images**

In [None]:
# Function to generate and visualize abstract images from the trained GAN
def generate_images(generator, latent_dim, num_images=16):
    generator.eval()  # Set the generator to evaluation mode
    with torch.no_grad():
        # Generate random noise vectors
        z = torch.randn(num_images, latent_dim)
        # Generate images
        generated_imgs = generator(z).cpu()
    
    # Denormalize images (since we normalized them in range [-1, 1])
    generated_imgs = (generated_imgs + 1) / 2  # Scale back to [0,1]
    
    # Plot the generated images
    fig, axes = plt.subplots(4, 4, figsize=(8, 8))
    for i, ax in enumerate(axes.flatten()):
        ax.imshow(np.transpose(generated_imgs[i].numpy(), (1, 2, 0)))
        ax.axis("off")
    
    plt.show()

# Generate and display new abstract art images
generate_images(generator, latent_dim, num_images=16)


### **Evaluation: How Good Are the Generated Images?**

Since evaluating abstract art is highly subjective, you can use:

1.  **Visual Diversity**: Look at the variations in generated images.
2.  **FID Score (Frechet Inception Distance)**: Measures similarity to real images.
3.  **User Feedback**: Ask users if the generated images look like abstract art.

For a basic FID Score evaluation:

In [None]:
from torchvision.models import inception_v3
from scipy.linalg import sqrtm

# Function to compute the FID Score
def compute_fid(real_images, fake_images):
    model = inception_v3(pretrained=True, transform_input=False)
    model.fc = torch.nn.Identity()  # Remove the classification layer
    model.eval()

    # Convert images to InceptionV3 features
    with torch.no_grad():
        real_features = model(real_images).cpu().numpy()
        fake_features = model(fake_images).cpu().numpy()

    # Compute mean and covariance
    mu_real, sigma_real = real_features.mean(axis=0), np.cov(real_features, rowvar=False)
    mu_fake, sigma_fake = fake_features.mean(axis=0), np.cov(fake_features, rowvar=False)

    # Compute FID score
    diff = mu_real - mu_fake
    covmean, _ = sqrtm(sigma_real.dot(sigma_fake), disp=False)
    
    if np.iscomplexobj(covmean):
        covmean = covmean.real  # Convert to real numbers if complex

    fid_score = np.sum(diff**2) + np.trace(sigma_real + sigma_fake - 2 * covmean)
    return fid_score

# Select real images from dataset and generated images for FID evaluation
real_images, _ = next(iter(dataloader))  # Get a batch of real images
real_images = real_images[:16]  # Take a subset
fake_images = generator(torch.randn(16, latent_dim)).detach()

# Compute FID Score
fid_score = compute_fid(real_images, fake_images)
print(f"FID Score: {fid_score:.2f}")
