# üî¨ Fair Compositional Generalization Test v3

## Changes from v2

- **More training data** (~90% of examples)
- **Minimal holdout** (1 primitive, 1 modifier)
- **Train accuracy check** before testing extrapolation
- **Longer training** until convergence
- **Proper evaluation** with verbose output

## Fair Test Design

- Transformer gets enough data to actually learn
- We verify it learned (train acc > 90%) before testing generalization
- HDC still uses structural composition

---

*Resonance Protocol Research*

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import matplotlib.pyplot as plt
import random
import json
from datetime import datetime
from tqdm.auto import tqdm
from collections import defaultdict
import warnings
warnings.filterwarnings('ignore')

# Reproducibility
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

# Results storage
RESULTS = {
    'start_time': datetime.now().isoformat(),
    'device': str(device),
    'levels': {},
    'summary': {}
}

def log(msg):
    print(f"[{datetime.now().strftime('%H:%M:%S')}] {msg}")

## Part 1: Language & Data

In [None]:
# ============================================================
# COMMAND LANGUAGE
# ============================================================

class CommandLanguage:
    def __init__(self):
        self.primitives = {
            'walk': 'WALK', 'run': 'RUN', 'jump': 'JUMP',
            'look': 'LOOK', 'turn': 'TURN', 'spin': 'SPIN',
            'crawl': 'CRAWL', 'swim': 'SWIM',
        }
        
        self.modifiers = {
            'twice': 2,
            'thrice': 3,
            'four times': 4,
        }
    
    def execute(self, command):
        command = command.strip().lower()
        
        if ' and ' in command:
            parts = command.split(' and ')
            if len(parts) == 2:
                left = self._execute_single(parts[0].strip())
                right = self._execute_single(parts[1].strip())
                if left and right:
                    return f"{left} {right}"
        
        result = self._execute_single(command)
        return result if result else '<e>'
    
    def _execute_single(self, cmd):
        repeat = 1
        for mod_name, mod_count in self.modifiers.items():
            if cmd.endswith(' ' + mod_name):
                repeat = mod_count
                cmd = cmd[:-len(mod_name)-1].strip()
                break
        
        if cmd in self.primitives:
            return ' '.join([self.primitives[cmd]] * repeat)
        return None
    
    def generate_all(self):
        """Generate all examples for levels 1-3."""
        examples = []
        
        # Level 1: primitives
        for p in self.primitives:
            examples.append((p, self.execute(p), 1))
        
        # Level 2: primitive + modifier
        for p in self.primitives:
            for m in self.modifiers:
                cmd = f"{p} {m}"
                examples.append((cmd, self.execute(cmd), 2))
        
        # Level 3: primitive and primitive
        for p1 in self.primitives:
            for p2 in self.primitives:
                cmd = f"{p1} and {p2}"
                examples.append((cmd, self.execute(cmd), 3))
        
        return examples

lang = CommandLanguage()
all_examples = lang.generate_all()

log(f"Total examples: {len(all_examples)}")
log(f"Primitives: {list(lang.primitives.keys())}")
log(f"Modifiers: {list(lang.modifiers.keys())}")

# Show distribution
level_counts = {}
for _, _, level in all_examples:
    level_counts[level] = level_counts.get(level, 0) + 1
log(f"By level: {level_counts}")

In [None]:
# ============================================================
# MINIMAL HOLDOUT SPLIT
# ============================================================

# Holdout: only ONE primitive and ONE modifier
HOLDOUT_PRIMITIVE = 'swim'  # Only 1 of 8
HOLDOUT_MODIFIER = 'four times'  # Only 1 of 3

train_data = []
test_extrapolation = []

for cmd, out, level in all_examples:
    has_holdout_prim = HOLDOUT_PRIMITIVE in cmd.lower()
    has_holdout_mod = HOLDOUT_MODIFIER in cmd.lower()
    
    # Put primitive alone in train (so model knows SWIM ‚Üí SWIM)
    if cmd.lower() == HOLDOUT_PRIMITIVE:
        train_data.append((cmd, out, level))
    # Put modifier with ONE primitive in train (so model knows "four times" = repeat 4x)
    elif has_holdout_mod and 'walk' in cmd.lower() and 'and' not in cmd.lower():
        train_data.append((cmd, out, level))
    # Extrapolation: combinations with holdout elements
    elif has_holdout_prim or has_holdout_mod:
        test_extrapolation.append((cmd, out, level))
    # Regular training
    else:
        train_data.append((cmd, out, level))

