# Minimal Text-to-Image Generator

A lightweight, educational prompt → image model that generates **28×28 grayscale images** from short text prompts.

## What This Notebook Demonstrates

This notebook builds a complete text-to-image generation pipeline **from scratch**:

1. **Synthetic Dataset**: Procedurally generated `(text_prompt, image)` pairs
2. **Text Encoder**: Character-level embeddings with mean pooling
3. **Image Generator**: MLP decoder that maps latent + noise → image
4. **Stochastic Generation**: Same prompt produces different outputs each time

## Hard Constraints (Intentional Limitations)

- ✅ **PyTorch** only
- ✅ **No pretrained models** - everything learned from scratch
- ✅ **No diffusion** - direct latent → image mapping
- ✅ **No transformers** - simple MLP architecture
- ✅ **No CLIP** - custom text encoder
- ✅ **No external datasets** - fully synthetic
- ✅ **CPU-friendly** - runs on any machine
- ✅ **28×28 grayscale** - compact and fast

## Architecture Overview

```
┌─────────────────────────────────────────────────────────────────┐
│                         TEXT ENCODER                            │
│  "circle" → [c,i,r,c,l,e] → Embeddings → Mean Pool → MLP → z   │
│                                                    (64 dims)    │
└─────────────────────────────────────────────────────┬───────────┘
                                                      │
                                                      ▼
┌─────────────────────────────────────────────────────────────────┐
│                       IMAGE GENERATOR                           │
│  [z (64) | noise (32)] → MLP → Sigmoid → 28×28 grayscale       │
└─────────────────────────────────────────────────────────────────┘
```

---
## 1. Setup & Imports

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import random

# Reproducibility (but generation will still be stochastic due to noise injection)
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

# Device selection - CPU by design for accessibility
DEVICE = torch.device('cpu')
print(f"Using device: {DEVICE}")

---
## 2. Synthetic Dataset Generation

We generate `(text_prompt, image)` pairs procedurally using NumPy.

### Design Philosophy
- Each prompt maps to a **distribution** of images, not a single template
- Randomness in position, size, angle ensures variety
- Binary images (0 or 1) for simplicity, stored as float32 for training

In [None]:
# ============================================================================
# IMAGE GENERATION FUNCTIONS
# Each function creates a 28x28 numpy array with values in [0, 1]
# Randomness is built-in so each call produces a unique image
# ============================================================================

def generate_horizontal_line():
    """Horizontal line at random y-position with slight thickness variation."""
    img = np.zeros((28, 28), dtype=np.float32)
    y = np.random.randint(5, 23)  # Keep away from edges
    thickness = np.random.randint(1, 3)  # 1-2 pixels thick
    x_start = np.random.randint(0, 5)
    x_end = np.random.randint(23, 28)
    for dy in range(thickness):
        if 0 <= y + dy < 28:
            img[y + dy, x_start:x_end] = 1.0
    return img


def generate_vertical_line():
    """Vertical line at random x-position."""
    img = np.zeros((28, 28), dtype=np.float32)
    x = np.random.randint(5, 23)
    thickness = np.random.randint(1, 3)
    y_start = np.random.randint(0, 5)
    y_end = np.random.randint(23, 28)
    for dx in range(thickness):
        if 0 <= x + dx < 28:
            img[y_start:y_end, x + dx] = 1.0
    return img


def generate_two_vertical_lines():
    """Two vertical lines with random spacing."""
    img = np.zeros((28, 28), dtype=np.float32)
    # First line on left half, second on right half
    x1 = np.random.randint(4, 10)
    x2 = np.random.randint(18, 24)
    y_start = np.random.randint(2, 6)
    y_end = np.random.randint(22, 26)
    img[y_start:y_end, x1] = 1.0
    img[y_start:y_end, x2] = 1.0
    return img


def generate_diagonal_line():
    """Diagonal line, either top-left to bottom-right or top-right to bottom-left."""
    img = np.zeros((28, 28), dtype=np.float32)
    direction = np.random.choice(['tlbr', 'trbl'])  # Top-left to bottom-right or reverse
    offset = np.random.randint(-3, 4)  # Shift the line slightly
    
    for i in range(28):
        if direction == 'tlbr':
            x, y = i, i + offset
        else:
            x, y = i, 27 - i + offset
        if 0 <= x < 28 and 0 <= y < 28:
            img[y, x] = 1.0
            # Add slight thickness
            if y + 1 < 28:
                img[y + 1, x] = 1.0
    return img


