# Text-to-Image Generator (Expanded Vocabulary)

A text-to-image model with **150+ natural language prompts** generating **28×28 grayscale images**.

## Expanded Vocabulary

This version supports natural language variations:
- **Sizes**: "small circle", "tiny dot", "big square", "large ring"
- **Positions**: "circle at top", "line on the left", "dot in corner"
- **Quantities**: "two circles", "three dots", "pair of lines"
- **Synonyms**: "ring" = "circle", "box" = "square", "point" = "dot"

## Constraints
- ✅ PyTorch only, no pretrained models, no diffusion, no transformers, no CLIP
- ✅ Fully synthetic data, runs on CPU/GPU

---
## 1. Setup

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import random
from itertools import product

torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

# Auto-detect GPU
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {DEVICE}")

---
## 2. Parameterized Image Generators

In [None]:
# ============================================================================
# PARAMETERIZED IMAGE GENERATORS
# Each function accepts size, position, and count parameters
# ============================================================================

def get_size_params(size):
    """Convert size name to numeric range."""
    sizes = {
        'tiny': (2, 4),
        'small': (3, 6),
        'medium': (6, 10),
        'large': (9, 13),
        'big': (10, 14),
    }
    return sizes.get(size, (5, 10))

def get_position_offset(position):
    """Convert position name to (dx, dy) offset from center."""
    positions = {
        'center': (0, 0),
        'top': (0, -6),
        'bottom': (0, 6),
        'left': (-6, 0),
        'right': (6, 0),
        'top left': (-5, -5),
        'top right': (5, -5),
        'bottom left': (-5, 5),
        'bottom right': (5, 5),
        'corner': (6, 6),  # Random corner
    }
    if position in positions:
        return positions[position]
    return (0, 0)