log(f"\nSplit:")
log(f"  Train: {len(train_data)} examples")
log(f"  Test (extrapolation): {len(test_extrapolation)} examples")
log(f"  Train ratio: {len(train_data)/len(all_examples):.1%}")

# Show what's in train vs test
log(f"\nTraining includes:")
train_with_holdout = [e for e in train_data if HOLDOUT_PRIMITIVE in e[0] or HOLDOUT_MODIFIER in e[0]]
for cmd, out, _ in train_with_holdout:
    log(f"  '{cmd}' ‚Üí '{out}'")

log(f"\nTest extrapolation examples (sample):")
for cmd, out, _ in test_extrapolation[:8]:
    log(f"  '{cmd}' ‚Üí '{out}'")

## Part 2: Vocabulary & Dataset

In [None]:
# ============================================================
# VOCABULARY
# ============================================================

class Vocabulary:
    def __init__(self):
        self.word2idx = {'<PAD>': 0, '<SOS>': 1, '<EOS>': 2}
        self.idx2word = {0: '<PAD>', 1: '<SOS>', 2: '<EOS>'}
        self.n_words = 3
    
    def add_word(self, word):
        if word not in self.word2idx:
            self.word2idx[word] = self.n_words
            self.idx2word[self.n_words] = word
            self.n_words += 1
    
    def add_sentence(self, sentence):
        for word in sentence.split():
            self.add_word(word)
    
    def encode(self, sentence, add_sos=False, add_eos=True):
        tokens = []
        if add_sos:
            tokens.append(self.word2idx['<SOS>'])
        for word in sentence.split():
            tokens.append(self.word2idx.get(word, 0))
        if add_eos:
            tokens.append(self.word2idx['<EOS>'])
        return tokens
    
    def decode(self, indices):
        words = []
        for idx in indices:
            if isinstance(idx, torch.Tensor):
                idx = idx.item()
            if idx == self.word2idx['<EOS>']:
                break
            word = self.idx2word.get(idx, '')
            if word and word not in ['<PAD>', '<SOS>', '<EOS>']:
                words.append(word)
        return ' '.join(words)

# Build from ALL examples (so vocab is complete)
src_vocab = Vocabulary()
tgt_vocab = Vocabulary()

for cmd, out, _ in all_examples:
    src_vocab.add_sentence(cmd.lower())
    tgt_vocab.add_sentence(out)

log(f"Source vocab: {src_vocab.n_words} tokens")
log(f"Target vocab: {tgt_vocab.n_words} tokens")
log(f"Target tokens: {list(tgt_vocab.word2idx.keys())}")

In [None]:
# ============================================================
# DATASET
# ============================================================

class CommandDataset(Dataset):
    def __init__(self, examples, src_vocab, tgt_vocab, max_src_len=12, max_tgt_len=10):
        self.examples = [(cmd, out) for cmd, out, _ in examples]
        self.src_vocab = src_vocab
        self.tgt_vocab = tgt_vocab
        self.max_src_len = max_src_len
        self.max_tgt_len = max_tgt_len
    
    def __len__(self):
        return len(self.examples)
    
    def __getitem__(self, idx):
        cmd, out = self.examples[idx]
        
        src = self.src_vocab.encode(cmd.lower(), add_sos=False, add_eos=True)
        tgt = self.tgt_vocab.encode(out, add_sos=True, add_eos=True)
        
        # Pad
        src = src[:self.max_src_len] + [0] * max(0, self.max_src_len - len(src))
        tgt = tgt[:self.max_tgt_len] + [0] * max(0, self.max_tgt_len - len(tgt))
        
        return torch.tensor(src), torch.tensor(tgt)

train_dataset = CommandDataset(train_data, src_vocab, tgt_vocab)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

log(f"Train dataset: {len(train_dataset)} examples")
log(f"Train batches: {len(train_loader)}")

# Verify encoding/decoding
log("\nVerify encode/decode:")
for cmd, out, _ in train_data[:3]:
    enc = tgt_vocab.encode(out, add_sos=True, add_eos=True)
    dec = tgt_vocab.decode(enc)
    log(f"  '{out}' ‚Üí {enc} ‚Üí '{dec}' | Match: {dec == out}")

## Part 3: Models

