In [2]:
 # ============================================================================
# GENERATE MORE DESCRIPTIONS USING OLLAMA
# ============================================================================
# Priority order:
#   1. CIFAR-100 words NOT in checkpoint (missing entirely)
#   2. Words with fewest descriptions (need more)
#
# Generates 100 descriptions per word, then moves to next word.
# Saves checkpoint after each word for safety.
# ============================================================================

import json
import requests
import re
import torchvision

CHECKPOINT_FILE = "more_words.json"
OLLAMA_URL = "http://localhost:11434/api/generate"
TARGET_PER_WORD = 100  # Generate 100 descriptions per run

# Load checkpoint
with open(CHECKPOINT_FILE, 'r') as f:
    checkpoint = json.load(f)

# Get all CIFAR-100 words
cifar100 = torchvision.datasets.CIFAR100(root='./data', download=True)
all_cifar_words = set(cifar100.classes)

# Find words missing from checkpoint entirely
checkpoint_words = set(checkpoint['descriptions'].keys())
missing_words = sorted(all_cifar_words - checkpoint_words)

# Get words sorted by description count (fewest first)
existing_counts = [(word, len(checkpoint['descriptions'].get(word, []))) 
                   for word in all_cifar_words if word in checkpoint_words]
existing_counts.sort(key=lambda x: x[1])

# Build priority queue: missing words first, then by count
words_to_process = missing_words + [w for w, c in existing_counts]

print(f"Checkpoint has {len(checkpoint_words)} words")
print(f"Missing from checkpoint: {len(missing_words)} words: {missing_words}")
print(f"\nWill generate {TARGET_PER_WORD} descriptions per word")
print("="*60)

def clean_description(desc, word):
    """Clean and validate a description."""
    desc = desc.lower().strip()
    desc = desc.strip('"\'')
    desc = re.sub(r'^[\d\.\)\-\*]+\s*', '', desc)  # Remove numbering
    desc = re.sub(r'_([a-z]+)_', r'\1', desc)  # Strip markdown underscores
    desc = re.sub(r"(\b\w+)'s\b", r'\1', desc)  # Strip possessives
    
    # Normalize compound words
    compounds = [
        ('aquarium fish', 'aquarium_fish'), ('lawn mower', 'lawn_mower'),
        ('maple tree', 'maple_tree'), ('oak tree', 'oak_tree'),
        ('palm tree', 'palm_tree'), ('pickup truck', 'pickup_truck'),
        ('pine tree', 'pine_tree'), ('sweet pepper', 'sweet_pepper'),
        ('willow tree', 'willow_tree'),
    ]
    for space_ver, under_ver in compounds:
        desc = desc.replace(space_ver, under_ver)
    
    # Validate
    word_check = word.replace('_', ' ') if '_' in word else word
    if word not in desc and word_check not in desc:
        return None
    if len(desc) < 10 or len(desc) > 200:
        return None
    return desc

# Generate descriptions
for i, word in enumerate(words_to_process):
    existing = len(checkpoint['descriptions'].get(word, []))
    print(f"[{i+1}/{len(words_to_process)}] {word} (existing: {existing})...", end=" ", flush=True)
    
    new_descriptions = set()
    batch_num = 0
    
    while len(new_descriptions) < TARGET_PER_WORD and batch_num < 15:
        batch_num += 1
        
        prompt = f"""Generate 50 unique short image captions that a HUMAN would write to describe a photograph containing "{word}".

Rules:
- 4-8 words each
- Use the exact word "{word}". KEEP UNDERSCORE IF PRESENT!
- The word needs to appear EXACTLY as mentioned in the description, without alternative (e.g use bed and not bedside)
- Write like a human describing what they SEE in a real photograph
- Varied contexts, colors, actions, settings
- Natural language, not robotic or repetitive
- Output ONLY the captions, one per line. No numbers, bullets, or explanations.
"""

        try:
            response = requests.post(OLLAMA_URL, json={
                "model": "llama3.2",
                "prompt": prompt,
                "stream": False
            }, timeout=120)
            
            result = response.json()
            text = result.get('response', '')
            
            for line in text.strip().split('\n'):
                cleaned = clean_description(line, word)
                if cleaned:
                    new_descriptions.add(cleaned)
                    
        except Exception as e:
            print(f"Error: {e}", end=" ", flush=True)
            break
    
    # Append to checkpoint
    if word not in checkpoint['descriptions']:
        checkpoint['descriptions'][word] = []
    
    new_list = list(new_descriptions)
    checkpoint['descriptions'][word].extend(new_list)
    
    # Save checkpoint after each word
    with open(CHECKPOINT_FILE, 'w') as f:
        json.dump(checkpoint, f)
    
    new_total = len(checkpoint['descriptions'][word])
    print(f"✓ +{len(new_list)} new (total: {new_total})")