def generate_cross():
    """A + or × shape centered with slight position jitter."""
    img = np.zeros((28, 28), dtype=np.float32)
    cx = 14 + np.random.randint(-3, 4)  # Center x with jitter
    cy = 14 + np.random.randint(-3, 4)  # Center y with jitter
    arm_length = np.random.randint(6, 10)
    
    # Horizontal arm
    for dx in range(-arm_length, arm_length + 1):
        if 0 <= cx + dx < 28:
            img[cy, cx + dx] = 1.0
    
    # Vertical arm
    for dy in range(-arm_length, arm_length + 1):
        if 0 <= cy + dy < 28:
            img[cy + dy, cx] = 1.0
    
    return img


def generate_circle():
    """Circle with random radius and center position."""
    img = np.zeros((28, 28), dtype=np.float32)
    cx = 14 + np.random.randint(-4, 5)
    cy = 14 + np.random.randint(-4, 5)
    radius = np.random.randint(5, 11)
    
    # Draw circle using parametric equation
    for theta in np.linspace(0, 2 * np.pi, 100):
        x = int(cx + radius * np.cos(theta))
        y = int(cy + radius * np.sin(theta))
        if 0 <= x < 28 and 0 <= y < 28:
            img[y, x] = 1.0
    
    return img


def generate_square():
    """Square outline with random size and position."""
    img = np.zeros((28, 28), dtype=np.float32)
    size = np.random.randint(8, 16)
    # Random top-left corner, ensuring square fits
    x1 = np.random.randint(2, 28 - size - 2)
    y1 = np.random.randint(2, 28 - size - 2)
    x2, y2 = x1 + size, y1 + size
    
    # Draw four sides
    img[y1, x1:x2] = 1.0  # Top
    img[y2, x1:x2+1] = 1.0  # Bottom
    img[y1:y2, x1] = 1.0  # Left
    img[y1:y2+1, x2] = 1.0  # Right
    
    return img


def generate_dense_center():
    """Gaussian-like blob in the center with random spread."""
    img = np.zeros((28, 28), dtype=np.float32)
    cx, cy = 14 + np.random.randint(-2, 3), 14 + np.random.randint(-2, 3)
    sigma = np.random.uniform(3, 6)
    
    for y in range(28):
        for x in range(28):
            dist_sq = (x - cx) ** 2 + (y - cy) ** 2
            img[y, x] = np.exp(-dist_sq / (2 * sigma ** 2))
    
    # Normalize to [0, 1]
    img = img / img.max()
    return img


def generate_sparse_dots():
    """5-15 random pixels lit up."""
    img = np.zeros((28, 28), dtype=np.float32)
    n_dots = np.random.randint(5, 16)
    
    for _ in range(n_dots):
        x = np.random.randint(0, 28)
        y = np.random.randint(0, 28)
        img[y, x] = 1.0
    
    return img


# ============================================================================
# PROMPT TO GENERATOR MAPPING
# This dictionary maps text prompts to their image generation functions
# ============================================================================

PROMPT_GENERATORS = {
    "horizontal line": generate_horizontal_line,
    "vertical line": generate_vertical_line,
    "two vertical lines": generate_two_vertical_lines,
    "diagonal line": generate_diagonal_line,
    "cross": generate_cross,
    "circle": generate_circle,
    "square": generate_square,
    "dense center": generate_dense_center,
    "sparse dots": generate_sparse_dots,
}

ALL_PROMPTS = list(PROMPT_GENERATORS.keys())
print(f"Available prompts ({len(ALL_PROMPTS)}): {ALL_PROMPTS}")

In [None]:
# ============================================================================
# VISUALIZE SAMPLE IMAGES FROM EACH PROMPT
# Shows the diversity of images generated for each prompt category
# ============================================================================

fig, axes = plt.subplots(len(ALL_PROMPTS), 5, figsize=(10, 2 * len(ALL_PROMPTS)))
fig.suptitle('Sample Images for Each Prompt (5 variations each)', fontsize=14, fontweight='bold')

for row, prompt in enumerate(ALL_PROMPTS):
    for col in range(5):
        img = PROMPT_GENERATORS[prompt]()
        axes[row, col].imshow(img, cmap='gray', vmin=0, vmax=1)
        axes[row, col].axis('off')
        if col == 0:
            axes[row, col].set_title(prompt, fontsize=10, loc='left')

