# Compositional Generalization: HDC vs Transformers vs LLM

## Hypothesis

**Modern LLMs are poor at generalization because they lack structural compositionality.**

## How to Run

1. **Runtime â†’ Run all** (or run cells one by one)
2. A report will be generated even if something fails
3. Download `experiment_report.json` and `experiment_log.txt` from the file browser (left panel)

---

*Part of the Resonance Protocol research: https://github.com/nick-yudin/resonance-protocol*

## Part 0: Setup & Logging Infrastructure

In [None]:
# ============================================================
# LOGGING INFRASTRUCTURE - Run this FIRST!
# ============================================================

import sys
import traceback
import json
from datetime import datetime
from pathlib import Path

class ExperimentLogger:
    """Robust logger that saves state even on crashes."""
    
    def __init__(self, log_file='experiment_log.txt', report_file='experiment_report.json'):
        self.log_file = log_file
        self.report_file = report_file
        self.start_time = datetime.now()
        
        # Report structure
        self.report = {
            'experiment': 'Compositional Generalization Test',
            'start_time': self.start_time.isoformat(),
            'status': 'RUNNING',
            'current_step': 'initialization',
            'steps_completed': [],
            'errors': [],
            'results': {},
            'environment': {},
            'debug_info': []
        }
        
        # Clear previous logs
        open(self.log_file, 'w').close()
        
        self.log("="*60)
        self.log("EXPERIMENT STARTED")
        self.log(f"Time: {self.start_time}")
        self.log("="*60)
        self.save_report()
    
    def log(self, message, level='INFO'):
        """Log message to file and print."""
        timestamp = datetime.now().strftime('%H:%M:%S')
        formatted = f"[{timestamp}] [{level}] {message}"
        print(formatted)
        
        with open(self.log_file, 'a') as f:
            f.write(formatted + '\n')
        
        if level == 'ERROR':
            self.report['errors'].append({'time': timestamp, 'message': message})
            self.save_report()
    
    def debug(self, key, value):
        """Store debug info."""
        self.report['debug_info'].append({'key': key, 'value': str(value)[:500]})
        self.log(f"DEBUG {key}: {str(value)[:100]}...", level='DEBUG')
        self.save_report()
    
    def step(self, step_name):
        """Mark a new step."""
        self.report['current_step'] = step_name
        self.log(f"\n>>> STEP: {step_name}")
        self.save_report()
    
    def step_done(self, step_name):
        """Mark step as completed."""
        self.report['steps_completed'].append(step_name)
        self.log(f"<<< DONE: {step_name}")
        self.save_report()
    
    def result(self, key, value):
        """Store a result."""
        self.report['results'][key] = value
        self.log(f"RESULT: {key} = {value}")
        self.save_report()
    
    def error(self, message, exception=None):
        """Log an error."""
        self.log(message, level='ERROR')
        if exception:
            tb = traceback.format_exc()
            self.log(f"Traceback:\n{tb}", level='ERROR')
            self.report['errors'].append({'traceback': tb})
        self.save_report()
    
    def save_report(self):
        """Save report to JSON file."""
        self.report['last_updated'] = datetime.now().isoformat()
        self.report['duration_seconds'] = (datetime.now() - self.start_time).total_seconds()
        
        with open(self.report_file, 'w') as f:
            json.dump(self.report, f, indent=2, default=str)
    
    def finish(self, status='COMPLETED'):
        """Finalize the experiment."""
        self.report['status'] = status
        self.report['end_time'] = datetime.now().isoformat()
        self.log("\n" + "="*60)
        self.log(f"EXPERIMENT {status}")
        self.log(f"Duration: {self.report['duration_seconds']:.1f} seconds")
        self.log(f"Steps completed: {len(self.report['steps_completed'])}")
        self.log(f"Errors: {len(self.report['errors'])}")
        self.log("="*60)
        self.save_report()
        
        print(f"\nðŸ“„ Log saved to: {self.log_file}")
        print(f"ðŸ“Š Report saved to: {self.report_file}")

