In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import numpy as np
from collections import defaultdict
import time

In [2]:
class BaseEmbeddingModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.dropout = nn.Dropout(0.2)
        self.layer1 = nn.Linear(embedding_dim, embedding_dim * 2)
        self.layer2 = nn.Linear(embedding_dim * 2, embedding_dim)
        self.fc = nn.Linear(embedding_dim, vocab_size)
        
    def forward(self, x):
        embedded = self.embedding(x)
        embedded = self.dropout(embedded)
        pooled = torch.mean(embedded, dim=1)
        x = F.relu(self.layer1(pooled))
        x = self.dropout(x)
        x = F.relu(self.layer2(x))
        x = self.dropout(x)
        return self.fc(x)
    
    def get_similarity(self, x1, x2):
        embedded1 = self.embedding(x1)
        embedded2 = self.embedding(x2)
        return F.cosine_similarity(embedded1.mean(1), embedded2.mean(1))

class CrossAttentionModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.dropout = nn.Dropout(0.2)
        self.attention = nn.MultiheadAttention(embedding_dim, num_heads=4, batch_first=True, dropout=0.1)
        self.layer1 = nn.Linear(embedding_dim, embedding_dim * 2)
        self.layer2 = nn.Linear(embedding_dim * 2, embedding_dim)
        self.fc = nn.Linear(embedding_dim, vocab_size)
        self.layer_norm = nn.LayerNorm(embedding_dim)
        
    def forward(self, x):
        embedded = self.embedding(x)
        embedded = self.dropout(embedded)
        attended, _ = self.attention(embedded, embedded, embedded)
        attended = self.layer_norm(attended + embedded)  # Residual connection
        pooled = torch.mean(attended, dim=1)
        x = F.relu(self.layer1(pooled))
        x = self.dropout(x)
        x = F.relu(self.layer2(x))
        x = self.dropout(x)
        return self.fc(x)
    
    def get_similarity(self, x1, x2):
        e1, e2 = self.embedding(x1), self.embedding(x2)
        attn_output, _ = self.attention(e1, e2, e2)
        return torch.sum(attn_output * e1, dim=-1) / torch.sqrt(torch.tensor(self.embedding.embedding_dim))

def train_and_measure(model, train_loader, val_loader, test_loader, epochs=30):
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=3, factor=0.5)
    metrics = defaultdict(list)
    
    best_val_loss = float('inf')
    best_model_state = None
    patience_counter = 0
    max_patience = 5
    
    start_time = time.time()
    total_energy = 0
    
    for epoch in range(epochs):
        # Training phase
        model.train()
        epoch_loss = 0
        
        for batch_idx, (data, target) in enumerate(train_loader):
            optimizer.zero_grad()
            output = model(data)
            loss = F.cross_entropy(output, target)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            epoch_loss += loss.item()
        
        # Validation phase
        model.eval()
        val_loss = 0
        val_correct = 0
        
        with torch.no_grad():
            for data, target in val_loader:
                output = model(data)
                val_loss += F.cross_entropy(output, target).item()
                pred = output.argmax(dim=1)
                val_correct += pred.eq(target).sum().item()
        
        avg_val_loss = val_loss / len(val_loader)
        val_accuracy = val_correct / len(val_loader.dataset)
        
        # Early stopping check
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            best_model_state = model.state_dict().copy()
            patience_counter = 0
        else:
            patience_counter += 1
            
        if patience_counter >= max_patience:
            print(f"Early stopping at epoch {epoch+1}")
            break
        
        scheduler.step(avg_val_loss)
        
        # Store metrics
        metrics['train_loss'].append(epoch_loss / len(train_loader))
        metrics['val_loss'].append(avg_val_loss)
        metrics['val_accuracy'].append(val_accuracy)
        
        print(f'Epoch {epoch+1}: Train Loss: {metrics["train_loss"][-1]:.4f}, '
              f'Val Loss: {metrics["val_loss"][-1]:.4f}, '
              f'Val Accuracy: {metrics["val_accuracy"][-1]:.4f}')
    
    # Load best model and evaluate on test set
    model.load_state_dict(best_model_state)
    model.eval()
    test_loss = 0
    test_correct = 0
    
    with torch.no_grad():
        for data, target in test_loader:
            output = model(data)
            test_loss += F.cross_entropy(output, target).item()
            pred = output.argmax(dim=1)
            test_correct += pred.eq(target).sum().item()
    
    training_time = time.time() - start_time
    
    return {
        'metrics': metrics,
        'training_time': training_time,
        'energy_proxy': total_energy,
        'test_loss': test_loss / len(test_loader),
        'test_accuracy': test_correct / len(test_loader.dataset)
    }