plt.tight_layout()
plt.show()

In [None]:
# ============================================================================
# PYTORCH DATASET CLASS
# Wraps our generators in a PyTorch Dataset for easy batching
# ============================================================================

class SyntheticTextImageDataset(Dataset):
    """
    Generates (prompt, image) pairs on-the-fly.
    
    Each __getitem__ call creates a fresh random image for the selected prompt,
    ensuring the model sees diverse examples during training.
    """
    
    def __init__(self, num_samples_per_prompt=1000):
        """
        Args:
            num_samples_per_prompt: How many times each prompt appears in one epoch
        """
        self.prompts = ALL_PROMPTS
        self.num_samples_per_prompt = num_samples_per_prompt
        
        # Pre-compute the prompt for each index for deterministic iteration order
        self.index_to_prompt = []
        for prompt in self.prompts:
            self.index_to_prompt.extend([prompt] * num_samples_per_prompt)
    
    def __len__(self):
        return len(self.index_to_prompt)
    
    def __getitem__(self, idx):
        prompt = self.index_to_prompt[idx]
        
        # Generate a fresh random image for this prompt
        img = PROMPT_GENERATORS[prompt]()
        
        # Convert to tensor: shape (1, 28, 28) for channel dimension
        img_tensor = torch.from_numpy(img).unsqueeze(0)
        
        return prompt, img_tensor


# Test the dataset
test_dataset = SyntheticTextImageDataset(num_samples_per_prompt=100)
print(f"Dataset size: {len(test_dataset)} samples")
print(f"Sample: prompt='{test_dataset[0][0]}', image shape={test_dataset[0][1].shape}")

---
## 3. Text Encoder

We need to convert text prompts into fixed-size vectors that the image generator can understand.

### Design Choices
- **Character-level tokenization**: Simple and works well for short prompts
- **Learned embeddings**: Each character gets a learnable vector
- **Mean pooling**: Average all character embeddings to get a fixed-size vector
- **MLP projection**: Project pooled embedding to the final latent dimension

In [None]:
# ============================================================================
# CHARACTER TOKENIZER
# Maps characters to integer indices for embedding lookup
# ============================================================================

class CharacterTokenizer:
    """
    Simple character-level tokenizer.
    
    Vocabulary: a-z, space, and <PAD> token
    All text is lowercased and unknown characters are ignored.
    """
    
    def __init__(self):
        # Build vocabulary: PAD=0, space=1, a-z=2-27
        self.char_to_idx = {'<PAD>': 0, ' ': 1}
        for i, char in enumerate('abcdefghijklmnopqrstuvwxyz'):
            self.char_to_idx[char] = i + 2
        
        self.idx_to_char = {v: k for k, v in self.char_to_idx.items()}
        self.vocab_size = len(self.char_to_idx)
        self.pad_idx = 0
    
    def encode(self, text, max_length=32):
        """
        Convert text to list of token indices.
        
        Args:
            text: Input string
            max_length: Maximum sequence length (pad/truncate to this)
        
        Returns:
            List of integer token indices
        """
        text = text.lower()
        tokens = []
        
        for char in text:
            if char in self.char_to_idx:
                tokens.append(self.char_to_idx[char])
            # Unknown characters are silently ignored
        
        # Truncate if too long
        tokens = tokens[:max_length]
        
        # Pad if too short
        while len(tokens) < max_length:
            tokens.append(self.pad_idx)
        
        return tokens
    
    def decode(self, tokens):
        """Convert token indices back to string."""
        chars = [self.idx_to_char.get(t, '') for t in tokens if t != self.pad_idx]
        return ''.join(chars)


# Test the tokenizer
tokenizer = CharacterTokenizer()
print(f"Vocabulary size: {tokenizer.vocab_size}")
print(f"\nExample encoding:")
test_text = "circle"
encoded = tokenizer.encode(test_text, max_length=16)
print(f"  '{test_text}' → {encoded[:10]}... (truncated)")
print(f"  Decoded: '{tokenizer.decode(encoded)}'")

In [None]:
# ============================================================================
# TEXT ENCODER NEURAL NETWORK
# Converts text prompts into fixed-size latent vectors
# ============================================================================