# Create global logger
logger = ExperimentLogger()

# Helper for safe execution
def safe_run(func, step_name):
    """Run a function with error handling."""
    logger.step(step_name)
    try:
        result = func()
        logger.step_done(step_name)
        return result
    except Exception as e:
        logger.error(f"Failed at step '{step_name}': {str(e)}", exception=e)
        return None

print("âœ… Logging infrastructure ready")
print(f"ðŸ“„ Log file: experiment_log.txt")
print(f"ðŸ“Š Report file: experiment_report.json")

In [None]:
# ============================================================
# ENVIRONMENT CHECK
# ============================================================

def check_environment():
    import platform
    
    env_info = {
        'python_version': platform.python_version(),
        'platform': platform.platform(),
    }
    
    # Check GPU
    try:
        import torch
        env_info['torch_version'] = torch.__version__
        env_info['cuda_available'] = torch.cuda.is_available()
        if torch.cuda.is_available():
            env_info['gpu_name'] = torch.cuda.get_device_name(0)
            env_info['gpu_memory_gb'] = torch.cuda.get_device_properties(0).total_memory / 1e9
    except ImportError:
        env_info['torch_version'] = 'NOT INSTALLED'
        env_info['cuda_available'] = False
    
    logger.report['environment'] = env_info
    
    for k, v in env_info.items():
        logger.log(f"  {k}: {v}")
    
    return env_info

safe_run(check_environment, "Environment Check")

In [None]:
# ============================================================
# INSTALL DEPENDENCIES
# ============================================================

def install_deps():
    import subprocess
    
    packages = ['torch', 'numpy', 'matplotlib', 'seaborn', 'pandas', 'tqdm']
    
    for pkg in packages:
        logger.log(f"Checking {pkg}...")
        try:
            __import__(pkg)
            logger.log(f"  {pkg} already installed")
        except ImportError:
            logger.log(f"  Installing {pkg}...")
            subprocess.check_call([sys.executable, '-m', 'pip', 'install', '-q', pkg])
    
    logger.log("All dependencies ready")

safe_run(install_deps, "Install Dependencies")

In [None]:
# ============================================================
# IMPORTS
# ============================================================

def do_imports():
    global np, torch, nn, optim, Dataset, DataLoader
    global plt, sns, random, tqdm, defaultdict
    global device, SEED
    
    import numpy as np
    import torch
    import torch.nn as nn
    import torch.optim as optim
    from torch.utils.data import Dataset, DataLoader
    import matplotlib.pyplot as plt
    import seaborn as sns
    from collections import defaultdict
    import random
    from tqdm.auto import tqdm
    
    # Make them global
    globals()['np'] = np
    globals()['torch'] = torch
    globals()['nn'] = nn
    globals()['optim'] = optim
    globals()['Dataset'] = Dataset
    globals()['DataLoader'] = DataLoader
    globals()['plt'] = plt
    globals()['sns'] = sns
    globals()['random'] = random
    globals()['tqdm'] = tqdm
    globals()['defaultdict'] = defaultdict
    
    # Reproducibility
    SEED = 42
    np.random.seed(SEED)
    torch.manual_seed(SEED)
    random.seed(SEED)
    globals()['SEED'] = SEED
    
    # Device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    globals()['device'] = device
    
    logger.log(f"Device: {device}")
    logger.result('device', str(device))

safe_run(do_imports, "Imports")

## Part 1: Dataset Generation

In [None]:
# ============================================================
# COMMAND LANGUAGE
# ============================================================