def run_comparison(vocab_size=1000, embedding_dim=64, dataset_size=10000):
    # Generate even more structured synthetic data
    sequences = []
    labels = []
    
    for _ in range(dataset_size):
        pattern_type = np.random.random()
        
        if pattern_type < 0.33:
            # Pattern 1: sum of first and last token maps to label
            seq = torch.randint(0, vocab_size//4, (10,))
            label = (seq[0] + seq[-1]) % vocab_size
            
        elif pattern_type < 0.66:
            # Pattern 2: sequence of similar tokens predicts next
            base_token = torch.randint(0, vocab_size//4, (1,))
            noise = torch.randint(-2, 3, (10,))
            seq = (base_token + noise) % vocab_size
            label = (base_token + 3) % vocab_size
            
        else:
            # Pattern 3: majority token in sequence
            majority_token = torch.randint(0, vocab_size//4, (1,))
            seq = torch.randint(0, vocab_size//4, (10,))
            # Insert majority token multiple times
            positions = torch.randint(0, 10, (4,))
            seq[positions] = majority_token
            label = majority_token

        sequences.append(seq)
        labels.append(label)
    
    data = torch.stack(sequences)
    labels = torch.tensor(labels)
    
    # Split into train/validation/test
    train_size = int(0.7 * len(data))
    val_size = int(0.15 * len(data))
    
    train_data = DataLoader(
        list(zip(data[:train_size], labels[:train_size])), 
        batch_size=32,
        shuffle=True
    )
    
    val_data = DataLoader(
        list(zip(data[train_size:train_size+val_size], 
                labels[train_size:train_size+val_size])), 
        batch_size=32
    )
    
    test_data = DataLoader(
        list(zip(data[train_size+val_size:], labels[train_size+val_size:])), 
        batch_size=32
    )
    
    # Train both models with validation
    base_model = BaseEmbeddingModel(vocab_size=vocab_size, embedding_dim=embedding_dim)
    cross_attn_model = CrossAttentionModel(vocab_size=vocab_size, embedding_dim=embedding_dim)
    
    base_results = train_and_measure(base_model, train_data, val_data, test_data, epochs=30)
    cross_attn_results = train_and_measure(cross_attn_model, train_data, val_data, test_data, epochs=30)
    
    return {
        'base_model': base_results,
        'cross_attention': cross_attn_results
    }

def print_results(results, configuration):
    print(f"\n=== Results for {configuration} ===")
    print("\nBase Model:")
    print(f"Training Time: {results['base_model']['training_time']:.2f} seconds")
    print(f"Final Training Loss: {results['base_model']['metrics']['train_loss'][-1]:.4f}")
    print(f"Final Validation Loss: {results['base_model']['metrics']['val_loss'][-1]:.4f}")
    print(f"Final Test Loss: {results['base_model']['test_loss']:.4f}")
    print(f"Final Test Accuracy: {results['base_model']['test_accuracy']:.4f}")
    
    print("\nCross-Attention Model:")
    print(f"Training Time: {results['cross_attention']['training_time']:.2f} seconds")
    print(f"Final Training Loss: {results['cross_attention']['metrics']['train_loss'][-1]:.4f}")
    print(f"Final Validation Loss: {results['cross_attention']['metrics']['val_loss'][-1]:.4f}")
    print(f"Final Test Loss: {results['cross_attention']['test_loss']:.4f}")
    print(f"Final Test Accuracy: {results['cross_attention']['test_accuracy']:.4f}")

In [3]:
# Run with slightly larger settings
results = run_comparison(
    vocab_size=100,  # Increased vocabulary
    embedding_dim=64,  # Increased embedding dimension
    dataset_size=5000  # Increased dataset size
)

print_results(results, "Enhanced Configuration")

Epoch 1: Train Loss: 3.7187, Val Loss: 3.0403, Val Accuracy: 0.1693
Epoch 2: Train Loss: 2.8727, Val Loss: 2.6366, Val Accuracy: 0.2680
Epoch 3: Train Loss: 2.6037, Val Loss: 2.4737, Val Accuracy: 0.3053
Epoch 4: Train Loss: 2.4515, Val Loss: 2.3371, Val Accuracy: 0.3653
Epoch 5: Train Loss: 2.3229, Val Loss: 2.2285, Val Accuracy: 0.4027
Epoch 6: Train Loss: 2.2405, Val Loss: 2.1598, Val Accuracy: 0.4320
Epoch 7: Train Loss: 2.1505, Val Loss: 2.0703, Val Accuracy: 0.4760
Epoch 8: Train Loss: 2.0874, Val Loss: 2.0234, Val Accuracy: 0.5107
Epoch 9: Train Loss: 2.0126, Val Loss: 1.9667, Val Accuracy: 0.5267
Epoch 10: Train Loss: 1.9619, Val Loss: 1.9221, Val Accuracy: 0.5507
Epoch 11: Train Loss: 1.9122, Val Loss: 1.8965, Val Accuracy: 0.5560
Epoch 12: Train Loss: 1.8954, Val Loss: 1.8527, Val Accuracy: 0.5773
Epoch 13: Train Loss: 1.8390, Val Loss: 1.8339, Val Accuracy: 0.5720
Epoch 14: Train Loss: 1.8052, Val Loss: 1.8117, Val Accuracy: 0.5733
Epoch 15: Train Loss: 1.7752, Val Loss: 1.8