In [None]:
# ============================================================
# HDC MODEL
# ============================================================

class HDCModel:
    """HDC uses structural composition - no training needed."""
    
    def __init__(self):
        self.primitives = {
            'walk': 'WALK', 'run': 'RUN', 'jump': 'JUMP',
            'look': 'LOOK', 'turn': 'TURN', 'spin': 'SPIN',
            'crawl': 'CRAWL', 'swim': 'SWIM',
        }
        self.modifiers = {'twice': 2, 'thrice': 3, 'four times': 4}
    
    def predict(self, command):
        command = command.strip().lower()
        
        if ' and ' in command:
            parts = command.split(' and ')
            if len(parts) == 2:
                left = self._predict_single(parts[0].strip())
                right = self._predict_single(parts[1].strip())
                if left and right:
                    return f"{left} {right}"
        
        return self._predict_single(command) or '<e>'
    
    def _predict_single(self, cmd):
        repeat = 1
        for mod_name, mod_count in self.modifiers.items():
            if cmd.endswith(' ' + mod_name):
                repeat = mod_count
                cmd = cmd[:-len(mod_name)-1].strip()
                break
        
        if cmd in self.primitives:
            return ' '.join([self.primitives[cmd]] * repeat)
        return None

hdc = HDCModel()

# Verify HDC
log("HDC verification:")
test_cmds = ['walk', 'swim', 'walk twice', 'swim four times', 'walk and swim']
for cmd in test_cmds:
    log(f"  '{cmd}' ‚Üí '{hdc.predict(cmd)}'")

In [None]:
# ============================================================
# TRANSFORMER MODEL
# ============================================================

class Seq2SeqTransformer(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, d_model=128, nhead=4, 
                 num_layers=3, dim_ff=256, dropout=0.1, max_len=15):
        super().__init__()
        
        self.d_model = d_model
        self.src_emb = nn.Embedding(src_vocab_size, d_model)
        self.tgt_emb = nn.Embedding(tgt_vocab_size, d_model)
        self.pos_emb = nn.Embedding(max_len, d_model)
        
        self.transformer = nn.Transformer(
            d_model=d_model,
            nhead=nhead,
            num_encoder_layers=num_layers,
            num_decoder_layers=num_layers,
            dim_feedforward=dim_ff,
            dropout=dropout,
            batch_first=True
        )
        
        self.fc_out = nn.Linear(d_model, tgt_vocab_size)
        self.tgt_vocab_size = tgt_vocab_size
    
    def forward(self, src, tgt):
        B, src_len = src.shape
        _, tgt_len = tgt.shape
        
        src_pos = torch.arange(src_len, device=src.device).unsqueeze(0).expand(B, -1)
        tgt_pos = torch.arange(tgt_len, device=tgt.device).unsqueeze(0).expand(B, -1)
        
        src_emb = self.src_emb(src) + self.pos_emb(src_pos)
        tgt_emb = self.tgt_emb(tgt) + self.pos_emb(tgt_pos)
        
        tgt_mask = nn.Transformer.generate_square_subsequent_mask(tgt_len, device=src.device)
        src_pad = (src == 0)
        tgt_pad = (tgt == 0)
        
        out = self.transformer(src_emb, tgt_emb, tgt_mask=tgt_mask,
                               src_key_padding_mask=src_pad, tgt_key_padding_mask=tgt_pad)
        
        return self.fc_out(out)
    
    def generate(self, src, max_len=10):
        self.eval()
        B = src.size(0)
        tgt = torch.ones(B, 1, dtype=torch.long, device=src.device)  # SOS
        
        for _ in range(max_len):
            with torch.no_grad():
                out = self.forward(src, tgt)
                next_tok = out[:, -1, :].argmax(dim=-1, keepdim=True)
                tgt = torch.cat([tgt, next_tok], dim=1)
                if (next_tok == 2).all():  # EOS
                    break
        return tgt

# Create model
model = Seq2SeqTransformer(
    src_vocab_size=src_vocab.n_words,
    tgt_vocab_size=tgt_vocab.n_words,
    d_model=128,
    nhead=4,
    num_layers=3,
    dim_ff=256
).to(device)

n_params = sum(p.numel() for p in model.parameters())
log(f"Transformer: {n_params:,} parameters")

## Part 4: Training with Convergence Check

In [None]:
# ============================================================
# TRAINING WITH EARLY STOPPING
# ============================================================