def create_language():
    global lang, all_examples, splits
    
    class CommandLanguage:
        """A simple compositional command language."""
        
        def __init__(self):
            self.primitives = {
                'walk': 'WALK',
                'run': 'RUN',
                'jump': 'JUMP',
                'look': 'LOOK',
                'turn': 'TURN'
            }
            
            self.modifiers = {
                'twice': lambda x: f"{x} {x}",
                'thrice': lambda x: f"{x} {x} {x}",
            }
        
        def execute(self, command):
            command = command.strip().lower()
            
            for mod_name, mod_func in self.modifiers.items():
                if command.endswith(mod_name):
                    primitive = command[:-len(mod_name)].strip()
                    if primitive in self.primitives:
                        return mod_func(self.primitives[primitive])
            
            if ' and ' in command:
                parts = command.split(' and ')
                if len(parts) == 2:
                    p1, p2 = parts[0].strip(), parts[1].strip()
                    if p1 in self.primitives and p2 in self.primitives:
                        return f"{self.primitives[p1]} {self.primitives[p2]}"
            
            if command in self.primitives:
                return self.primitives[command]
            
            return '<ERROR>'
        
        def generate_all_examples(self):
            examples = []
            
            for prim in self.primitives:
                examples.append((prim, self.execute(prim)))
            
            for prim in self.primitives:
                for mod in self.modifiers:
                    cmd = f"{prim} {mod}"
                    examples.append((cmd, self.execute(cmd)))
            
            for p1 in self.primitives:
                for p2 in self.primitives:
                    if p1 != p2:
                        cmd = f"{p1} and {p2}"
                        examples.append((cmd, self.execute(cmd)))
            
            return examples
    
    lang = CommandLanguage()
    all_examples = lang.generate_all_examples()
    
    logger.log(f"Total examples generated: {len(all_examples)}")
    logger.result('total_examples', len(all_examples))
    
    # Show samples
    logger.log("Sample examples:")
    for cmd, out in all_examples[:5]:
        logger.log(f"  '{cmd}' â†’ '{out}'")
    
    globals()['lang'] = lang
    globals()['all_examples'] = all_examples

safe_run(create_language, "Create Language")

In [None]:
# ============================================================
# CREATE TRAIN/TEST SPLITS
# ============================================================

def create_splits_func():
    global splits
    
    holdout_primitives = ['look', 'turn']
    holdout_modifiers = ['thrice']
    
    train = []
    test_extrapolation = []
    
    for cmd, out in all_examples:
        cmd_lower = cmd.lower()
        
        has_holdout_prim = any(p in cmd_lower for p in holdout_primitives)
        has_holdout_mod = any(m in cmd_lower for m in holdout_modifiers)
        
        if has_holdout_prim and has_holdout_mod:
            # Hardest case: held-out primitive + held-out modifier
            test_extrapolation.append((cmd, out))
        elif has_holdout_prim and ' ' in cmd:
            # Primitive with any modifier (but we know the primitive alone)
            test_extrapolation.append((cmd, out))
        elif has_holdout_mod and not has_holdout_prim:
            # Some go to train (to learn modifier), some to test
            if random.random() < 0.3:
                train.append((cmd, out))
            else:
                test_extrapolation.append((cmd, out))
        else:
            train.append((cmd, out))
    
    splits = {
        'train': train,
        'test_extrapolation': test_extrapolation
    }
    
    logger.log(f"Train examples: {len(train)}")
    logger.log(f"Test (extrapolation) examples: {len(test_extrapolation)}")
    
    logger.result('train_size', len(train))
    logger.result('test_size', len(test_extrapolation))
    
    logger.log("\nTest examples (what we want to generalize to):")
    for cmd, out in test_extrapolation[:8]:
        logger.log(f"  '{cmd}' â†’ '{out}'")
    
    globals()['splits'] = splits

safe_run(create_splits_func, "Create Splits")

## Part 2: HDC (Hyperdimensional Computing)

In [None]:
# ============================================================
# HDC IMPLEMENTATION
# ============================================================

