# Section 5: Image Generation with Hugging Face Providers

## Objectives
- Generate images using Hugging Face Inference API
- Compare `provider="auto"` vs explicit provider selection
- Measure and analyze latency differences
- Implement error handling and failover

## Requirements
- Python 3.10+
- CUDA 12.4+ (for local GPU acceleration, optional)
- PyTorch 2.6.0+
- Hugging Face account and API token

## Setup and Installation

In [None]:
# Install required packages
!pip install -q huggingface_hub>=0.20.0 torch>=2.6.0 pillow matplotlib pandas

In [None]:
# Verify PyTorch and CUDA installation
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA version: {torch.version.cuda}")
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## Authentication Setup

**IMPORTANT:** Never hardcode your API token in notebooks!

Set your token as an environment variable:
```bash
export HF_TOKEN="your_token_here"
```

Or use the Hugging Face CLI:
```bash
huggingface-cli login
```

In [None]:
import os
from getpass import getpass

# Try to load token from environment
HF_TOKEN = os.getenv("HF_TOKEN")

# If not found, prompt securely (won't show in output)
if not HF_TOKEN:
    print("HF_TOKEN not found in environment.")
    HF_TOKEN = getpass("Enter your Hugging Face token: ")

# Verify token is loaded
assert HF_TOKEN, "Token must be provided"
print("✓ Token loaded successfully")
print(f"✓ Token length: {len(HF_TOKEN)} characters")

## Import Required Libraries

In [None]:
from huggingface_hub import InferenceClient
from PIL import Image
import matplotlib.pyplot as plt
import time
import pandas as pd
from typing import List, Dict, Tuple
from io import BytesIO
import warnings
warnings.filterwarnings('ignore')

print("✓ All libraries imported successfully")

## Initialize Inference Client

In [None]:
# Initialize the Hugging Face Inference Client
client = InferenceClient(token=HF_TOKEN)

# Model to use for image generation
MODEL = "stabilityai/stable-diffusion-2-1"

print(f"✓ Client initialized")
print(f"✓ Using model: {MODEL}")

## Exercise 1: Basic Image Generation

**What you'll practice:** Generate images from text prompts using Hugging Face Inference API.

This exercise introduces you to the fundamentals of text-to-image generation:
- Setting up the inference client with authentication
- Creating a function to generate images from prompts
- Measuring generation latency
- Displaying generated images

You'll learn how to use the `InferenceClient` for image generation and understand the basic workflow that all subsequent exercises build upon.

In [None]:
def generate_image(prompt: str, provider: str = "auto") -> Tuple[Image.Image, float]:
    """
    Generate an image from a text prompt.
    
    Args:
        prompt: Text description of the image
        provider: Provider to use ("auto" or specific provider)
    
    Returns:
        Tuple of (PIL Image, generation time in seconds)
    """
    start_time = time.perf_counter()
    
    # Generate image
    image_bytes = client.text_to_image(
        prompt=prompt,
        model=MODEL
    )
    
    # Convert bytes to PIL Image
    image = Image.open(BytesIO(image_bytes))
    
    elapsed_time = time.perf_counter() - start_time
    
    return image, elapsed_time

In [None]:
# Generate a single image
prompt = "A serene mountain landscape at sunset, photorealistic, 4k"

print(f"Generating image for prompt: '{prompt}'")
image, gen_time = generate_image(prompt)

print(f"✓ Image generated in {gen_time:.2f} seconds")
print(f"✓ Image size: {image.size}")

# Display the image
plt.figure(figsize=(10, 10))
plt.imshow(image)
plt.axis('off')
plt.title(f"Generated in {gen_time:.2f}s")
plt.show()

## Exercise 2: Provider Comparison

Compare image generation using `provider="auto"` vs explicit provider selection.

In [None]:
def benchmark_providers(
    prompts: List[str],
    providers: List[str],
    num_runs: int = 3
) -> pd.DataFrame:
    """
    Benchmark multiple providers with the same prompts.
    
    Args:
        prompts: List of text prompts
        providers: List of provider names to test
        num_runs: Number of times to run each prompt
    
    Returns:
        DataFrame with benchmark results
    """
    results = []
    
    for provider in providers:
        print(f"\nTesting provider: {provider}")
        print("-" * 50)
        
        for prompt_idx, prompt in enumerate(prompts, 1):
            for run in range(1, num_runs + 1):
                try:
                    print(f"  Prompt {prompt_idx}/{len(prompts)}, Run {run}/{num_runs}...", end=" ")
                    
                    image, gen_time = generate_image(prompt, provider)
                    
                    results.append({
                        'provider': provider,
                        'prompt': prompt[:50] + "...",
                        'run': run,
                        'latency_sec': gen_time,
                        'status': 'success'
                    })
                    
                    print(f"✓ {gen_time:.2f}s")
                    
                except Exception as e:
                    results.append({
                        'provider': provider,
                        'prompt': prompt[:50] + "...",
                        'run': run,
                        'latency_sec': None,
                        'status': f'failed: {type(e).__name__}'
                    })
                    
                    print(f"✗ {type(e).__name__}")
                
                # Small delay between requests
                time.sleep(1)
    
    return pd.DataFrame(results)

