# Module 8.1 – GANs for Data Augmentation in Semiconductor Manufacturing

This notebook demonstrates synthetic wafer map generation using Deep Convolutional GANs (DCGANs) for data augmentation in semiconductor manufacturing applications.

## Learning Objectives
- Understand GAN architecture and training dynamics
- Generate synthetic wafer maps and defect patterns
- Evaluate generated sample quality
- Apply GANs for data augmentation in imbalanced datasets

## Prerequisites
- Basic understanding of neural networks
- Familiarity with PyTorch tensors
- Knowledge of semiconductor manufacturing terminology

In [None]:
# Import required libraries
import sys
import os
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torchvision.utils as vutils
from pathlib import Path
import json
import warnings
warnings.filterwarnings('ignore')

# Add the module path to sys.path
module_path = Path('.').resolve()
if str(module_path) not in sys.path:
    sys.path.append(str(module_path))

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

print(f"PyTorch version: {torch.__version__}")
print(f"Device available: {'CUDA' if torch.cuda.is_available() else 'CPU'}")

## 1. Understanding the Problem: Wafer Map Data Augmentation

Semiconductor wafer maps show spatial patterns of defects across the wafer surface. These patterns are crucial for:
- Process monitoring and control
- Root cause analysis of yield issues
- Predictive maintenance of equipment

However, rare defect patterns are difficult to model due to limited training examples. GANs can generate synthetic examples to balance datasets.

In [None]:
# Import our GAN pipeline
from importlib import reload
import subprocess

# First, let's look at what our synthetic wafer data looks like
# We'll use the pipeline to create and visualize synthetic wafer patterns

# Create a simple synthetic dataset to understand the problem
try:
    # Import the pipeline components
    import importlib.util
    spec = importlib.util.spec_from_file_location(
        "gans_pipeline", 
        "8.1-gans-data-augmentation-pipeline.py"
    )
    gans_module = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(gans_module)
    
    GANsPipeline = gans_module.GANsPipeline
    SyntheticWaferDataset = gans_module.SyntheticWaferDataset
    
    print("Successfully imported GAN pipeline components")
except Exception as e:
    print(f"Import error: {e}")
    print("Make sure you're running this notebook from the module-8 directory")

In [None]:
# Let's first examine synthetic wafer patterns to understand what we're trying to generate
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from PIL import Image

# Create a synthetic dataset for demonstration
transform = transforms.Compose([
    transforms.Resize(64),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))  # Normalize to [-1, 1]
])

synthetic_dataset = SyntheticWaferDataset(
    num_samples=16, 
    image_size=64, 
    transform=transform
)

# Create a dataloader
dataloader = DataLoader(synthetic_dataset, batch_size=16, shuffle=False)

# Get a batch of synthetic wafer patterns
real_batch = next(iter(dataloader))

# Visualize the synthetic wafer patterns
plt.figure(figsize=(12, 8))
plt.subplot(1, 2, 1)
plt.title("Synthetic Wafer Patterns (Training Data)")
grid = vutils.make_grid(real_batch, nrow=4, normalize=True, value_range=(-1, 1))
plt.imshow(np.transpose(grid, (1, 2, 0)))
plt.axis('off')

# Show statistics
plt.subplot(1, 2, 2)
plt.title("Pixel Intensity Distribution")
pixel_values = real_batch.numpy().flatten()
plt.hist(pixel_values, bins=50, alpha=0.7, color='blue')
plt.xlabel('Pixel Intensity')
plt.ylabel('Frequency')
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"Batch shape: {real_batch.shape}")
print(f"Value range: [{real_batch.min():.3f}, {real_batch.max():.3f}]")
print(f"Mean: {real_batch.mean():.3f}, Std: {real_batch.std():.3f}")

## 2. GAN Architecture Overview

Our DCGAN consists of two neural networks competing against each other:

**Generator (G)**: Transforms random noise → realistic wafer patterns  
**Discriminator (D)**: Distinguishes real wafer patterns from generated ones

### Training Process:
1. **D step**: Train discriminator to better identify real vs fake
2. **G step**: Train generator to better fool the discriminator
3. **Repeat**: Until equilibrium is reached

Let's examine the network architectures:

In [None]:
# Create GAN pipeline and examine network architectures
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Initialize the pipeline
pipeline = GANsPipeline(
    model_type='dcgan',
    image_size=64,
    latent_dim=100,
    batch_size=16,  # Small batch for notebook
    device=device
)

# Create the networks to examine their structure
generator = gans_module.Generator(latent_dim=100, image_size=64).to(device)
discriminator = gans_module.Discriminator(image_size=64).to(device)

