In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import numpy as np
from collections import defaultdict
import time

In [2]:
class BaseEmbeddingModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.fc = nn.Linear(embedding_dim, embedding_dim)
        
    def forward(self, x):
        # Standard embedding approach with cosine similarity
        embedded = self.embedding(x)
        return F.normalize(self.fc(embedded), p=2, dim=-1)
    
    def get_similarity(self, x1, x2):
        # Using cosine similarity
        return F.cosine_similarity(self.forward(x1), self.forward(x2))

class CrossAttentionModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.attention = nn.MultiheadAttention(embedding_dim, num_heads=4, batch_first=True)
        self.fc = nn.Linear(embedding_dim, embedding_dim)
        
    def forward(self, x):
        # Cross-attention based approach
        embedded = self.embedding(x)
        attended, _ = self.attention(embedded, embedded, embedded)
        return self.fc(attended)
    
    def get_similarity(self, x1, x2):
        # Using attention-based similarity
        e1, e2 = self.embedding(x1), self.embedding(x2)
        attn_output, _ = self.attention(e1, e2, e2)
        return torch.sum(attn_output * e1, dim=-1) / torch.sqrt(torch.tensor(self.embedding.embedding_dim))

def train_and_measure(model, train_loader, test_loader, epochs=5):
    optimizer = torch.optim.Adam(model.parameters())
    metrics = defaultdict(list)
    
    start_time = time.time()
    total_energy = 0  # Proxy for energy consumption
    
    for epoch in range(epochs):
        model.train()
        epoch_loss = 0
        
        for batch_idx, (data, target) in enumerate(train_loader):
            optimizer.zero_grad()
            
            # Record starting CPU temperature (proxy for energy)
            start_temp = torch.cuda.get_device_properties(0).temperature if torch.cuda.is_available() else 0
            
            output = model(data)
            loss = F.cross_entropy(output, target)
            loss.backward()
            optimizer.step()
            
            # Calculate energy proxy
            end_temp = torch.cuda.get_device_properties(0).temperature if torch.cuda.is_available() else 0
            energy_proxy = abs(end_temp - start_temp)
            total_energy += energy_proxy
            
            epoch_loss += loss.item()
            
        # Evaluation phase
        model.eval()
        test_loss = 0
        correct = 0
        
        with torch.no_grad():
            for data, target in test_loader:
                output = model(data)
                test_loss += F.cross_entropy(output, target).item()
                pred = output.argmax(dim=1)
                correct += pred.eq(target).sum().item()
        
        metrics['train_loss'].append(epoch_loss / len(train_loader))
        metrics['test_loss'].append(test_loss / len(test_loader))
        metrics['accuracy'].append(correct / len(test_loader.dataset))
        
    training_time = time.time() - start_time
    
    return {
        'metrics': metrics,
        'training_time': training_time,
        'energy_proxy': total_energy
    }

# Usage example:
def run_comparison(vocab_size=1000, embedding_dim=64, dataset_size=10000):
    # Generate synthetic data
    data = torch.randint(0, vocab_size, (dataset_size, 10))
    labels = torch.randint(0, vocab_size, (dataset_size,))
    
    # Create dataloaders
    train_size = int(0.8 * len(data))
    train_data = DataLoader(list(zip(data[:train_size], labels[:train_size])), batch_size=32, shuffle=True)
    test_data = DataLoader(list(zip(data[train_size:], labels[train_size:])), batch_size=32)
    
    # Train both models
    base_model = BaseEmbeddingModel(vocab_size, embedding_dim)
    cross_attn_model = CrossAttentionModel(vocab_size, embedding_dim)
    
    base_results = train_and_measure(base_model, train_data, test_data)
    cross_attn_results = train_and_measure(cross_attn_model, train_data, test_data)
    
    return {
        'base_model': base_results,
        'cross_attention': cross_attn_results
    }