In [None]:
# Define test prompts
test_prompts = [
    "A peaceful garden with colorful flowers and butterflies",
    "A futuristic city skyline at night with neon lights",
    "A cozy cabin in a snowy forest during winter"
]

# Providers to test
# Note: Replace with actual provider names available in your account
providers_to_test = ["auto"]  # Add explicit providers if available

# Run benchmark
print("Starting provider benchmark...")
benchmark_df = benchmark_providers(
    prompts=test_prompts,
    providers=providers_to_test,
    num_runs=3
)

print("\n✓ Benchmark completed")

## Exercise 3: Analyze Results

**What you'll practice:** Analyze and visualize performance metrics from provider comparisons.

This exercise teaches you to:
- Process benchmark data into meaningful insights
- Create visualizations comparing provider performance
- Identify patterns in latency and reliability
- Make data-driven decisions about provider selection

You'll use pandas and matplotlib to transform raw benchmark data into actionable insights about which providers work best for your use case.

In [None]:
# Display raw results
print("Raw Benchmark Results:")
print(benchmark_df.to_string())

In [None]:
# Calculate statistics by provider
stats = benchmark_df[benchmark_df['status'] == 'success'].groupby('provider').agg({
    'latency_sec': ['mean', 'std', 'min', 'max', 'count']
}).round(3)

# Calculate success rate
success_rate = benchmark_df.groupby('provider').apply(
    lambda x: (x['status'] == 'success').sum() / len(x) * 100
).round(1)

print("\nProvider Statistics:")
print("=" * 70)
print(stats)
print("\nSuccess Rate (%)")
print(success_rate)

In [None]:
# Visualize latency comparison
successful_results = benchmark_df[benchmark_df['status'] == 'success']

if len(successful_results) > 0:
    plt.figure(figsize=(12, 6))
    
    # Box plot
    plt.subplot(1, 2, 1)
    successful_results.boxplot(column='latency_sec', by='provider', ax=plt.gca())
    plt.title('Latency Distribution by Provider')
    plt.suptitle('')  # Remove default title
    plt.xlabel('Provider')
    plt.ylabel('Latency (seconds)')
    
    # Bar plot with error bars
    plt.subplot(1, 2, 2)
    provider_stats = successful_results.groupby('provider')['latency_sec'].agg(['mean', 'std'])
    provider_stats.plot(kind='bar', y='mean', yerr='std', ax=plt.gca(), legend=False)
    plt.title('Average Latency by Provider')
    plt.xlabel('Provider')
    plt.ylabel('Latency (seconds)')
    plt.xticks(rotation=45)
    
    plt.tight_layout()
    plt.show()
else:
    print("No successful results to visualize")

## Exercise 4: Implement Failover Strategy

**What you'll practice:** Build robust error handling with automatic failover between providers.

This exercise demonstrates production-ready patterns:
- Handling API failures gracefully
- Implementing retry logic with exponential backoff
- Automatically switching to backup providers
- Logging failures for monitoring

You'll learn how to make your image generation system resilient to provider outages and rate limits, ensuring high availability for production applications.

In [None]:
def generate_with_failover(
    prompt: str,
    providers: List[str],
    timeout: int = 60
) -> Tuple[Image.Image, str, float]:
    """
    Generate image with automatic failover across providers.
    
    Args:
        prompt: Text description
        providers: List of providers to try in order
        timeout: Timeout per provider in seconds
    
    Returns:
        Tuple of (image, provider_used, generation_time)
    """
    last_error = None
    
    for provider in providers:
        try:
            print(f"Trying provider: {provider}...", end=" ")
            image, gen_time = generate_image(prompt, provider)
            print(f"✓ Success ({gen_time:.2f}s)")
            return image, provider, gen_time
            
        except Exception as e:
            print(f"✗ Failed ({type(e).__name__})")
            last_error = e
            continue
    
    raise Exception(f"All providers failed. Last error: {last_error}")