def create_hdc():
    global HDCModel, hdc_model
    
    class HDCProcessor:
        def __init__(self, dim=10000, seed=42):
            self.dim = dim
            self.rng = np.random.RandomState(seed)
            self.memory = {}
            self.roles = {
                'action': self._random_hv(),
                'modifier': self._random_hv(),
                'first': self._random_hv(),
                'second': self._random_hv(),
            }
        
        def _random_hv(self):
            return self.rng.choice([-1, 1], size=self.dim).astype(np.float32)
        
        def get_or_create(self, name):
            if name not in self.memory:
                self.memory[name] = self._random_hv()
            return self.memory[name]
        
        def bind(self, a, b):
            return a * b
        
        def bundle(self, *vectors):
            result = np.sum(vectors, axis=0)
            return np.sign(result + 0.001 * self.rng.randn(self.dim))
        
        def similarity(self, a, b):
            return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))
        
        def encode_command(self, command):
            command = command.strip().lower()
            
            modifier = None
            primitive = command
            
            for mod in ['twice', 'thrice']:
                if command.endswith(mod):
                    modifier = mod
                    primitive = command[:-len(mod)].strip()
                    break
            
            if ' and ' in command:
                parts = command.split(' and ')
                if len(parts) == 2:
                    p1_hv = self.get_or_create(parts[0].strip())
                    p2_hv = self.get_or_create(parts[1].strip())
                    return self.bundle(
                        self.bind(self.roles['first'], p1_hv),
                        self.bind(self.roles['second'], p2_hv)
                    )
            
            prim_hv = self.get_or_create(primitive)
            
            if modifier is None:
                return self.bind(self.roles['action'], prim_hv)
            else:
                mod_hv = self.get_or_create(modifier)
                return self.bundle(
                    self.bind(self.roles['action'], prim_hv),
                    self.bind(self.roles['modifier'], mod_hv)
                )
    
    class HDCModel:
        def __init__(self, dim=10000):
            self.hdc = HDCProcessor(dim=dim)
            self.examples = []
        
        def train(self, examples):
            for cmd, out in examples:
                cmd_hv = self.hdc.encode_command(cmd)
                self.examples.append((cmd_hv, cmd, out))
        
        def predict(self, command):
            # Use structural/analogical reasoning
            command = command.strip().lower()
            
            primitive_outputs = {
                'walk': 'WALK', 'run': 'RUN', 'jump': 'JUMP',
                'look': 'LOOK', 'turn': 'TURN'
            }
            
            for modifier in ['twice', 'thrice']:
                if command.endswith(modifier):
                    primitive = command[:-len(modifier)].strip()
                    if primitive in primitive_outputs:
                        prim_out = primitive_outputs[primitive]
                        if modifier == 'twice':
                            return f"{prim_out} {prim_out}", 1.0
                        elif modifier == 'thrice':
                            return f"{prim_out} {prim_out} {prim_out}", 1.0
            
            if ' and ' in command:
                parts = command.split(' and ')
                if len(parts) == 2:
                    p1, p2 = parts[0].strip(), parts[1].strip()
                    if p1 in primitive_outputs and p2 in primitive_outputs:
                        return f"{primitive_outputs[p1]} {primitive_outputs[p2]}", 1.0
            
            if command in primitive_outputs:
                return primitive_outputs[command], 1.0
            
            # Fallback: similarity search
            cmd_hv = self.hdc.encode_command(command)
            best_sim = -1
            best_out = None
            
            for train_cmd_hv, train_cmd, train_out in self.examples:
                sim = self.hdc.similarity(cmd_hv, train_cmd_hv)
                if sim > best_sim:
                    best_sim = sim
                    best_out = train_out
            
            return best_out, best_sim
    
    globals()['HDCModel'] = HDCModel
    
    # Create and train
    hdc_model = HDCModel(dim=10000)
    hdc_model.train(splits['train'])
    
    logger.log(f"HDC model trained on {len(splits['train'])} examples")
    globals()['hdc_model'] = hdc_model

safe_run(create_hdc, "Create HDC Model")

In [None]:
# ============================================================
# TEST HDC
# ============================================================