def evaluate_accuracy(model, examples, src_vocab, tgt_vocab, max_samples=100):
    """Calculate accuracy on a set of examples."""
    model.eval()
    correct = 0
    samples = examples[:max_samples]
    
    for cmd, out, _ in samples:
        src = src_vocab.encode(cmd.lower())
        src = src[:12] + [0] * max(0, 12 - len(src))
        src_t = torch.tensor([src], device=device)
        
        pred_tokens = model.generate(src_t, max_len=10)
        pred = tgt_vocab.decode(pred_tokens[0])
        
        if pred == out:
            correct += 1
    
    return correct / len(samples)

def train_until_convergence(model, loader, train_data, src_vocab, tgt_vocab,
                            max_epochs=300, target_acc=0.95, patience=20):
    """Train until train accuracy reaches target."""
    
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=10, factor=0.5)
    criterion = nn.CrossEntropyLoss(ignore_index=0)
    
    history = {'loss': [], 'train_acc': []}
    best_acc = 0
    no_improve = 0
    
    log(f"Training until {target_acc:.0%} accuracy or {max_epochs} epochs...")
    
    for epoch in range(max_epochs):
        model.train()
        total_loss = 0
        
        for src, tgt in loader:
            src, tgt = src.to(device), tgt.to(device)
            
            tgt_in = tgt[:, :-1]
            tgt_out = tgt[:, 1:]
            
            optimizer.zero_grad()
            output = model(src, tgt_in)
            
            loss = criterion(output.reshape(-1, output.size(-1)), tgt_out.reshape(-1))
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            
            total_loss += loss.item()
        
        avg_loss = total_loss / len(loader)
        scheduler.step(avg_loss)
        
        # Check accuracy every 10 epochs
        if (epoch + 1) % 10 == 0:
            train_acc = evaluate_accuracy(model, train_data, src_vocab, tgt_vocab)
            history['loss'].append(avg_loss)
            history['train_acc'].append(train_acc)
            
            log(f"Epoch {epoch+1}: loss={avg_loss:.4f}, train_acc={train_acc:.1%}")
            
            if train_acc >= target_acc:
                log(f"‚úì Reached {target_acc:.0%} accuracy!")
                return history, epoch + 1
            
            if train_acc > best_acc:
                best_acc = train_acc
                no_improve = 0
            else:
                no_improve += 1
                if no_improve >= patience // 10:
                    log(f"No improvement for {patience} epochs, stopping.")
                    break
    
    final_acc = evaluate_accuracy(model, train_data, src_vocab, tgt_vocab)
    log(f"Finished at epoch {epoch+1}, final train_acc={final_acc:.1%}")
    return history, epoch + 1

# Train
history, epochs_used = train_until_convergence(
    model, train_loader, train_data, src_vocab, tgt_vocab,
    max_epochs=300, target_acc=0.95
)

In [None]:
# ============================================================
# VERIFY TRAINING WORKED
# ============================================================

log("\n" + "="*60)
log("TRAINING VERIFICATION")
log("="*60)

# Check on training examples
log("\nSample predictions on TRAINING data:")
model.eval()

correct = 0
for cmd, out, _ in train_data[:20]:
    src = src_vocab.encode(cmd.lower())
    src = src[:12] + [0] * max(0, 12 - len(src))
    src_t = torch.tensor([src], device=device)
    
    pred_tokens = model.generate(src_t, max_len=10)
    pred = tgt_vocab.decode(pred_tokens[0])
    
    is_correct = pred == out
    if is_correct:
        correct += 1
    
    status = "‚úì" if is_correct else "‚úó"
    log(f"  {status} '{cmd}' ‚Üí '{pred}' (expected: '{out}')")

train_acc = correct / 20
log(f"\nTrain accuracy (sample): {train_acc:.1%}")

# Full train accuracy
full_train_acc = evaluate_accuracy(model, train_data, src_vocab, tgt_vocab, max_samples=len(train_data))
log(f"Train accuracy (full): {full_train_acc:.1%}")

RESULTS['train_accuracy'] = full_train_acc

if full_train_acc < 0.8:
    log("\n‚ö†Ô∏è WARNING: Train accuracy < 80%. Model hasn't learned properly!")
    log("Results may not be meaningful.")
else:
    log("\n‚úì Model has learned the training data well.")

## Part 5: Extrapolation Test