class TextEncoder(nn.Module):
    """
    Encodes text prompts into latent vectors.
    
    Architecture:
        1. Embedding layer: char indices → 32-dim vectors
        2. Mean pooling: average over sequence length
        3. MLP: project to final latent dimension
    """
    
    def __init__(self, vocab_size, embed_dim=32, latent_dim=64, max_length=32):
        super().__init__()
        
        self.max_length = max_length
        
        # Learnable character embeddings
        # Each character in the vocabulary gets a embed_dim-dimensional vector
        self.embedding = nn.Embedding(
            num_embeddings=vocab_size,
            embedding_dim=embed_dim,
            padding_idx=0  # PAD tokens get zero vectors
        )
        
        # Project pooled embeddings to latent space
        # This MLP learns to extract meaningful features from the averaged embeddings
        self.projection = nn.Sequential(
            nn.Linear(embed_dim, 128),
            nn.ReLU(),
            nn.Linear(128, latent_dim)
        )
    
    def forward(self, token_ids):
        """
        Args:
            token_ids: (batch_size, seq_length) integer tensor
        
        Returns:
            (batch_size, latent_dim) latent vectors
        """
        # Get embeddings: (batch, seq_len, embed_dim)
        embeds = self.embedding(token_ids)
        
        # Create mask for non-padding tokens
        # This ensures we only average over actual characters, not PAD tokens
        mask = (token_ids != 0).float().unsqueeze(-1)  # (batch, seq_len, 1)
        
        # Mean pooling: sum embeddings / count of non-pad tokens
        # Adding small epsilon to avoid division by zero
        masked_embeds = embeds * mask
        summed = masked_embeds.sum(dim=1)  # (batch, embed_dim)
        counts = mask.sum(dim=1).clamp(min=1)  # (batch, 1)
        pooled = summed / counts  # (batch, embed_dim)
        
        # Project to latent space
        latent = self.projection(pooled)  # (batch, latent_dim)
        
        return latent


# Test the text encoder
text_encoder = TextEncoder(vocab_size=tokenizer.vocab_size)
print(f"TextEncoder architecture:")
print(text_encoder)

# Test forward pass
test_tokens = torch.tensor([tokenizer.encode("circle")])
test_latent = text_encoder(test_tokens)
print(f"\nInput shape: {test_tokens.shape}")
print(f"Output latent shape: {test_latent.shape}")

---
## 4. Image Generator

The generator takes a text latent vector + random noise and produces a 28×28 image.

### Why Add Noise?
Without noise, the model would learn a deterministic mapping: same prompt → same image.
By concatenating random noise with the latent vector, we enable **stochastic generation**:
the same prompt can produce different valid images depending on the noise sample.

In [None]:
# ============================================================================
# IMAGE GENERATOR NEURAL NETWORK
# Maps latent vectors + noise to 28x28 grayscale images
# ============================================================================

class ImageGenerator(nn.Module):
    """
    MLP-based image generator.
    
    Architecture:
        Input: latent (64) + noise (32) = 96 dims
        → Linear(96, 256) + ReLU + BatchNorm
        → Linear(256, 512) + ReLU + BatchNorm
        → Linear(512, 784) + Sigmoid
        → Reshape to (1, 28, 28)
    
    The sigmoid ensures output values are in [0, 1] for valid grayscale pixels.
    """
    
    def __init__(self, latent_dim=64, noise_dim=32):
        super().__init__()
        
        self.noise_dim = noise_dim
        input_dim = latent_dim + noise_dim
        
        # MLP decoder
        # BatchNorm helps with training stability
        self.decoder = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            
            nn.Linear(256, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            
            nn.Linear(512, 784),  # 28 * 28 = 784
            nn.Sigmoid()  # Output in [0, 1]
        )
    
    def forward(self, latent, noise=None):
        """
        Args:
            latent: (batch_size, latent_dim) text encoding
            noise: (batch_size, noise_dim) random noise, or None to sample fresh
        
        Returns:
            (batch_size, 1, 28, 28) generated images
        """
        batch_size = latent.shape[0]
        device = latent.device
        
        # Sample noise if not provided
        if noise is None:
            noise = torch.randn(batch_size, self.noise_dim, device=device)
        
        # Concatenate latent and noise
        combined = torch.cat([latent, noise], dim=1)
        
        # Generate flattened image
        flat_img = self.decoder(combined)  # (batch, 784)
        
        # Reshape to image
        img = flat_img.view(batch_size, 1, 28, 28)
        
        return img


