# Text-to-Image Generator (Expanded Vocabulary)

A text-to-image model with **150+ natural language prompts** generating **28×28 grayscale images**.

## Features
- Sizes: tiny, small, medium, large, big
- Positions: at top, at bottom, on the left, on the right
- Quantities: two, three, pair of, few
- Synonyms: ring=circle, box=square, point=dot

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import random

torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {DEVICE}")

In [None]:
# ============================================================================
# PARAMETERIZED IMAGE GENERATORS (with bounds checking)
# ============================================================================

def get_size_params(size):
    sizes = {'tiny': (2, 4), 'small': (3, 6), 'medium': (5, 9), 'large': (8, 12), 'big': (9, 13)}
    return sizes.get(size, (5, 9))

def get_position_offset(position):
    positions = {
        'center': (0, 0), 'top': (0, -6), 'bottom': (0, 6),
        'left': (-6, 0), 'right': (6, 0),
        'top left': (-5, -5), 'top right': (5, -5),
        'bottom left': (-5, 5), 'bottom right': (5, 5),
    }
    return positions.get(position, (0, 0))

def clamp(val, lo, hi):
    return max(lo, min(hi, val))

def generate_circle(size='medium', position='center', count=1):
    img = np.zeros((28, 28), dtype=np.float32)
    min_r, max_r = get_size_params(size)
    dx, dy = get_position_offset(position)
    for i in range(count):
        offset_x = (i - count//2) * 8 if count > 1 else 0
        cx = clamp(14 + dx + offset_x + np.random.randint(-2, 3), 2, 25)
        cy = clamp(14 + dy + np.random.randint(-2, 3), 2, 25)
        radius = np.random.randint(min_r, max_r)
        for theta in np.linspace(0, 2 * np.pi, 100):
            x = int(cx + radius * np.cos(theta))
            y = int(cy + radius * np.sin(theta))
            if 0 <= x < 28 and 0 <= y < 28:
                img[y, x] = 1.0
    return img

def generate_filled_circle(size='medium', position='center', count=1):
    img = np.zeros((28, 28), dtype=np.float32)
    min_r, max_r = get_size_params(size)
    dx, dy = get_position_offset(position)
    for i in range(count):
        offset_x = (i - count//2) * 8 if count > 1 else 0
        cx = clamp(14 + dx + offset_x + np.random.randint(-2, 3), 2, 25)
        cy = clamp(14 + dy + np.random.randint(-2, 3), 2, 25)
        radius = np.random.randint(min_r, max_r)
        for y in range(28):
            for x in range(28):
                if (x - cx)**2 + (y - cy)**2 <= radius**2:
                    img[y, x] = 1.0
    return img

def generate_square(size='medium', position='center', count=1):
    img = np.zeros((28, 28), dtype=np.float32)
    min_s, max_s = get_size_params(size)
    dx, dy = get_position_offset(position)
    for i in range(count):
        offset_x = (i - count//2) * 10 if count > 1 else 0
        s = np.random.randint(min_s, max_s)
        cx = clamp(14 + dx + offset_x + np.random.randint(-2, 3), s+1, 26-s)
        cy = clamp(14 + dy + np.random.randint(-2, 3), s+1, 26-s)
        x1, y1 = cx - s, cy - s
        x2, y2 = cx + s, cy + s
        img[y1, x1:x2] = 1.0
        img[y2, x1:x2+1] = 1.0
        img[y1:y2, x1] = 1.0
        img[y1:y2+1, x2] = 1.0
    return img

def generate_filled_square(size='medium', position='center', count=1):
    img = np.zeros((28, 28), dtype=np.float32)
    min_s, max_s = get_size_params(size)
    dx, dy = get_position_offset(position)
    for i in range(count):
        offset_x = (i - count//2) * 10 if count > 1 else 0
        s = np.random.randint(min_s, max_s)
        cx = clamp(14 + dx + offset_x + np.random.randint(-2, 3), s+1, 26-s)
        cy = clamp(14 + dy + np.random.randint(-2, 3), s+1, 26-s)
        x1, y1 = cx - s, cy - s
        x2, y2 = cx + s, cy + s
        img[y1:y2+1, x1:x2+1] = 1.0
    return img

def generate_horizontal_line(size='medium', position='center', count=1):
    img = np.zeros((28, 28), dtype=np.float32)
    dx, dy = get_position_offset(position)
    thickness = 1 if size in ['tiny', 'small'] else 2
    for i in range(count):
        y = clamp(14 + dy + (i - count//2) * 6 + np.random.randint(-2, 3), 2, 25)
        x1, x2 = np.random.randint(2, 6), np.random.randint(22, 27)
        for t in range(thickness):
            if 0 <= y + t < 28:
                img[y + t, x1:x2] = 1.0
    return img

def generate_vertical_line(size='medium', position='center', count=1):
    img = np.zeros((28, 28), dtype=np.float32)
    dx, dy = get_position_offset(position)
    thickness = 1 if size in ['tiny', 'small'] else 2
    for i in range(count):
        x = clamp(14 + dx + (i - count//2) * 6 + np.random.randint(-2, 3), 2, 25)
        y1, y2 = np.random.randint(2, 6), np.random.randint(22, 27)
        for t in range(thickness):
            if 0 <= x + t < 28:
                img[y1:y2, x + t] = 1.0
    return img

def generate_diagonal_line(size='medium', position='center', count=1):
    img = np.zeros((28, 28), dtype=np.float32)
    dx, dy = get_position_offset(position)
    for i in range(count):
        direction = np.random.choice([1, -1])
        offset = (i - count//2) * 5 + np.random.randint(-2, 3)
        for j in range(28):
            x = j
            y = j * direction + 14 * (1 - direction) + offset + dy
            if 0 <= x < 28 and 0 <= y < 28:
                img[int(y), x] = 1.0
    return img

def generate_cross(size='medium', position='center', count=1):
    img = np.zeros((28, 28), dtype=np.float32)
    min_a, max_a = get_size_params(size)
    dx, dy = get_position_offset(position)
    for i in range(count):
        offset_x = (i - count//2) * 10 if count > 1 else 0
        cx = clamp(14 + dx + offset_x + np.random.randint(-2, 3), 2, 25)
        cy = clamp(14 + dy + np.random.randint(-2, 3), 2, 25)
        arm = np.random.randint(min_a, max_a)
        for d in range(-arm, arm + 1):
            if 0 <= cx + d < 28:
                img[cy, cx + d] = 1.0
            if 0 <= cy + d < 28:
                img[cy + d, cx] = 1.0
    return img

def generate_dot(size='medium', position='center', count=1):
    img = np.zeros((28, 28), dtype=np.float32)
    dx, dy = get_position_offset(position)
    dot_size = 1 if size in ['tiny', 'small'] else 2
    for i in range(count):
        cx = clamp(14 + dx + np.random.randint(-8, 9), 1, 26)
        cy = clamp(14 + dy + np.random.randint(-8, 9), 1, 26)
        for ddx in range(dot_size):
            for ddy in range(dot_size):
                if 0 <= cx + ddx < 28 and 0 <= cy + ddy < 28:
                    img[cy + ddy, cx + ddx] = 1.0
    return img

def generate_triangle(size='medium', position='center', count=1):
    img = np.zeros((28, 28), dtype=np.float32)
    min_s, max_s = get_size_params(size)
    dx, dy = get_position_offset(position)
    for i in range(count):
        offset_x = (i - count//2) * 10 if count > 1 else 0
        s = np.random.randint(min_s, max_s)
        cx = clamp(14 + dx + offset_x + np.random.randint(-2, 3), s+2, 25-s)
        cy = clamp(14 + dy + np.random.randint(-2, 3), s+2, 25-s)
        p1 = (cx, cy - s)
        p2 = (cx - s, cy + s)
        p3 = (cx + s, cy + s)
        for p_start, p_end in [(p1, p2), (p2, p3), (p3, p1)]:
            for t in np.linspace(0, 1, 50):
                x = int(p_start[0] * (1-t) + p_end[0] * t)
                y = int(p_start[1] * (1-t) + p_end[1] * t)
                if 0 <= x < 28 and 0 <= y < 28:
                    img[y, x] = 1.0
    return img

def generate_x_shape(size='medium', position='center', count=1):
    img = np.zeros((28, 28), dtype=np.float32)
    min_s, max_s = get_size_params(size)
    dx, dy = get_position_offset(position)
    for i in range(count):
        offset_x = (i - count//2) * 10 if count > 1 else 0
        s = np.random.randint(min_s, max_s)
        cx = clamp(14 + dx + offset_x + np.random.randint(-2, 3), s+2, 25-s)
        cy = clamp(14 + dy + np.random.randint(-2, 3), s+2, 25-s)
        for d in range(-s, s + 1):
            if 0 <= cx + d < 28 and 0 <= cy + d < 28:
                img[cy + d, cx + d] = 1.0
            if 0 <= cx + d < 28 and 0 <= cy - d < 28:
                img[cy - d, cx + d] = 1.0
    return img

def generate_blob(size='medium', position='center', count=1):
    img = np.zeros((28, 28), dtype=np.float32)
    min_s, max_s = get_size_params(size)
    dx, dy = get_position_offset(position)
    sigma = np.random.uniform(min_s * 0.7, max_s * 0.7)
    for i in range(count):
        offset_x = (i - count//2) * 10 if count > 1 else 0
        cx = 14 + dx + offset_x + np.random.randint(-2, 3)
        cy = 14 + dy + np.random.randint(-2, 3)
        for y in range(28):
            for x in range(28):
                dist_sq = (x - cx)**2 + (y - cy)**2
                img[y, x] = max(img[y, x], np.exp(-dist_sq / (2 * sigma**2)))
    return img

def generate_random_dots(size='medium', position='center', count=1):
    img = np.zeros((28, 28), dtype=np.float32)
    n_dots = count * np.random.randint(3, 8)
    dx, dy = get_position_offset(position)
    for _ in range(n_dots):
        x = clamp(14 + dx + np.random.randint(-10, 11), 0, 27)
        y = clamp(14 + dy + np.random.randint(-10, 11), 0, 27)
        img[y, x] = 1.0
    return img

def generate_grid(size='medium', position='center', count=1):
    img = np.zeros((28, 28), dtype=np.float32)
    spacing = 3 if size in ['tiny', 'small'] else 5
    for i in range(2, 26, spacing):
        img[i, 2:26] = 1.0
        img[2:26, i] = 1.0
    return img

def generate_border(size='medium', position='center', count=1):
    img = np.zeros((28, 28), dtype=np.float32)
    thickness = 1 if size in ['tiny', 'small'] else 2
    for t in range(thickness):
        img[2+t, 2:26] = 1.0
        img[25-t, 2:26] = 1.0
        img[2:26, 2+t] = 1.0
        img[2:26, 25-t] = 1.0
    return img

print("✅ Generators defined")

In [None]:
# ============================================================================
# PROMPT GENERATION
# ============================================================================

SHAPE_GENERATORS = {
    'circle': generate_circle, 'ring': generate_circle,
    'disk': generate_filled_circle, 'filled circle': generate_filled_circle,
    'square': generate_square, 'box': generate_square,
    'filled square': generate_filled_square, 'block': generate_filled_square,
    'horizontal line': generate_horizontal_line,
    'vertical line': generate_vertical_line,
    'diagonal line': generate_diagonal_line, 'diagonal': generate_diagonal_line,
    'cross': generate_cross, 'plus': generate_cross,
    'dot': generate_dot, 'point': generate_dot,
    'triangle': generate_triangle,
    'x': generate_x_shape, 'x shape': generate_x_shape,
    'blob': generate_blob, 'glow': generate_blob,
    'dots': generate_random_dots, 'scattered dots': generate_random_dots,
    'grid': generate_grid, 'border': generate_border, 'frame': generate_border,
}

SIZES = ['', 'tiny', 'small', 'medium', 'large', 'big']
POSITIONS = ['', 'at top', 'at bottom', 'on the left', 'on the right']
QUANTITIES = {'': 1, 'two': 2, 'three': 3}

def parse_position(pos_str):
    return {'': 'center', 'at top': 'top', 'at bottom': 'bottom',
            'on the left': 'left', 'on the right': 'right'}.get(pos_str, 'center')

ALL_PROMPTS = {}
for shape_name, gen_func in SHAPE_GENERATORS.items():
    for size in SIZES:
        for position in POSITIONS:
            for qty_word, qty_num in QUANTITIES.items():
                if qty_num > 1 and shape_name in ['grid', 'border', 'frame']:
                    continue
                parts = [p for p in [qty_word, size, shape_name, position] if p]
                prompt = ' '.join(parts)
                ALL_PROMPTS[prompt] = {
                    'generator': gen_func,
                    'size': size if size else 'medium',
                    'position': parse_position(position),
                    'count': qty_num,
                }

print(f"✅ {len(ALL_PROMPTS)} prompts generated")

In [None]:
# Test generators work
demo = ['circle', 'small circle', 'large triangle', 'two dots', 'cross at top']
fig, axes = plt.subplots(1, len(demo), figsize=(12, 3))
for ax, p in zip(axes, demo):
    if p in ALL_PROMPTS:
        cfg = ALL_PROMPTS[p]
        img = cfg['generator'](cfg['size'], cfg['position'], cfg['count'])
        ax.imshow(img, cmap='gray')
    ax.set_title(p, fontsize=9)
    ax.axis('off')
plt.tight_layout()
plt.show()

In [None]:
# ============================================================================
# TOKENIZER & DATASET
# ============================================================================

class CharacterTokenizer:
    def __init__(self):
        self.char_to_idx = {'<PAD>': 0, ' ': 1}
        for i, c in enumerate('abcdefghijklmnopqrstuvwxyz'):
            self.char_to_idx[c] = i + 2
        self.vocab_size = len(self.char_to_idx)
    def encode(self, text, max_length=64):
        tokens = [self.char_to_idx.get(c, 0) for c in text.lower() if c in self.char_to_idx]
        return (tokens + [0] * max_length)[:max_length]

class TextImageDataset(Dataset):
    def __init__(self, prompts_dict, samples_per_prompt=50):
        self.prompts = list(prompts_dict.keys())
        self.configs = prompts_dict
        self.index_to_prompt = [p for p in self.prompts for _ in range(samples_per_prompt)]
    def __len__(self): return len(self.index_to_prompt)
    def __getitem__(self, idx):
        prompt = self.index_to_prompt[idx]
        cfg = self.configs[prompt]
        img = cfg['generator'](cfg['size'], cfg['position'], cfg['count'])
        return prompt, torch.from_numpy(img).unsqueeze(0)

tokenizer = CharacterTokenizer()
train_dataset = TextImageDataset(ALL_PROMPTS, samples_per_prompt=50)
print(f"Dataset: {len(train_dataset)} samples")

In [None]:
# ============================================================================
# MODEL
# ============================================================================

class TextEncoder(nn.Module):
    def __init__(self, vocab_size, embed_dim=64, latent_dim=128):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        self.projection = nn.Sequential(nn.Linear(embed_dim, 256), nn.ReLU(), nn.Linear(256, latent_dim))
    def forward(self, x):
        embeds = self.embedding(x)
        mask = (x != 0).float().unsqueeze(-1)
        pooled = (embeds * mask).sum(1) / mask.sum(1).clamp(min=1)
        return self.projection(pooled)

class ImageGenerator(nn.Module):
    def __init__(self, latent_dim=128, noise_dim=64):
        super().__init__()
        self.noise_dim = noise_dim
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim + noise_dim, 512), nn.LayerNorm(512), nn.ReLU(),
            nn.Linear(512, 1024), nn.LayerNorm(1024), nn.ReLU(),
            nn.Linear(1024, 784), nn.Sigmoid()
        )
    def forward(self, latent, noise=None):
        if noise is None:
            noise = torch.randn(latent.shape[0], self.noise_dim, device=latent.device)
        return self.decoder(torch.cat([latent, noise], 1)).view(-1, 1, 28, 28)

class TextToImageModel(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.encoder = TextEncoder(vocab_size)
        self.generator = ImageGenerator()
    def forward(self, token_ids):
        return self.generator(self.encoder(token_ids))
    def generate(self, prompt, tokenizer, n=1, device='cpu'):
        self.eval()
        with torch.no_grad():
            tokens = torch.tensor([tokenizer.encode(prompt)] * n, device=device)
            return self.forward(tokens)

model = TextToImageModel(tokenizer.vocab_size).to(DEVICE)
print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")

In [None]:
# ============================================================================
# TRAINING
# ============================================================================

BATCH_SIZE = 128
NUM_EPOCHS = 150

def collate_fn(batch):
    prompts, images = zip(*batch)
    return torch.tensor([tokenizer.encode(p) for p in prompts]), torch.stack(images)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.MSELoss()
losses = []

print(f"Training {NUM_EPOCHS} epochs...")
for epoch in range(NUM_EPOCHS):
    model.train()
    epoch_loss = 0
    for tokens, images in train_loader:
        tokens, images = tokens.to(DEVICE), images.to(DEVICE)
        loss = criterion(model(tokens), images)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
    losses.append(epoch_loss / len(train_loader))
    if (epoch + 1) % 10 == 0:
        print(f"Epoch {epoch+1:3d} | Loss: {losses[-1]:.6f}")

print("✅ Training complete!")

In [None]:
plt.plot(losses)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss')
plt.show()

In [None]:
# ============================================================================
# GENERATE FROM ANY PROMPT
# ============================================================================

def show(prompt, n=4):
    print(f"'{prompt}'" + (" ✓" if prompt in ALL_PROMPTS else " (novel)"))
    imgs = model.generate(prompt, tokenizer, n=n, device=DEVICE)
    fig, axes = plt.subplots(1, n, figsize=(3*n, 3))
    for i, ax in enumerate(axes):
        ax.imshow(imgs[i, 0].cpu().numpy(), cmap='gray')
        ax.axis('off')
    plt.show()

show("small circle")
show("large triangle")
show("two dots")
show("cross at top")

In [None]:
# Test some novel prompts
show("big blob at bottom")
show("tiny x")
show("three squares")

In [None]:
print(f"\n{'='*50}")
print("DONE!")
print(f"{'='*50}")
print(f"Prompts: {len(ALL_PROMPTS)}")
print(f"Final loss: {losses[-1]:.6f}")
print("\nTry: show('your prompt here')")