print("=== Generator Architecture ===")
print(generator)
print(f"\nTotal Generator Parameters: {sum(p.numel() for p in generator.parameters()):,}")

print("\n=== Discriminator Architecture ===")
print(discriminator)
print(f"\nTotal Discriminator Parameters: {sum(p.numel() for p in discriminator.parameters()):,}")

In [None]:
# Demonstrate the data flow through the networks
print("=== Data Flow Demonstration ===")

# Generate random noise (latent vector)
batch_size = 4
latent_dim = 100
noise = torch.randn(batch_size, latent_dim, 1, 1, device=device)
print(f"Input noise shape: {noise.shape}")

# Forward pass through generator
with torch.no_grad():
    fake_images = generator(noise)
    print(f"Generated images shape: {fake_images.shape}")
    print(f"Generated image range: [{fake_images.min():.3f}, {fake_images.max():.3f}]")
    
    # Forward pass through discriminator
    fake_scores = discriminator(fake_images)
    print(f"Discriminator scores shape: {fake_scores.shape}")
    print(f"Discriminator scores (fake): {fake_scores.cpu().numpy()}")
    
    # Test with real images
    real_scores = discriminator(real_batch[:batch_size].to(device))
    print(f"Discriminator scores (real): {real_scores.cpu().numpy()}")

# Visualize untrained generator output
plt.figure(figsize=(12, 4))
plt.suptitle("Untrained Generator Output vs Real Data", fontsize=14)

# Untrained generator samples
plt.subplot(1, 2, 1)
plt.title("Untrained Generator")
fake_grid = vutils.make_grid(fake_images.cpu(), nrow=2, normalize=True, value_range=(-1, 1))
plt.imshow(np.transpose(fake_grid, (1, 2, 0)))
plt.axis('off')

# Real data
plt.subplot(1, 2, 2)
plt.title("Real Wafer Patterns")
real_grid = vutils.make_grid(real_batch[:batch_size], nrow=2, normalize=True, value_range=(-1, 1))
plt.imshow(np.transpose(real_grid, (1, 2, 0)))
plt.axis('off')

plt.tight_layout()
plt.show()

print("\nObservation: Untrained generator produces random noise-like patterns")
print("Goal: Train the generator to produce realistic wafer patterns")

## 3. Training the GAN

Now let's train our GAN for a few epochs to see how it learns to generate realistic wafer patterns. We'll use a short training run suitable for an interactive notebook.

In [None]:
# Train the GAN for a short demonstration
print("Starting GAN training...")
print("Note: This is a short training run for demonstration purposes")
print("For production use, train for 50-200 epochs")

# Train the pipeline
trained_pipeline = pipeline.fit(
    data_path=None,  # Use synthetic data
    epochs=5  # Short training for notebook demonstration
)

print("\nTraining completed!")
print(f"Final metadata: {trained_pipeline.metadata}")

## 4. Evaluating Generated Samples

Let's examine the quality of our trained generator and compare it with the untrained version.

In [None]:
# Generate samples from the trained model
num_samples = 16
generated_samples = trained_pipeline.generate(num_samples)

print(f"Generated {num_samples} samples")
print(f"Sample shape: {generated_samples.shape}")
print(f"Sample range: [{generated_samples.min():.3f}, {generated_samples.max():.3f}]")

# Create comparison visualization
plt.figure(figsize=(15, 10))

# Trained generator samples
plt.subplot(2, 2, 1)
plt.title("Trained Generator Output", fontsize=12)
generated_grid = vutils.make_grid(generated_samples.cpu(), nrow=4, normalize=True, value_range=(-1, 1))
plt.imshow(np.transpose(generated_grid, (1, 2, 0)))
plt.axis('off')

# Real training data
plt.subplot(2, 2, 2)
plt.title("Real Training Data", fontsize=12)
real_grid = vutils.make_grid(real_batch[:num_samples], nrow=4, normalize=True, value_range=(-1, 1))
plt.imshow(np.transpose(real_grid, (1, 2, 0)))
plt.axis('off')

# Pixel intensity distributions
plt.subplot(2, 2, 3)
plt.title("Pixel Intensity Comparison", fontsize=12)
generated_pixels = generated_samples.cpu().numpy().flatten()
real_pixels = real_batch[:num_samples].numpy().flatten()
plt.hist(real_pixels, bins=50, alpha=0.7, label='Real Data', color='blue')
plt.hist(generated_pixels, bins=50, alpha=0.7, label='Generated', color='red')
plt.xlabel('Pixel Intensity')
plt.ylabel('Frequency')
plt.legend()
plt.grid(True, alpha=0.3)