def generate_circle(size='medium', position='center', count=1):
    """Circle/ring with configurable size, position, count."""
    img = np.zeros((28, 28), dtype=np.float32)
    min_r, max_r = get_size_params(size)
    dx, dy = get_position_offset(position)
    
    for i in range(count):
        offset_x = (i - count//2) * 8 if count > 1 else 0
        cx = 14 + dx + offset_x + np.random.randint(-2, 3)
        cy = 14 + dy + np.random.randint(-2, 3)
        radius = np.random.randint(min_r, max_r)
        
        for theta in np.linspace(0, 2 * np.pi, 100):
            x = int(cx + radius * np.cos(theta))
            y = int(cy + radius * np.sin(theta))
            if 0 <= x < 28 and 0 <= y < 28:
                img[y, x] = 1.0
    return img


def generate_filled_circle(size='medium', position='center', count=1):
    """Filled circle/disk."""
    img = np.zeros((28, 28), dtype=np.float32)
    min_r, max_r = get_size_params(size)
    dx, dy = get_position_offset(position)
    
    for i in range(count):
        offset_x = (i - count//2) * 8 if count > 1 else 0
        cx = 14 + dx + offset_x + np.random.randint(-2, 3)
        cy = 14 + dy + np.random.randint(-2, 3)
        radius = np.random.randint(min_r, max_r)
        
        for y in range(28):
            for x in range(28):
                if (x - cx)**2 + (y - cy)**2 <= radius**2:
                    img[y, x] = 1.0
    return img


def generate_square(size='medium', position='center', count=1):
    """Square outline."""
    img = np.zeros((28, 28), dtype=np.float32)
    min_s, max_s = get_size_params(size)
    dx, dy = get_position_offset(position)
    
    for i in range(count):
        offset_x = (i - count//2) * 10 if count > 1 else 0
        s = np.random.randint(min_s, max_s)
        cx = 14 + dx + offset_x + np.random.randint(-2, 3)
        cy = 14 + dy + np.random.randint(-2, 3)
        x1, y1 = max(0, cx - s), max(0, cy - s)
        x2, y2 = min(27, cx + s), min(27, cy + s)
        
        img[y1, x1:x2] = 1.0
        img[y2, x1:x2+1] = 1.0
        img[y1:y2, x1] = 1.0
        img[y1:y2+1, x2] = 1.0
    return img


def generate_filled_square(size='medium', position='center', count=1):
    """Filled square/box."""
    img = np.zeros((28, 28), dtype=np.float32)
    min_s, max_s = get_size_params(size)
    dx, dy = get_position_offset(position)
    
    for i in range(count):
        offset_x = (i - count//2) * 10 if count > 1 else 0
        s = np.random.randint(min_s, max_s)
        cx = 14 + dx + offset_x + np.random.randint(-2, 3)
        cy = 14 + dy + np.random.randint(-2, 3)
        x1, y1 = max(0, cx - s), max(0, cy - s)
        x2, y2 = min(27, cx + s), min(27, cy + s)
        img[y1:y2+1, x1:x2+1] = 1.0
    return img


def generate_horizontal_line(size='medium', position='center', count=1):
    """Horizontal line."""
    img = np.zeros((28, 28), dtype=np.float32)
    dx, dy = get_position_offset(position)
    thickness = 1 if size in ['tiny', 'small'] else 2
    
    for i in range(count):
        y = 14 + dy + (i - count//2) * 6 + np.random.randint(-2, 3)
        y = max(1, min(26, y))
        x1 = np.random.randint(2, 6)
        x2 = np.random.randint(22, 27)
        for t in range(thickness):
            if 0 <= y + t < 28:
                img[y + t, x1:x2] = 1.0
    return img


def generate_vertical_line(size='medium', position='center', count=1):
    """Vertical line."""
    img = np.zeros((28, 28), dtype=np.float32)
    dx, dy = get_position_offset(position)
    thickness = 1 if size in ['tiny', 'small'] else 2
    
    for i in range(count):
        x = 14 + dx + (i - count//2) * 6 + np.random.randint(-2, 3)
        x = max(1, min(26, x))
        y1 = np.random.randint(2, 6)
        y2 = np.random.randint(22, 27)
        for t in range(thickness):
            if 0 <= x + t < 28:
                img[y1:y2, x + t] = 1.0
    return img


def generate_diagonal_line(size='medium', position='center', count=1):
    """Diagonal line."""
    img = np.zeros((28, 28), dtype=np.float32)
    dx, dy = get_position_offset(position)
    
    for i in range(count):
        direction = np.random.choice([1, -1])
        offset = (i - count//2) * 5 + np.random.randint(-2, 3)
        for j in range(28):
            x = j
            y = j * direction + 14 * (1 - direction) + offset + dy
            if 0 <= x < 28 and 0 <= y < 28:
                img[int(y), x] = 1.0
                if int(y) + 1 < 28:
                    img[int(y) + 1, x] = 1.0
    return img


def generate_cross(size='medium', position='center', count=1):
    """Plus/cross shape."""
    img = np.zeros((28, 28), dtype=np.float32)
    min_a, max_a = get_size_params(size)
    dx, dy = get_position_offset(position)
    
    for i in range(count):
        offset_x = (i - count//2) * 10 if count > 1 else 0
        cx = 14 + dx + offset_x + np.random.randint(-2, 3)
        cy = 14 + dy + np.random.randint(-2, 3)
        arm = np.random.randint(min_a, max_a)
        
        for d in range(-arm, arm + 1):
            if 0 <= cx + d < 28:
                img[cy, cx + d] = 1.0
            if 0 <= cy + d < 28:
                img[cy + d, cx] = 1.0
    return img


def generate_dot(size='medium', position='center', count=1):
    """Single dot/point."""
    img = np.zeros((28, 28), dtype=np.float32)
    dx, dy = get_position_offset(position)
    dot_size = 1 if size in ['tiny', 'small'] else 2
    
    for i in range(count):
        cx = 14 + dx + np.random.randint(-8, 9)
        cy = 14 + dy + np.random.randint(-8, 9)
        for ddx in range(dot_size):
            for ddy in range(dot_size):
                if 0 <= cx + ddx < 28 and 0 <= cy + ddy < 28:
                    img[cy + ddy, cx + ddx] = 1.0
    return img


def generate_triangle(size='medium', position='center', count=1):
    """Triangle outline."""
    img = np.zeros((28, 28), dtype=np.float32)
    min_s, max_s = get_size_params(size)
    dx, dy = get_position_offset(position)
    
    for i in range(count):
        offset_x = (i - count//2) * 10 if count > 1 else 0
        s = np.random.randint(min_s, max_s)
        cx = 14 + dx + offset_x + np.random.randint(-2, 3)
        cy = 14 + dy + np.random.randint(-2, 3)
        
        # Three vertices
        p1 = (cx, cy - s)  # Top
        p2 = (cx - s, cy + s)  # Bottom left
        p3 = (cx + s, cy + s)  # Bottom right
        
        # Draw lines between vertices
        for p_start, p_end in [(p1, p2), (p2, p3), (p3, p1)]:
            for t in np.linspace(0, 1, 50):
                x = int(p_start[0] * (1-t) + p_end[0] * t)
                y = int(p_start[1] * (1-t) + p_end[1] * t)
                if 0 <= x < 28 and 0 <= y < 28:
                    img[y, x] = 1.0
    return img


def generate_x_shape(size='medium', position='center', count=1):
    """X shape (diagonal cross)."""
    img = np.zeros((28, 28), dtype=np.float32)
    min_s, max_s = get_size_params(size)
    dx, dy = get_position_offset(position)
    
    for i in range(count):
        offset_x = (i - count//2) * 10 if count > 1 else 0
        s = np.random.randint(min_s, max_s)
        cx = 14 + dx + offset_x + np.random.randint(-2, 3)
        cy = 14 + dy + np.random.randint(-2, 3)
        
        for d in range(-s, s + 1):
            if 0 <= cx + d < 28 and 0 <= cy + d < 28:
                img[cy + d, cx + d] = 1.0
            if 0 <= cx + d < 28 and 0 <= cy - d < 28:
                img[cy - d, cx + d] = 1.0
    return img


def generate_blob(size='medium', position='center', count=1):
    """Gaussian blob."""
    img = np.zeros((28, 28), dtype=np.float32)
    min_s, max_s = get_size_params(size)
    dx, dy = get_position_offset(position)
    sigma = np.random.uniform(min_s * 0.7, max_s * 0.7)
    
    for i in range(count):
        offset_x = (i - count//2) * 10 if count > 1 else 0
        cx = 14 + dx + offset_x + np.random.randint(-2, 3)
        cy = 14 + dy + np.random.randint(-2, 3)
        
        for y in range(28):
            for x in range(28):
                dist_sq = (x - cx)**2 + (y - cy)**2
                img[y, x] = max(img[y, x], np.exp(-dist_sq / (2 * sigma**2)))
    return img


def generate_random_dots(size='medium', position='center', count=1):
    """Random scattered dots."""
    img = np.zeros((28, 28), dtype=np.float32)
    n_dots = count * np.random.randint(3, 8)
    dx, dy = get_position_offset(position)
    
    for _ in range(n_dots):
        x = 14 + dx + np.random.randint(-10, 11)
        y = 14 + dy + np.random.randint(-10, 11)
        if 0 <= x < 28 and 0 <= y < 28:
            img[y, x] = 1.0
    return img


def generate_grid(size='medium', position='center', count=1):
    """Grid pattern."""
    img = np.zeros((28, 28), dtype=np.float32)
    spacing = 3 if size in ['tiny', 'small'] else 5
    
    for i in range(2, 26, spacing):
        img[i, 2:26] = 1.0
        img[2:26, i] = 1.0
    return img


def generate_border(size='medium', position='center', count=1):
    """Rectangle border/frame."""
    img = np.zeros((28, 28), dtype=np.float32)
    thickness = 1 if size in ['tiny', 'small'] else 2
    
    for t in range(thickness):
        img[2+t, 2:26] = 1.0
        img[25-t, 2:26] = 1.0
        img[2:26, 2+t] = 1.0
        img[2:26, 25-t] = 1.0
    return img


print("✅ Parameterized generators defined")

---
## 3. Natural Language Prompt Templates

In [None]:
# ============================================================================
# NATURAL LANGUAGE PROMPT GENERATION
# Creates all combinations of shapes, sizes, positions, quantities
# ============================================================================

# Shape name → generator function mapping (with synonyms)
SHAPE_GENERATORS = {
    'circle': generate_circle,
    'ring': generate_circle,
    'oval': generate_circle,
    'disk': generate_filled_circle,
    'filled circle': generate_filled_circle,
    'square': generate_square,
    'box': generate_square,
    'rectangle': generate_square,
    'filled square': generate_filled_square,
    'solid square': generate_filled_square,
    'block': generate_filled_square,
    'horizontal line': generate_horizontal_line,
    'horizontal bar': generate_horizontal_line,
    'vertical line': generate_vertical_line,
    'vertical bar': generate_vertical_line,
    'diagonal line': generate_diagonal_line,
    'diagonal': generate_diagonal_line,
    'slash': generate_diagonal_line,
    'cross': generate_cross,
    'plus': generate_cross,
    'plus sign': generate_cross,
    'dot': generate_dot,
    'point': generate_dot,
    'pixel': generate_dot,
    'triangle': generate_triangle,
    'x': generate_x_shape,
    'x shape': generate_x_shape,
    'blob': generate_blob,
    'gradient': generate_blob,
    'glow': generate_blob,
    'dots': generate_random_dots,
    'scattered dots': generate_random_dots,
    'random dots': generate_random_dots,
    'noise': generate_random_dots,
    'grid': generate_grid,
    'mesh': generate_grid,
    'border': generate_border,
    'frame': generate_border,
    'outline': generate_border,
}

# Modifiers
SIZES = ['', 'tiny', 'small', 'medium', 'large', 'big']
POSITIONS = ['', 'at top', 'at bottom', 'on the left', 'on the right', 'in the center',
             'at top left', 'at top right', 'at bottom left', 'at bottom right']
QUANTITIES = {
    '': 1, 'single': 1, 'one': 1, 'a': 1,
    'two': 2, 'pair of': 2, 'double': 2,
    'three': 3, 'triple': 3, 'few': 3,
}

def parse_position(pos_str):
    """Convert position phrase to position name."""
    mapping = {
        '': 'center',
        'at top': 'top',
        'at bottom': 'bottom',
        'on the left': 'left',
        'on the right': 'right',
        'in the center': 'center',
        'at top left': 'top left',
        'at top right': 'top right',
        'at bottom left': 'bottom left',
        'at bottom right': 'bottom right',
    }
    return mapping.get(pos_str, 'center')


def build_prompt(quantity_word, size, shape, position):
    """Build a natural language prompt from components."""
    parts = []
    if quantity_word:
        parts.append(quantity_word)
    if size:
        parts.append(size)
    parts.append(shape)
    if position:
        parts.append(position)
    return ' '.join(parts)


# Generate all valid prompt combinations
ALL_PROMPTS = {}
for shape_name, gen_func in SHAPE_GENERATORS.items():
    for size in SIZES:
        for position in POSITIONS:
            for qty_word, qty_num in QUANTITIES.items():
                # Skip some redundant combinations
                if qty_num > 1 and shape_name in ['grid', 'border', 'frame', 'mesh', 'outline']:
                    continue
                    
                prompt = build_prompt(qty_word, size, shape_name, position)
                prompt = prompt.strip()
                
                ALL_PROMPTS[prompt] = {
                    'generator': gen_func,
                    'size': size if size else 'medium',
                    'position': parse_position(position),
                    'count': qty_num,
                }

print(f"✅ Generated {len(ALL_PROMPTS)} unique prompts")
print(f"\nSample prompts:")
sample_prompts = random.sample(list(ALL_PROMPTS.keys()), 15)
for p in sorted(sample_prompts, key=len):
    print(f"  • {p}")

In [None]:
# ============================================================================
# VISUALIZE SAMPLE PROMPTS
# ============================================================================

demo_prompts = [
    'circle', 'small circle', 'large circle at top',
    'square', 'two squares', 'tiny square on the left',
    'triangle', 'big triangle', 'three dots',
    'horizontal line', 'vertical line at bottom', 'cross',
]

fig, axes = plt.subplots(3, 4, figsize=(12, 9))
axes = axes.flatten()

for ax, prompt in zip(axes, demo_prompts):
    if prompt in ALL_PROMPTS:
        cfg = ALL_PROMPTS[prompt]
        img = cfg['generator'](cfg['size'], cfg['position'], cfg['count'])
        ax.imshow(img, cmap='gray', vmin=0, vmax=1)
    ax.set_title(prompt, fontsize=9)
    ax.axis('off')

plt.tight_layout()
plt.show()

---
## 4. Dataset & Tokenizer

In [None]:
# ============================================================================
# CHARACTER TOKENIZER
# ============================================================================

class CharacterTokenizer:
    def __init__(self):
        self.char_to_idx = {'<PAD>': 0, ' ': 1}
        for i, char in enumerate('abcdefghijklmnopqrstuvwxyz'):
            self.char_to_idx[char] = i + 2
        self.idx_to_char = {v: k for k, v in self.char_to_idx.items()}
        self.vocab_size = len(self.char_to_idx)
    
    def encode(self, text, max_length=64):
        text = text.lower()
        tokens = [self.char_to_idx.get(c, 0) for c in text if c in self.char_to_idx]
        tokens = tokens[:max_length]
        tokens += [0] * (max_length - len(tokens))
        return tokens

tokenizer = CharacterTokenizer()
print(f"Vocabulary size: {tokenizer.vocab_size}")

In [None]:
# ============================================================================
# DATASET CLASS
# ============================================================================

class ExpandedTextImageDataset(Dataset):
    def __init__(self, prompts_dict, samples_per_prompt=100):
        self.prompts = list(prompts_dict.keys())
        self.configs = prompts_dict
        self.samples_per_prompt = samples_per_prompt
        self.index_to_prompt = []
        for p in self.prompts:
            self.index_to_prompt.extend([p] * samples_per_prompt)
    
    def __len__(self):
        return len(self.index_to_prompt)
    
    def __getitem__(self, idx):
        prompt = self.index_to_prompt[idx]
        cfg = self.configs[prompt]
        img = cfg['generator'](cfg['size'], cfg['position'], cfg['count'])
        img_tensor = torch.from_numpy(img).unsqueeze(0)
        return prompt, img_tensor

train_dataset = ExpandedTextImageDataset(ALL_PROMPTS, samples_per_prompt=50)
print(f"Dataset size: {len(train_dataset)} samples")

---
## 5. Model Architecture

In [None]:
# ============================================================================
# MODEL (same architecture, larger capacity)
# ============================================================================

class TextEncoder(nn.Module):
    def __init__(self, vocab_size, embed_dim=64, latent_dim=128, max_length=64):
        super().__init__()
        self.max_length = max_length
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        self.projection = nn.Sequential(
            nn.Linear(embed_dim, 256),
            nn.ReLU(),
            nn.Linear(256, latent_dim)
        )
    
    def forward(self, token_ids):
        embeds = self.embedding(token_ids)
        mask = (token_ids != 0).float().unsqueeze(-1)
        pooled = (embeds * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1)
        return self.projection(pooled)


class ImageGenerator(nn.Module):
    def __init__(self, latent_dim=128, noise_dim=64):
        super().__init__()
        self.noise_dim = noise_dim
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim + noise_dim, 512),
            nn.LayerNorm(512),
            nn.ReLU(),
            nn.Linear(512, 1024),
            nn.LayerNorm(1024),
            nn.ReLU(),
            nn.Linear(1024, 784),
            nn.Sigmoid()
        )
    
    def forward(self, latent, noise=None):
        if noise is None:
            noise = torch.randn(latent.shape[0], self.noise_dim, device=latent.device)
        combined = torch.cat([latent, noise], dim=1)
        return self.decoder(combined).view(-1, 1, 28, 28)


class TextToImageModel(nn.Module):
    def __init__(self, vocab_size, embed_dim=64, latent_dim=128, noise_dim=64, max_length=64):
        super().__init__()
        self.max_length = max_length
        self.noise_dim = noise_dim
        self.text_encoder = TextEncoder(vocab_size, embed_dim, latent_dim, max_length)
        self.image_generator = ImageGenerator(latent_dim, noise_dim)
    
    def forward(self, token_ids, noise=None):
        latent = self.text_encoder(token_ids)
        return self.image_generator(latent, noise)
    
    def generate(self, prompt, tokenizer, num_samples=1, device='cpu'):
        self.eval()
        with torch.no_grad():
            tokens = tokenizer.encode(prompt, max_length=self.max_length)
            token_ids = torch.tensor([tokens] * num_samples, device=device)
            return self.forward(token_ids)

model = TextToImageModel(vocab_size=tokenizer.vocab_size).to(DEVICE)
total_params = sum(p.numel() for p in model.parameters())
print(f"Model parameters: {total_params:,}")

---
## 6. Training

In [None]:
# ============================================================================
# TRAINING CONFIGURATION
# ============================================================================

BATCH_SIZE = 128
NUM_EPOCHS = 150
LEARNING_RATE = 1e-3

def collate_fn(batch):
    prompts, images = zip(*batch)
    token_ids = torch.tensor([tokenizer.encode(p) for p in prompts])
    images = torch.stack(images)
    return token_ids, images

train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    collate_fn=collate_fn,
    num_workers=0
)

print(f"Batches per epoch: {len(train_loader)}")

In [None]:
# ============================================================================
# TRAINING LOOP
# ============================================================================

optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
criterion = nn.MSELoss()
losses = []

print("Starting training...\n")
for epoch in range(NUM_EPOCHS):
    model.train()
    epoch_loss = 0.0
    
    for token_ids, target_images in train_loader:
        token_ids = token_ids.to(DEVICE)
        target_images = target_images.to(DEVICE)
        
        generated = model(token_ids)
        loss = criterion(generated, target_images)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()
    
    avg_loss = epoch_loss / len(train_loader)
    losses.append(avg_loss)
    
    if (epoch + 1) % 10 == 0 or epoch == 0:
        print(f"Epoch {epoch+1:3d}/{NUM_EPOCHS} | Loss: {avg_loss:.6f}")

print("\n✅ Training complete!")

In [None]:
# Plot loss
plt.figure(figsize=(10, 4))
plt.plot(losses)
plt.xlabel('Epoch')
plt.ylabel('MSE Loss')
plt.title('Training Loss')
plt.grid(True, alpha=0.3)
plt.show()

---
## 7. Interactive Generation

In [None]:
# ============================================================================
# GENERATE FROM ANY PROMPT
# ============================================================================

def generate_and_show(prompt, num_samples=4):
    """Generate images from any prompt."""
    print(f"Prompt: '{prompt}'")
    
    if prompt in ALL_PROMPTS:
        print("  ✓ Known prompt")
    else:
        print("  ⚠ Novel prompt (may work if similar to training data)")
    
    images = model.generate(prompt, tokenizer, num_samples=num_samples, device=DEVICE)
    
    fig, axes = plt.subplots(1, num_samples, figsize=(3*num_samples, 3))
    for i, ax in enumerate(axes):
        ax.imshow(images[i, 0].cpu().numpy(), cmap='gray', vmin=0, vmax=1)
        ax.axis('off')
    plt.tight_layout()
    plt.show()


# Test known prompts
generate_and_show("small circle")
generate_and_show("large triangle at top")
generate_and_show("two squares")
generate_and_show("vertical line on the left")

In [None]:
# ============================================================================
# TRY NOVEL PROMPTS (may or may not work!)
# ============================================================================

print("Testing novel prompts (not in training data):\n")

novel_prompts = [
    "big blob at bottom",
    "tiny x",
    "three triangles",
    "small grid",
]

for p in novel_prompts:
    generate_and_show(p)

In [None]:
# ============================================================================
# COMPARISON GRID
# ============================================================================

test_prompts = [
    "circle", "small circle", "large circle",
    "square", "small square", "large square",
    "triangle", "small triangle", "large triangle",
    "dot", "two dots", "three dots",
]

fig, axes = plt.subplots(4, 3, figsize=(9, 12))
axes = axes.flatten()

for ax, prompt in zip(axes, test_prompts):
    img = model.generate(prompt, tokenizer, num_samples=1, device=DEVICE)
    ax.imshow(img[0, 0].cpu().numpy(), cmap='gray', vmin=0, vmax=1)
    ax.set_title(prompt, fontsize=10)
    ax.axis('off')

plt.tight_layout()
plt.show()

In [None]:
print("\n" + "="*60)
print("TRAINING COMPLETE!")
print("="*60)
print(f"\nDataset: {len(ALL_PROMPTS)} unique prompts")
print(f"Model: {total_params:,} parameters")
print(f"Final loss: {losses[-1]:.6f}")
print("\nTry any prompt! The model understands:")
print("  - Shapes: circle, square, triangle, line, cross, dot, blob, grid, border, x")
print("  - Sizes: tiny, small, medium, large, big")
print("  - Positions: at top, at bottom, on the left, on the right, in the center")
print("  - Quantities: one, two, three, pair of, few")