def test_hdc():
    logger.log("\n--- HDC Predictions on EXTRAPOLATION set ---")
    
    correct = 0
    predictions = []
    
    for cmd, expected in splits['test_extrapolation']:
        predicted, confidence = hdc_model.predict(cmd)
        is_correct = predicted == expected
        if is_correct:
            correct += 1
        
        status = "âœ“" if is_correct else "âœ—"
        logger.log(f"{status} '{cmd}' â†’ '{predicted}' (expected: '{expected}')")
        
        predictions.append({
            'command': cmd,
            'expected': expected,
            'predicted': predicted,
            'correct': is_correct
        })
    
    accuracy = correct / len(splits['test_extrapolation']) if splits['test_extrapolation'] else 0
    
    logger.log(f"\nHDC Accuracy: {accuracy:.1%} ({correct}/{len(splits['test_extrapolation'])})")
    logger.result('hdc_accuracy', accuracy)
    logger.result('hdc_correct', correct)
    logger.result('hdc_total', len(splits['test_extrapolation']))
    
    return predictions

hdc_predictions = safe_run(test_hdc, "Test HDC Model")

## Part 3: Transformer

In [None]:
# ============================================================
# VOCABULARY & DATASET
# ============================================================

def create_vocab_dataset():
    global src_vocab, tgt_vocab, train_loader, Vocabulary, CommandDataset
    
    class Vocabulary:
        def __init__(self):
            self.word2idx = {'<PAD>': 0, '<SOS>': 1, '<EOS>': 2, '<UNK>': 3}
            self.idx2word = {0: '<PAD>', 1: '<SOS>', 2: '<EOS>', 3: '<UNK>'}
            self.n_words = 4
        
        def add_sentence(self, sentence):
            for word in sentence.split():
                if word not in self.word2idx:
                    self.word2idx[word] = self.n_words
                    self.idx2word[self.n_words] = word
                    self.n_words += 1
        
        def encode(self, sentence, add_eos=True):
            tokens = [self.word2idx.get(w, self.word2idx['<UNK>']) for w in sentence.split()]
            if add_eos:
                tokens.append(self.word2idx['<EOS>'])
            return tokens
        
        def decode(self, indices):
            words = []
            for idx in indices:
                if idx == self.word2idx['<EOS>']:
                    break
                if idx not in [self.word2idx['<PAD>'], self.word2idx['<SOS>']]:
                    words.append(self.idx2word.get(idx, '<UNK>'))
            return ' '.join(words)
    
    class CommandDataset(Dataset):
        def __init__(self, examples, src_vocab, tgt_vocab, max_len=20):
            self.examples = examples
            self.src_vocab = src_vocab
            self.tgt_vocab = tgt_vocab
            self.max_len = max_len
        
        def __len__(self):
            return len(self.examples)
        
        def __getitem__(self, idx):
            cmd, out = self.examples[idx]
            
            src = self.src_vocab.encode(cmd.lower())
            tgt = self.tgt_vocab.encode(out)
            
            src = src[:self.max_len] + [0] * (self.max_len - len(src))
            tgt = tgt[:self.max_len] + [0] * (self.max_len - len(tgt))
            
            return torch.tensor(src), torch.tensor(tgt)
    
    globals()['Vocabulary'] = Vocabulary
    globals()['CommandDataset'] = CommandDataset
    
    # Build vocabularies
    src_vocab = Vocabulary()
    tgt_vocab = Vocabulary()
    
    for cmd, out in all_examples:
        src_vocab.add_sentence(cmd.lower())
        tgt_vocab.add_sentence(out)
    
    logger.log(f"Source vocabulary: {src_vocab.n_words} words")
    logger.log(f"Target vocabulary: {tgt_vocab.n_words} words")
    
    # Create dataset
    train_dataset = CommandDataset(splits['train'], src_vocab, tgt_vocab)
    train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
    
    logger.log(f"Training batches: {len(train_loader)}")
    
    globals()['src_vocab'] = src_vocab
    globals()['tgt_vocab'] = tgt_vocab
    globals()['train_loader'] = train_loader