In [None]:
# ============================================================
# EXTRAPOLATION TEST
# ============================================================

log("\n" + "="*60)
log("EXTRAPOLATION TEST")
log("="*60)
log(f"Testing on {len(test_extrapolation)} held-out combinations")
log(f"Model knows: '{HOLDOUT_PRIMITIVE}' and 'walk {HOLDOUT_MODIFIER}'")
log(f"But has never seen other combinations with these.")

# HDC
log("\n--- HDC ---")
hdc_correct = 0
hdc_results = []

for cmd, expected, level in test_extrapolation:
    pred = hdc.predict(cmd)
    is_correct = pred == expected
    if is_correct:
        hdc_correct += 1
    hdc_results.append((cmd, expected, pred, is_correct))

hdc_acc = hdc_correct / len(test_extrapolation)
log(f"HDC Accuracy: {hdc_acc:.1%} ({hdc_correct}/{len(test_extrapolation)})")

# Show HDC samples
log("Sample predictions:")
for cmd, expected, pred, correct in hdc_results[:5]:
    status = "‚úì" if correct else "‚úó"
    log(f"  {status} '{cmd}' ‚Üí '{pred}' (expected: '{expected}')")

# Transformer
log("\n--- Transformer ---")
model.eval()
trans_correct = 0
trans_results = []

for cmd, expected, level in test_extrapolation:
    src = src_vocab.encode(cmd.lower())
    src = src[:12] + [0] * max(0, 12 - len(src))
    src_t = torch.tensor([src], device=device)
    
    pred_tokens = model.generate(src_t, max_len=10)
    pred = tgt_vocab.decode(pred_tokens[0])
    
    is_correct = pred == expected
    if is_correct:
        trans_correct += 1
    trans_results.append((cmd, expected, pred, is_correct))

trans_acc = trans_correct / len(test_extrapolation)
log(f"Transformer Accuracy: {trans_acc:.1%} ({trans_correct}/{len(test_extrapolation)})")

# Show Transformer samples
log("Sample predictions:")
for cmd, expected, pred, correct in trans_results[:10]:
    status = "‚úì" if correct else "‚úó"
    log(f"  {status} '{cmd}' ‚Üí '{pred}' (expected: '{expected}')")

RESULTS['hdc_extrapolation'] = hdc_acc
RESULTS['transformer_extrapolation'] = trans_acc

In [None]:
# ============================================================
# ERROR ANALYSIS
# ============================================================

log("\n" + "="*60)
log("ERROR ANALYSIS")
log("="*60)

# Group by type of extrapolation
log("\nTransformer errors by type:")

errors_by_type = defaultdict(list)
for cmd, expected, pred, correct in trans_results:
    if not correct:
        if HOLDOUT_PRIMITIVE in cmd and HOLDOUT_MODIFIER in cmd:
            errors_by_type['both_holdout'].append((cmd, expected, pred))
        elif HOLDOUT_PRIMITIVE in cmd:
            errors_by_type['holdout_primitive'].append((cmd, expected, pred))
        elif HOLDOUT_MODIFIER in cmd:
            errors_by_type['holdout_modifier'].append((cmd, expected, pred))

for err_type, errors in errors_by_type.items():
    log(f"\n{err_type}: {len(errors)} errors")
    for cmd, expected, pred in errors[:3]:
        log(f"  '{cmd}' ‚Üí '{pred}' (expected: '{expected}')")

## Part 6: Results Summary

In [None]:
# ============================================================
# FINAL SUMMARY
# ============================================================

log("\n" + "="*60)
log("FINAL RESULTS")
log("="*60)

log(f"\nDataset:")
log(f"  Total examples: {len(all_examples)}")
log(f"  Train: {len(train_data)} ({len(train_data)/len(all_examples):.1%})")
log(f"  Test (extrapolation): {len(test_extrapolation)}")

log(f"\nHoldout elements:")
log(f"  Primitive: '{HOLDOUT_PRIMITIVE}'")
log(f"  Modifier: '{HOLDOUT_MODIFIER}'")

log(f"\nResults:")
log(f"  {'Model':<20} {'Train Acc':<15} {'Extrapolation Acc':<20}")
log(f"  {'-'*55}")
log(f"  {'HDC':<20} {'100% (by design)':<15} {hdc_acc:<20.1%}")
log(f"  {'Transformer':<20} {full_train_acc:<15.1%} {trans_acc:<20.1%}")

