# DCGAN with FFT Training Example

This notebook demonstrates how to use the improved DCGAN with FFT implementation.

In [None]:
import os
import sys
import torch
import matplotlib.pyplot as plt

# Add the src directory to the Python path
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))

from src.config.config import config
from src.models.dcgan import create_models
from src.utils.visualization import plot_fft_components
from src.utils.data_loader import create_dataloader

## Configure the Model

First, let's configure the model parameters:

In [None]:
# Update configuration
config.model.img_size = 64
config.model.latent_dim = 100
config.training.batch_size = 64
config.training.n_epochs = 100
config.data.categories = ['a']  # Update with your categories

print("Configuration:")
print(f"Image size: {config.model.img_size}")
print(f"Latent dimension: {config.model.latent_dim}")
print(f"Batch size: {config.training.batch_size}")
print(f"Number of epochs: {config.training.n_epochs}")

## Create and Visualize Models

Let's create the generator and discriminator models and visualize their architectures:

In [None]:
# Create models
generator, discriminator = create_models(config)

print("Generator architecture:")
print(generator)
print("\nDiscriminator architecture:")
print(discriminator)

## Load and Visualize Data

Let's load some sample data and visualize the FFT components:

In [None]:
# Create dataloader
dataloader = create_dataloader(
    data_path=config.data.data_path,
    categories=config.data.categories,
    batch_size=config.training.batch_size,
    img_size=config.data.img_size
)

# Get a batch of images
imgs = next(iter(dataloader))

# Visualize FFT components for the first image
plot_fft_components(
    imgs[0],
    'fft_components.png'
)

## Train the Model

Now, let's train the model:

In [None]:
from src.train import train

# Train the model
train()

## Visualize Results

Let's visualize the training metrics and generated images:

In [None]:
import json

# Load training metrics
with open('logs/metrics.json', 'r') as f:
    metrics = json.load(f)

# Plot metrics
plt.figure(figsize=(10, 6))
plt.plot(metrics['epoch'], metrics['d_loss'], label='Discriminator Loss')
plt.plot(metrics['epoch'], metrics['g_loss'], label='Generator Loss')
plt.title('Training Metrics')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)
plt.show()