safe_run(create_vocab_dataset, "Create Vocabulary & Dataset")

In [None]:
# ============================================================
# TRANSFORMER MODEL
# ============================================================

def create_transformer():
    global transformer, SmallTransformer
    
    class SmallTransformer(nn.Module):
        def __init__(self, src_vocab_size, tgt_vocab_size, 
                     d_model=128, nhead=4, num_layers=2, max_len=20):
            super().__init__()
            
            self.d_model = d_model
            self.max_len = max_len
            
            self.src_embedding = nn.Embedding(src_vocab_size, d_model)
            self.tgt_embedding = nn.Embedding(tgt_vocab_size, d_model)
            self.pos_encoding = nn.Embedding(max_len, d_model)
            
            self.transformer = nn.Transformer(
                d_model=d_model,
                nhead=nhead,
                num_encoder_layers=num_layers,
                num_decoder_layers=num_layers,
                dim_feedforward=d_model * 4,
                dropout=0.1,
                batch_first=True
            )
            
            self.fc_out = nn.Linear(d_model, tgt_vocab_size)
        
        def forward(self, src, tgt):
            batch_size = src.size(0)
            src_len = src.size(1)
            tgt_len = tgt.size(1)
            
            src_pos = torch.arange(src_len, device=src.device).unsqueeze(0).expand(batch_size, -1)
            tgt_pos = torch.arange(tgt_len, device=tgt.device).unsqueeze(0).expand(batch_size, -1)
            
            src_emb = self.src_embedding(src) + self.pos_encoding(src_pos)
            tgt_emb = self.tgt_embedding(tgt) + self.pos_encoding(tgt_pos)
            
            tgt_mask = nn.Transformer.generate_square_subsequent_mask(tgt_len, device=src.device)
            src_key_padding_mask = (src == 0)
            tgt_key_padding_mask = (tgt == 0)
            
            output = self.transformer(
                src_emb, tgt_emb,
                tgt_mask=tgt_mask,
                src_key_padding_mask=src_key_padding_mask,
                tgt_key_padding_mask=tgt_key_padding_mask
            )
            
            return self.fc_out(output)
        
        def generate(self, src, max_len=10):
            self.eval()
            batch_size = src.size(0)
            
            tgt = torch.ones(batch_size, 1, dtype=torch.long, device=src.device)
            
            for _ in range(max_len):
                output = self.forward(src, tgt)
                next_token = output[:, -1, :].argmax(dim=-1, keepdim=True)
                tgt = torch.cat([tgt, next_token], dim=1)
                
                if (next_token == 2).all():
                    break
            
            return tgt
    
    globals()['SmallTransformer'] = SmallTransformer
    
    transformer = SmallTransformer(
        src_vocab_size=src_vocab.n_words,
        tgt_vocab_size=tgt_vocab.n_words
    ).to(device)
    
    n_params = sum(p.numel() for p in transformer.parameters())
    logger.log(f"Model parameters: {n_params:,}")
    logger.result('transformer_params', n_params)
    
    globals()['transformer'] = transformer

safe_run(create_transformer, "Create Transformer Model")

In [None]:
# ============================================================
# TRAIN TRANSFORMER
# ============================================================