# Training progress (if available)
plt.subplot(2, 2, 4)
plt.title("Training Metrics", fontsize=12)
# This would show loss curves in a full training run
plt.text(0.5, 0.5, f"Training completed\nEpochs: {trained_pipeline.metadata.epochs_trained}\nFinal D Loss: {trained_pipeline.metadata.final_d_loss:.3f}\nFinal G Loss: {trained_pipeline.metadata.final_g_loss:.3f}", 
         ha='center', va='center', transform=plt.gca().transAxes, fontsize=10)
plt.axis('off')

plt.tight_layout()
plt.show()

In [None]:
# Evaluate the model quantitatively
evaluation_results = trained_pipeline.evaluate()

print("=== Model Evaluation Results ===")
print(json.dumps(evaluation_results, indent=2))

# Analyze the metrics
metrics = evaluation_results['metrics']
warnings = evaluation_results['warnings']

print("\n=== Quality Assessment ===")
print(f"Sample diversity (std): {metrics['sample_std']:.3f}")
print(f"Sample mean: {metrics['sample_mean']:.3f}")
print(f"Sample range: [{metrics['sample_min']:.3f}, {metrics['sample_max']:.3f}]")

if warnings:
    print("\n⚠️  Warnings:")
    for warning in warnings:
        print(f"  - {warning}")
else:
    print("\n✅ No quality warnings detected")

print("\n📊 Note: For production use, consider training for more epochs to improve quality")

## 5. Exploring the Latent Space

One of the powerful aspects of GANs is the learned latent space. We can interpolate between points in latent space to generate smooth transitions between different wafer patterns.

In [None]:
# Demonstrate latent space interpolation
def interpolate_latent(generator, z1, z2, steps=8):
    """Interpolate between two latent vectors."""
    alphas = np.linspace(0, 1, steps)
    interpolated_images = []
    
    with torch.no_grad():
        for alpha in alphas:
            z_interp = (1 - alpha) * z1 + alpha * z2
            generated = generator(z_interp)
            interpolated_images.append(generated)
    
    return torch.cat(interpolated_images, dim=0)

# Create two random latent vectors
z1 = torch.randn(1, 100, 1, 1, device=device)
z2 = torch.randn(1, 100, 1, 1, device=device)

# Generate interpolation
interpolated = interpolate_latent(trained_pipeline.generator, z1, z2, steps=8)

# Visualize interpolation
plt.figure(figsize=(16, 4))
plt.suptitle("Latent Space Interpolation: Smooth Transitions Between Wafer Patterns", fontsize=14)

for i in range(8):
    plt.subplot(1, 8, i+1)
    image = interpolated[i].cpu().squeeze()
    plt.imshow(image, cmap='gray', vmin=-1, vmax=1)
    plt.title(f"Step {i+1}")
    plt.axis('off')

plt.tight_layout()
plt.show()

print("Observation: Smooth interpolation indicates good latent space structure")
print("This suggests the generator has learned meaningful representations")

## 6. Practical Application: Data Augmentation

Let's demonstrate how to use our trained GAN for data augmentation in a classification scenario.

In [None]:
# Simulate a data augmentation scenario
print("=== Data Augmentation Demonstration ===")

# Original dataset size
original_data_size = 100
print(f"Original dataset size: {original_data_size} samples")

# Generate additional synthetic samples
augmentation_samples = 200
synthetic_data = trained_pipeline.generate(augmentation_samples)

print(f"Generated additional samples: {augmentation_samples}")
print(f"Augmented dataset size: {original_data_size + augmentation_samples} samples")
print(f"Data increase: {(augmentation_samples / original_data_size) * 100:.0f}%")

# Save a sample grid for inspection
sample_grid_path = "augmentation_samples.png"
trained_pipeline.save_sample_grid(
    sample_grid_path, 
    num_samples=64, 
    nrow=8
)
print(f"\nSample grid saved to: {sample_grid_path}")

# Show a subset of the generated augmentation data
plt.figure(figsize=(12, 6))
plt.title("Generated Data for Augmentation")
augmentation_grid = vutils.make_grid(
    synthetic_data[:32].cpu(), 
    nrow=8, 
    normalize=True, 
    value_range=(-1, 1)
)
plt.imshow(np.transpose(augmentation_grid, (1, 2, 0)))
plt.axis('off')
plt.show()

print("\n📈 Benefits of GAN-based augmentation:")
print("  • Increases dataset size without manual labeling")
print("  • Provides diverse patterns for better generalization")
print("  • Particularly valuable for rare defect patterns")
print("  • Can be conditioned on specific defect types (advanced GANs)")

## 7. Model Persistence and Deployment

Let's demonstrate how to save and load the trained model for production use.