In [None]:
# Test failover mechanism
prompt = "A majestic eagle soaring over mountains"
providers = ["auto"]  # Add more providers if available

print(f"Generating with failover for: '{prompt}'")
print(f"Provider chain: {providers}\n")

try:
    image, used_provider, gen_time = generate_with_failover(prompt, providers)
    
    print(f"\n✓ Successfully generated using: {used_provider}")
    print(f"✓ Generation time: {gen_time:.2f}s")
    
    # Display result
    plt.figure(figsize=(10, 10))
    plt.imshow(image)
    plt.axis('off')
    plt.title(f"Provider: {used_provider} | Time: {gen_time:.2f}s")
    plt.show()
    
except Exception as e:
    print(f"\n✗ All providers failed: {e}")

## Exercise 5: Batch Generation with Progress Tracking

**What you'll practice:** Generate multiple images efficiently with progress tracking and error handling.

This exercise covers advanced patterns:
- Processing multiple prompts in batches
- Tracking progress with visual indicators
- Handling partial failures in batch operations
- Optimizing throughput for large-scale generation

You'll build a system that can generate hundreds of images efficiently while providing real-time feedback on progress and handling errors gracefully.

In [None]:
def batch_generate(
    prompts: List[str],
    provider: str = "auto",
    save_images: bool = False
) -> List[Dict]:
    """
    Generate multiple images with progress tracking.
    
    Args:
        prompts: List of text prompts
        provider: Provider to use
        save_images: Whether to save images to disk
    
    Returns:
        List of result dictionaries
    """
    results = []
    total = len(prompts)
    
    for idx, prompt in enumerate(prompts, 1):
        print(f"\n[{idx}/{total}] Generating: '{prompt[:60]}...'")
        
        try:
            image, gen_time = generate_image(prompt, provider)
            
            result = {
                'prompt': prompt,
                'status': 'success',
                'latency': gen_time,
                'image': image
            }
            
            if save_images:
                filename = f"image_{idx:03d}.png"
                image.save(filename)
                result['filename'] = filename
                print(f"  ✓ Saved to {filename}")
            
            print(f"  ✓ Generated in {gen_time:.2f}s")
            
        except Exception as e:
            result = {
                'prompt': prompt,
                'status': 'failed',
                'error': str(e),
                'image': None
            }
            print(f"  ✗ Failed: {type(e).__name__}")
        
        results.append(result)
        time.sleep(1)  # Rate limiting
    
    return results

In [None]:
# Generate multiple images
batch_prompts = [
    "A steampunk robot in a Victorian workshop",
    "An underwater coral reef with tropical fish",
    "A medieval castle on a hilltop at dawn"
]

print("Starting batch generation...")
batch_results = batch_generate(batch_prompts, provider="auto", save_images=False)

# Display results in a grid
successful_images = [r for r in batch_results if r['status'] == 'success']

if successful_images:
    n_images = len(successful_images)
    cols = min(3, n_images)
    rows = (n_images + cols - 1) // cols
    
    fig, axes = plt.subplots(rows, cols, figsize=(15, 5 * rows))
    if rows == 1 and cols == 1:
        axes = [[axes]]
    elif rows == 1 or cols == 1:
        axes = axes.reshape(rows, cols)
    
    for idx, result in enumerate(successful_images):
        row = idx // cols
        col = idx % cols
        ax = axes[row][col]
        
        ax.imshow(result['image'])
        ax.axis('off')
        ax.set_title(f"{result['prompt'][:40]}...\n{result['latency']:.2f}s", fontsize=10)
    
    # Hide empty subplots
    for idx in range(n_images, rows * cols):
        row = idx // cols
        col = idx % cols
        axes[row][col].axis('off')
    
    plt.tight_layout()
    plt.show()
else:
    print("No successful images to display")

## Summary and Key Takeaways

### What You Learned
1. ✅ How to authenticate with Hugging Face API securely
2. ✅ Generate images using text-to-image models
3. ✅ Compare provider performance and latency
4. ✅ Implement failover strategies for reliability
5. ✅ Batch process multiple image generations

### Best Practices
- Always use environment variables for tokens
- Implement timeout and retry logic
- Monitor latency and success rates
- Use failover for production applications
- Add delays between requests to avoid rate limiting

### Next Steps
- Proceed to **Chat Inference Practice** notebook
- Experiment with different models and providers
- Implement more advanced error handling
- Create a production-ready image generation service

## Cleanup

In [None]:
# Clear token from memory (security best practice)
if 'HF_TOKEN' in locals():
    del HF_TOKEN

print("✓ Cleanup completed")