print(f"\n{'='*60}")
print(f"✓ Done! Checkpoint updated.")
print(f"  Words in checkpoint: {len(checkpoint['descriptions'])}")


100%|██████████| 169M/169M [00:17<00:00, 9.63MB/s] 


Checkpoint has 0 words
Missing from checkpoint: 100 words: ['apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee', 'beetle', 'bicycle', 'bottle', 'bowl', 'boy', 'bridge', 'bus', 'butterfly', 'camel', 'can', 'castle', 'caterpillar', 'cattle', 'chair', 'chimpanzee', 'clock', 'cloud', 'cockroach', 'couch', 'crab', 'crocodile', 'cup', 'dinosaur', 'dolphin', 'elephant', 'flatfish', 'forest', 'fox', 'girl', 'hamster', 'house', 'kangaroo', 'keyboard', 'lamp', 'lawn_mower', 'leopard', 'lion', 'lizard', 'lobster', 'man', 'maple_tree', 'motorcycle', 'mountain', 'mouse', 'mushroom', 'oak_tree', 'orange', 'orchid', 'otter', 'palm_tree', 'pear', 'pickup_truck', 'pine_tree', 'plain', 'plate', 'poppy', 'porcupine', 'possum', 'rabbit', 'raccoon', 'ray', 'road', 'rocket', 'rose', 'sea', 'seal', 'shark', 'shrew', 'skunk', 'skyscraper', 'snail', 'snake', 'spider', 'squirrel', 'streetcar', 'sunflower', 'sweet_pepper', 'table', 'tank', 'telephone', 'television', 'tiger', 'tractor', 'train', 'tro

In [4]:
# DESCRIPTIONS PER WORD
import json
import torchvision

CHECKPOINT_FILE = "more_words.json"

with open(CHECKPOINT_FILE, 'r') as f:
    checkpoint = json.load(f)

cifar100 = torchvision.datasets.CIFAR100(root='./data', download=False)
all_cifar_words = sorted(cifar100.classes)

print(f"{'Word':<20} {'Count':>6}")
print("=" * 28)

total = 0
for word in all_cifar_words:
    count = len(checkpoint['descriptions'].get(word, []))
    total += count
    print(f"{word:<20} {count:>6}")

print("=" * 28)
print(f"{'TOTAL':<20} {total:>6}")
print(f"\nWords with descriptions: {len(checkpoint['descriptions'])}/100")

Word                  Count
apple                   129
aquarium_fish           100
baby                    108
bear                    105
beaver                  102
bed                     131
bee                     147
beetle                  100
bicycle                 131
bottle                  111
bowl                    132
boy                     135
bridge                  113
bus                     113
butterfly               146
camel                   134
can                     149
castle                  133
caterpillar             125
cattle                  127
chair                   121
chimpanzee              139
clock                   112
cloud                   114
cockroach               147
couch                   104
crab                    135
crocodile               129
cup                     129
dinosaur                103
dolphin                 108
elephant                107
flatfish                120
forest                  107
fox                 