In [None]:
# Save the trained model
model_path = Path("trained_wafer_gan.joblib")
trained_pipeline.save(model_path)
print(f"Model saved to: {model_path}")
print(f"Model file size: {model_path.stat().st_size / 1024 / 1024:.2f} MB")

# Load the model and verify it works
loaded_pipeline = GANsPipeline.load(model_path)
print("\nModel loaded successfully!")

# Verify the loaded model works
test_samples = loaded_pipeline.generate(4)
print(f"Generated {test_samples.shape[0]} test samples from loaded model")

# Compare metadata
print("\n=== Model Metadata ===")
if loaded_pipeline.metadata:
    metadata = loaded_pipeline.metadata
    print(f"Model type: {metadata.model_type}")
    print(f"Image size: {metadata.image_size}")
    print(f"Training epochs: {metadata.epochs_trained}")
    print(f"Training time: {metadata.training_time_seconds:.1f} seconds")
    print(f"Device used: {metadata.device}")
    print(f"Timestamp: {metadata.timestamp}")

# Visualize samples from loaded model
plt.figure(figsize=(8, 4))
plt.title("Samples from Loaded Model")
test_grid = vutils.make_grid(test_samples.cpu(), nrow=4, normalize=True, value_range=(-1, 1))
plt.imshow(np.transpose(test_grid, (1, 2, 0)))
plt.axis('off')
plt.show()

print("\n✅ Model persistence working correctly!")

## 8. Production Considerations and Next Steps

This notebook demonstrated the basics of GAN-based data augmentation for semiconductor manufacturing. Here are key considerations for production deployment:

In [None]:
# Performance and scaling analysis
print("=== Production Considerations ===")
print()

# Timing analysis
import time

# Time sample generation
start_time = time.time()
batch_samples = trained_pipeline.generate(100)
generation_time = time.time() - start_time

print(f"⏱️  Performance Metrics:")
print(f"   Generation time for 100 samples: {generation_time:.2f} seconds")
print(f"   Samples per second: {100 / generation_time:.1f}")
print(f"   Memory usage: ~{torch.cuda.memory_allocated() / 1024 / 1024:.0f} MB" if torch.cuda.is_available() else "   Memory usage: CPU mode")
print()

print("🚀 Scaling Recommendations:")
print("   • CPU training: Suitable for 32x32 images, prototyping")
print("   • GPU training: Recommended for 64x64+ images, production")
print("   • Batch generation: Process multiple samples efficiently")
print("   • Model optimization: Consider quantization for deployment")
print()

print("📊 Quality Improvements:")
print("   • Train for 50-200 epochs for production quality")
print("   • Use larger datasets when available")
print("   • Implement progressive growing for high resolution")
print("   • Consider WGAN-GP for improved stability")
print()

print("🔧 Integration Tips:")
print("   • Validate on held-out real data")
print("   • Monitor downstream model performance")
print("   • Start with 10-20% synthetic data in augmented datasets")
print("   • Implement quality control gates for generated samples")
print()

print("📚 Advanced Topics to Explore:")
print("   • Conditional GANs for specific defect types")
print("   • Progressive GANs for higher resolution")
print("   • StyleGAN for fine-grained control")
print("   • FID/KID metrics for quantitative evaluation")

## Summary

In this notebook, we successfully:

1. **Understood the problem**: Data scarcity in semiconductor defect pattern datasets
2. **Implemented DCGAN**: Built generator and discriminator networks for wafer map synthesis
3. **Trained the model**: Demonstrated adversarial training process
4. **Evaluated quality**: Used visual inspection and quantitative metrics
5. **Explored latent space**: Showed smooth interpolations between patterns
6. **Applied to augmentation**: Generated synthetic data for dataset expansion
7. **Demonstrated persistence**: Saved and loaded trained models

### Key Takeaways:
- GANs can generate realistic wafer patterns for data augmentation
- Even short training runs show promising results
- Proper evaluation is crucial for production deployment
- Model persistence enables reuse and deployment

### Next Steps:
- Train for more epochs with larger datasets
- Implement advanced GAN variants (WGAN-GP, Progressive GAN)
- Integrate with downstream classification pipelines
- Develop conditional generation for specific defect types

This foundation provides a solid starting point for GAN-based data augmentation in semiconductor manufacturing applications.

In [None]:
# Clean up generated files (optional)
import os

print("Generated files in this session:")
for file in ["trained_wafer_gan.joblib", "augmentation_samples.png"]:
    if os.path.exists(file):
        size = os.path.getsize(file)
        print(f"  {file}: {size / 1024:.1f} KB")
    else:
        print(f"  {file}: Not found")

print("\n🎉 Module 8.1 GAN demonstration complete!")
print("Check the pipeline script and documentation for production usage.")