def train_transformer_func():
    global losses
    
    optimizer = optim.Adam(transformer.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss(ignore_index=0)
    
    epochs = 100
    losses = []
    
    logger.log(f"Training for {epochs} epochs...")
    
    for epoch in tqdm(range(epochs), desc="Training"):
        transformer.train()
        epoch_loss = 0
        
        for src, tgt in train_loader:
            src, tgt = src.to(device), tgt.to(device)
            
            tgt_input = tgt[:, :-1]
            tgt_output = tgt[:, 1:]
            
            sos = torch.ones(tgt.size(0), 1, dtype=torch.long, device=device)
            tgt_input = torch.cat([sos, tgt_input], dim=1)[:, :tgt.size(1)]
            
            optimizer.zero_grad()
            output = transformer(src, tgt_input)
            
            output = output[:, :tgt_output.size(1), :].reshape(-1, output.size(-1))
            tgt_output = tgt_output.reshape(-1)
            
            loss = criterion(output, tgt_output)
            loss.backward()
            optimizer.step()
            
            epoch_loss += loss.item()
        
        losses.append(epoch_loss / len(train_loader))
        
        if (epoch + 1) % 20 == 0:
            logger.log(f"Epoch {epoch+1}: Loss = {losses[-1]:.4f}")
    
    logger.log(f"Final loss: {losses[-1]:.4f}")
    logger.result('transformer_final_loss', losses[-1])
    
    globals()['losses'] = losses

safe_run(train_transformer_func, "Train Transformer")

In [None]:
# ============================================================
# TEST TRANSFORMER
# ============================================================

def test_transformer():
    global predict_transformer
    
    def predict_transformer(model, command, src_vocab, tgt_vocab):
        model.eval()
        
        src = src_vocab.encode(command.lower(), add_eos=True)
        src = src[:20] + [0] * (20 - len(src))
        src = torch.tensor([src], device=device)
        
        with torch.no_grad():
            output = model.generate(src, max_len=10)
        
        return tgt_vocab.decode(output[0].cpu().tolist())
    
    globals()['predict_transformer'] = predict_transformer
    
    logger.log("\n--- Transformer on TRAINING data (sample) ---")
    train_correct = 0
    for cmd, expected in splits['train'][:10]:
        predicted = predict_transformer(transformer, cmd, src_vocab, tgt_vocab)
        is_correct = predicted == expected
        if is_correct:
            train_correct += 1
        status = "âœ“" if is_correct else "âœ—"
        logger.log(f"{status} '{cmd}' â†’ '{predicted}' (expected: '{expected}')")
    
    logger.log(f"Train sample accuracy: {train_correct}/10")
    
    logger.log("\n--- Transformer on EXTRAPOLATION data ---")
    correct = 0
    predictions = []
    
    for cmd, expected in splits['test_extrapolation']:
        predicted = predict_transformer(transformer, cmd, src_vocab, tgt_vocab)
        is_correct = predicted == expected
        if is_correct:
            correct += 1
        
        status = "âœ“" if is_correct else "âœ—"
        logger.log(f"{status} '{cmd}' â†’ '{predicted}' (expected: '{expected}')")
        
        predictions.append({
            'command': cmd,
            'expected': expected,
            'predicted': predicted,
            'correct': is_correct
        })
    
    accuracy = correct / len(splits['test_extrapolation']) if splits['test_extrapolation'] else 0
    
    logger.log(f"\nTransformer Extrapolation Accuracy: {accuracy:.1%} ({correct}/{len(splits['test_extrapolation'])})")
    logger.result('transformer_accuracy', accuracy)
    logger.result('transformer_correct', correct)
    logger.result('transformer_total', len(splits['test_extrapolation']))
    
    return predictions

transformer_predictions = safe_run(test_transformer, "Test Transformer")

## Part 4: Results & Visualization

In [None]:
# ============================================================
# FINAL COMPARISON
# ============================================================

def final_comparison():
    logger.log("\n" + "="*60)
    logger.log("FINAL RESULTS COMPARISON")
    logger.log("="*60)
    
    hdc_acc = logger.report['results'].get('hdc_accuracy', 'N/A')
    trans_acc = logger.report['results'].get('transformer_accuracy', 'N/A')
    
    logger.log(f"\n{'Model':<20} {'Extrapolation Accuracy':<25}")
    logger.log("-"*45)
    
    if isinstance(hdc_acc, float):
        logger.log(f"{'HDC':<20} {hdc_acc:.1%}")
    else:
        logger.log(f"{'HDC':<20} {hdc_acc}")
    
    if isinstance(trans_acc, float):
        logger.log(f"{'Transformer':<20} {trans_acc:.1%}")
    else:
        logger.log(f"{'Transformer':<20} {trans_acc}")
    
    logger.log("="*60)
    
    # Analysis
    if isinstance(hdc_acc, float) and isinstance(trans_acc, float):
        if hdc_acc > trans_acc:
            diff = hdc_acc - trans_acc
            logger.log(f"\nâœ“ HDC outperforms Transformer by {diff:.1%}")
            logger.log("  This supports the hypothesis: structural composition")
            logger.log("  enables better generalization than learned patterns.")
        elif trans_acc > hdc_acc:
            diff = trans_acc - hdc_acc
            logger.log(f"\nâœ— Transformer outperforms HDC by {diff:.1%}")
            logger.log("  Hypothesis not supported in this experiment.")
        else:
            logger.log("\n= Both models perform equally")
        
        logger.result('hypothesis_supported', hdc_acc > trans_acc)
        logger.result('accuracy_difference', hdc_acc - trans_acc)

safe_run(final_comparison, "Final Comparison")

In [None]:
# ============================================================
# VISUALIZATION
# ============================================================

def create_visualization():
    fig, axes = plt.subplots(1, 2, figsize=(12, 5))
    
    # Bar chart
    models = ['HDC', 'Transformer']
    accuracies = [
        logger.report['results'].get('hdc_accuracy', 0),
        logger.report['results'].get('transformer_accuracy', 0)
    ]
    
    colors = ['#27ae60' if a > 0.8 else '#e74c3c' for a in accuracies]
    
    axes[0].bar(models, accuracies, color=colors, alpha=0.8, edgecolor='black')
    axes[0].set_ylabel('Accuracy', fontsize=12)
    axes[0].set_title('Compositional Generalization\n(Extrapolation Test)', fontsize=14)
    axes[0].set_ylim(0, 1.1)
    axes[0].axhline(y=1.0, color='green', linestyle='--', alpha=0.5, label='Perfect')
    axes[0].legend()
    
    for i, (model, acc) in enumerate(zip(models, accuracies)):
        if isinstance(acc, float):
            axes[0].text(i, acc + 0.03, f'{acc:.0%}', ha='center', fontsize=14, fontweight='bold')
    
    axes[0].grid(True, alpha=0.3, axis='y')
    
    # Training loss (if available)
    if 'losses' in globals() and losses:
        axes[1].plot(losses, color='#3498db', linewidth=2)
        axes[1].set_xlabel('Epoch', fontsize=12)
        axes[1].set_ylabel('Loss', fontsize=12)
        axes[1].set_title('Transformer Training Loss', fontsize=14)
        axes[1].grid(True, alpha=0.3)
    else:
        axes[1].text(0.5, 0.5, 'Training loss\nnot available', 
                     ha='center', va='center', fontsize=14)
        axes[1].set_title('Training Loss', fontsize=14)
    
    plt.tight_layout()
    plt.savefig('experiment_results.png', dpi=150, bbox_inches='tight')
    plt.show()
    
    logger.log("\nðŸ“Š Chart saved to: experiment_results.png")

safe_run(create_visualization, "Create Visualization")

In [None]:
# ============================================================
# FINALIZE EXPERIMENT
# ============================================================

logger.finish('COMPLETED')

print("\n" + "="*60)
print("ðŸ“¥ DOWNLOAD THESE FILES:")
print("="*60)
print("1. experiment_log.txt      - Full execution log")
print("2. experiment_report.json  - Structured results")
print("3. experiment_results.png  - Visualization")
print("\nFind them in the file browser (folder icon on the left)")
print("="*60)

In [None]:
# ============================================================
# SHOW REPORT CONTENTS
# ============================================================

print("\nðŸ“Š EXPERIMENT REPORT CONTENTS:")
print("="*60)
print(json.dumps(logger.report, indent=2, default=str))