# Test the generator
generator = ImageGenerator()
print(f"ImageGenerator architecture:")
print(generator)

# Test forward pass
test_output = generator(test_latent)
print(f"\nInput latent shape: {test_latent.shape}")
print(f"Output image shape: {test_output.shape}")
print(f"Output value range: [{test_output.min():.3f}, {test_output.max():.3f}]")

In [None]:
# ============================================================================
# COMPLETE TEXT-TO-IMAGE MODEL
# Combines text encoder and image generator into one module
# ============================================================================

class TextToImageModel(nn.Module):
    """
    End-to-end text-to-image generation model.
    
    Pipeline:
        text prompt → tokenize → encode → concatenate with noise → generate image
    """
    
    def __init__(self, vocab_size, embed_dim=32, latent_dim=64, noise_dim=32, max_length=32):
        super().__init__()
        
        self.max_length = max_length
        self.noise_dim = noise_dim
        
        self.text_encoder = TextEncoder(
            vocab_size=vocab_size,
            embed_dim=embed_dim,
            latent_dim=latent_dim,
            max_length=max_length
        )
        
        self.image_generator = ImageGenerator(
            latent_dim=latent_dim,
            noise_dim=noise_dim
        )
    
    def forward(self, token_ids, noise=None):
        """
        Args:
            token_ids: (batch_size, seq_length) tokenized prompts
            noise: optional noise tensor
        
        Returns:
            (batch_size, 1, 28, 28) generated images
        """
        latent = self.text_encoder(token_ids)
        image = self.image_generator(latent, noise)
        return image
    
    def generate(self, prompt, tokenizer, num_samples=1, device='cpu'):
        """
        Convenience method for inference.
        
        Args:
            prompt: string prompt
            tokenizer: CharacterTokenizer instance
            num_samples: number of images to generate
            device: torch device
        
        Returns:
            (num_samples, 1, 28, 28) generated images
        """
        self.eval()
        with torch.no_grad():
            # Tokenize and repeat for batch
            tokens = tokenizer.encode(prompt, max_length=self.max_length)
            token_ids = torch.tensor([tokens] * num_samples, device=device)
            
            # Generate with fresh noise for each sample
            images = self.forward(token_ids)
        
        return images


# Create the full model
model = TextToImageModel(vocab_size=tokenizer.vocab_size)
model = model.to(DEVICE)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

---
## 5. Training Loop

We train the model end-to-end to minimize the difference between generated and target images.

### Training Strategy
- **Loss**: MSE (Mean Squared Error) - works well for regression to pixel values
- **Optimizer**: Adam with default learning rate
- **Epochs**: 100 (adjustable based on convergence)
- **Batch size**: 64 (balance between speed and gradient quality)

In [None]:
# ============================================================================
# TRAINING CONFIGURATION
# ============================================================================

# Hyperparameters
BATCH_SIZE = 64
NUM_EPOCHS = 100
LEARNING_RATE = 1e-3
SAMPLES_PER_PROMPT = 500  # How many times each prompt appears per epoch

# Create dataset and dataloader
train_dataset = SyntheticTextImageDataset(num_samples_per_prompt=SAMPLES_PER_PROMPT)

# Custom collate function to batch prompts properly
def collate_fn(batch):
    prompts, images = zip(*batch)
    
    # Tokenize all prompts
    token_ids = torch.tensor([tokenizer.encode(p) for p in prompts])
    
    # Stack images
    images = torch.stack(images)
    
    return token_ids, images

train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    collate_fn=collate_fn,
    num_workers=0  # Keep 0 for compatibility
)

print(f"Training samples per epoch: {len(train_dataset)}")
print(f"Batches per epoch: {len(train_loader)}")

In [None]:
# ============================================================================
# TRAINING LOOP
# ============================================================================

# Initialize optimizer and loss function
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
criterion = nn.MSELoss()

# Training history for plotting
losses = []

print("Starting training...\n")
print(f"{'Epoch':>6} | {'Loss':>10} | {'Progress':>10}")
print("-" * 35)