log(f"\n" + "="*60)
if hdc_acc > trans_acc:
    diff = hdc_acc - trans_acc
    log(f"HDC outperforms Transformer by {diff:.1%} on extrapolation")
    if full_train_acc >= 0.9:
        log("\n‚úì HYPOTHESIS SUPPORTED:")
        log("  Transformer learned the training data well but failed to generalize.")
        log("  HDC's structural composition enables perfect generalization.")
    else:
        log("\n‚ö†Ô∏è Transformer didn't fully learn training data.")
        log("  Results may not be conclusive.")
else:
    log(f"Transformer matches or exceeds HDC")
    log("Hypothesis not supported in this test.")

log("="*60)

In [None]:
# ============================================================
# VISUALIZATION
# ============================================================

fig, axes = plt.subplots(1, 3, figsize=(15, 5))

# 1. Train vs Extrapolation
ax = axes[0]
models = ['HDC', 'Transformer']
train_accs = [1.0, full_train_acc]
extrap_accs = [hdc_acc, trans_acc]

x = np.arange(len(models))
width = 0.35

ax.bar(x - width/2, train_accs, width, label='Train', color='#3498db', alpha=0.8)
ax.bar(x + width/2, extrap_accs, width, label='Extrapolation', color='#e74c3c', alpha=0.8)

ax.set_ylabel('Accuracy')
ax.set_title('Train vs Extrapolation Accuracy')
ax.set_xticks(x)
ax.set_xticklabels(models)
ax.legend()
ax.set_ylim(0, 1.1)
ax.axhline(y=1.0, color='green', linestyle='--', alpha=0.3)

for i, (t, e) in enumerate(zip(train_accs, extrap_accs)):
    ax.text(i - width/2, t + 0.02, f'{t:.0%}', ha='center', fontsize=10)
    ax.text(i + width/2, e + 0.02, f'{e:.0%}', ha='center', fontsize=10)

# 2. Generalization Gap
ax = axes[1]
gaps = [t - e for t, e in zip(train_accs, extrap_accs)]
colors = ['#27ae60' if g < 0.1 else '#e74c3c' for g in gaps]

ax.bar(models, gaps, color=colors, alpha=0.8, edgecolor='black')
ax.set_ylabel('Gap (Train - Extrapolation)')
ax.set_title('Generalization Gap\n(Lower = Better)')
ax.axhline(y=0, color='green', linestyle='--', alpha=0.5)

for i, g in enumerate(gaps):
    ax.text(i, g + 0.02, f'{g:.0%}', ha='center', fontsize=12, fontweight='bold')

# 3. Training curve
ax = axes[2]
if history['train_acc']:
    epochs = range(10, 10 * len(history['train_acc']) + 1, 10)
    ax.plot(epochs, history['train_acc'], 'b-o', label='Train Accuracy')
    ax.axhline(y=0.95, color='green', linestyle='--', alpha=0.5, label='Target (95%)')
    ax.set_xlabel('Epoch')
    ax.set_ylabel('Accuracy')
    ax.set_title('Training Progress')
    ax.legend()
    ax.set_ylim(0, 1.05)
    ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('fair_test_results.png', dpi=150, bbox_inches='tight')
plt.show()

log("\nüìä Saved: fair_test_results.png")

In [None]:
# ============================================================
# SAVE RESULTS
# ============================================================

RESULTS['end_time'] = datetime.now().isoformat()
RESULTS['summary'] = {
    'total_examples': len(all_examples),
    'train_size': len(train_data),
    'test_size': len(test_extrapolation),
    'holdout_primitive': HOLDOUT_PRIMITIVE,
    'holdout_modifier': HOLDOUT_MODIFIER,
    'transformer_params': n_params,
    'epochs_trained': epochs_used,
    'train_accuracy': full_train_acc,
    'hdc_extrapolation': hdc_acc,
    'transformer_extrapolation': trans_acc,
    'hypothesis_supported': hdc_acc > trans_acc and full_train_acc >= 0.9
}

with open('fair_test_results.json', 'w') as f:
    json.dump(RESULTS, f, indent=2, default=str)

log("\nüìÑ Saved: fair_test_results.json")

print("\n" + "="*60)
print("FILES TO DOWNLOAD:")
print("="*60)
print("1. fair_test_results.json")
print("2. fair_test_results.png")
print("="*60)