for epoch in range(NUM_EPOCHS):
    model.train()
    epoch_loss = 0.0
    
    for batch_idx, (token_ids, target_images) in enumerate(train_loader):
        # Move to device
        token_ids = token_ids.to(DEVICE)
        target_images = target_images.to(DEVICE)
        
        # Forward pass
        # The model samples fresh noise internally
        generated_images = model(token_ids)
        
        # Compute loss
        loss = criterion(generated_images, target_images)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()
    
    # Record average epoch loss
    avg_loss = epoch_loss / len(train_loader)
    losses.append(avg_loss)
    
    # Print progress every 10 epochs
    if (epoch + 1) % 10 == 0 or epoch == 0:
        progress = (epoch + 1) / NUM_EPOCHS * 100
        print(f"{epoch + 1:>6} | {avg_loss:>10.6f} | {progress:>9.1f}%")

print("\nTraining complete!")

In [None]:
# ============================================================================
# PLOT TRAINING LOSS
# ============================================================================

plt.figure(figsize=(10, 4))
plt.plot(losses, 'b-', linewidth=2)
plt.xlabel('Epoch', fontsize=12)
plt.ylabel('MSE Loss', fontsize=12)
plt.title('Training Loss Over Time', fontsize=14, fontweight='bold')
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

print(f"Final loss: {losses[-1]:.6f}")
print(f"Best loss: {min(losses):.6f} (epoch {losses.index(min(losses)) + 1})")

---
## 6. Sampling & Visualization

Now let's see what our model has learned! We'll:
1. Generate single images from prompts
2. Show multiple generations from the same prompt (demonstrating stochasticity)
3. Create a grid comparing all prompts

In [None]:
# ============================================================================
# HELPER FUNCTION FOR VISUALIZATION
# ============================================================================

def visualize_generations(model, tokenizer, prompt, num_samples=8, figsize=(12, 2)):
    """
    Generate and display multiple images from the same prompt.
    
    This demonstrates the stochastic nature of our generator:
    same prompt + different noise = different but valid images.
    """
    images = model.generate(prompt, tokenizer, num_samples=num_samples, device=DEVICE)
    
    fig, axes = plt.subplots(1, num_samples, figsize=figsize)
    fig.suptitle(f'Prompt: "{prompt}"', fontsize=14, fontweight='bold')
    
    for i, ax in enumerate(axes):
        img = images[i, 0].cpu().numpy()
        ax.imshow(img, cmap='gray', vmin=0, vmax=1)
        ax.axis('off')
        ax.set_title(f'#{i+1}', fontsize=10)
    
    plt.tight_layout()
    plt.show()

In [None]:
# ============================================================================
# DEMO 1: SINGLE PROMPT → MULTIPLE GENERATIONS
# Shows that the same prompt produces different outputs each time
# ============================================================================

print("Demonstrating stochastic generation:")
print("Same prompt → Different outputs (due to random noise)\n")

for prompt in ["circle", "horizontal line", "cross", "square"]:
    visualize_generations(model, tokenizer, prompt, num_samples=8)

In [None]:
# ============================================================================
# DEMO 2: GRID OF ALL PROMPTS
# Compare generations across all prompt categories
# ============================================================================

num_samples_per_prompt = 4

fig, axes = plt.subplots(
    len(ALL_PROMPTS), 
    num_samples_per_prompt, 
    figsize=(num_samples_per_prompt * 2, len(ALL_PROMPTS) * 2)
)
fig.suptitle('Generated Images for All Prompts', fontsize=16, fontweight='bold', y=1.02)

for row, prompt in enumerate(ALL_PROMPTS):
    images = model.generate(prompt, tokenizer, num_samples=num_samples_per_prompt, device=DEVICE)
    
    for col in range(num_samples_per_prompt):
        img = images[col, 0].cpu().numpy()
        axes[row, col].imshow(img, cmap='gray', vmin=0, vmax=1)
        axes[row, col].axis('off')
        
        if col == 0:
            axes[row, col].set_ylabel(prompt, fontsize=10, rotation=0, ha='right', va='center')

# Adjust layout to show labels
plt.subplots_adjust(left=0.15, wspace=0.05, hspace=0.1)
plt.show()

In [None]:
# ============================================================================
# DEMO 3: SIDE-BY-SIDE COMPARISON WITH GROUND TRUTH
# Compare what the model generates vs. the target distribution
# ============================================================================

fig, axes = plt.subplots(len(ALL_PROMPTS), 6, figsize=(12, len(ALL_PROMPTS) * 1.5))
fig.suptitle('Ground Truth (left 3) vs Generated (right 3)', fontsize=14, fontweight='bold', y=1.02)

for row, prompt in enumerate(ALL_PROMPTS):
    # Ground truth examples
    for col in range(3):
        gt_img = PROMPT_GENERATORS[prompt]()
        axes[row, col].imshow(gt_img, cmap='gray', vmin=0, vmax=1)
        axes[row, col].axis('off')
        if col == 0:
            axes[row, col].set_title('GT', fontsize=8)
    
    # Generated examples
    generated = model.generate(prompt, tokenizer, num_samples=3, device=DEVICE)
    for col in range(3):
        img = generated[col, 0].cpu().numpy()
        axes[row, col + 3].imshow(img, cmap='gray', vmin=0, vmax=1)
        axes[row, col + 3].axis('off')
        if col == 0:
            axes[row, col + 3].set_title('Gen', fontsize=8)
    
    # Add prompt label on the left
    axes[row, 0].set_ylabel(prompt, fontsize=9, rotation=0, ha='right', va='center')

plt.subplots_adjust(left=0.18, wspace=0.05, hspace=0.15)
plt.show()

---
## 7. Limitations & Extensions

### What This Model CAN Do ✅

- Generate recognizable shapes from text prompts
- Produce **varied outputs** from the same prompt (stochastic generation)
- Run quickly on CPU
- Serve as an educational example of text-to-image architecture

### What This Model CANNOT Do ❌

- Generalize to unseen prompts (it only knows the 9 trained prompts)
- Generate complex or realistic images
- Understand compositional prompts like "red circle on the left"
- Produce high-resolution outputs

### Why These Limitations Exist

1. **Small vocabulary**: Only 9 prompts means limited semantic understanding
2. **Character-level encoding**: Can't capture word meanings
3. **MLP decoder**: Limited spatial awareness (no convolutions)
4. **28×28 resolution**: Inherently low detail
5. **No attention mechanism**: Can't focus on relevant parts of the prompt

### Possible Extensions (For Further Learning)

| Extension | Difficulty | Description |
|-----------|------------|-------------|
| CNN Decoder | Easy | Replace MLP with transposed convolutions for better spatial structure |
| More Prompts | Easy | Add "triangle", "rectangle", "three dots", etc. |
| Word Embeddings | Medium | Use word-level instead of character-level tokenization |
| Larger Images | Medium | Scale to 64×64 or 128×128 |
| Conditioning Injection | Medium | Inject latent at multiple decoder layers |
| Attention Mechanism | Hard | Add cross-attention between text and image features |
| Diffusion | Hard | Replace direct generation with iterative denoising |

In [None]:
# ============================================================================
# INTERACTIVE DEMO: TRY YOUR OWN PROMPTS
# Note: Only trained prompts will work well!
# ============================================================================

def interactive_generate(prompt, num_samples=4):
    """
    Generate images from a custom prompt.
    
    Warning: The model only understands the 9 trained prompts.
    Unseen prompts will produce unpredictable results!
    """
    print(f"Generating {num_samples} images for: '{prompt}'")
    
    if prompt not in ALL_PROMPTS:
        print(f"⚠️  Warning: '{prompt}' was not in training data!")
        print(f"   Trained prompts: {ALL_PROMPTS}")
    
    visualize_generations(model, tokenizer, prompt, num_samples=num_samples)


# Try some prompts
interactive_generate("circle")
interactive_generate("sparse dots")

# This won't work well - unseen prompt!
interactive_generate("triangle")

In [None]:
# ============================================================================
# SAVE THE MODEL (OPTIONAL)
# ============================================================================

# Uncomment to save:
# torch.save({
#     'model_state_dict': model.state_dict(),
#     'tokenizer_vocab': tokenizer.char_to_idx,
#     'prompts': ALL_PROMPTS,
# }, 'text_to_image_model.pt')
# print("Model saved to 'text_to_image_model.pt'")

print("\n" + "="*60)
print("NOTEBOOK COMPLETE!")
print("="*60)
print(f"\nSummary:")
print(f"  - Trained on {len(ALL_PROMPTS)} prompt categories")
print(f"  - Model size: {total_params:,} parameters")
print(f"  - Final training loss: {losses[-1]:.6f}")
print(f"\nThe model can generate 28×28 grayscale images from text prompts!")