In [1]:
%load_ext autoreload
%autoreload 2

In [4]:
import os 
path = '/Volumes/Sid_Drive/mnist/'

if os.path.exists(path):
    prefix = path
else:
    prefix = ''

In [5]:
prefix

'/Volumes/Sid_Drive/mnist/'

In [6]:
import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from tqdm import tqdm
from sklearn.model_selection import train_test_split

torch.manual_seed(42)

# Hyperparameters
batch_size = 64
learning_rate = 0.001
epochs = 10

In [17]:
import torch
import numpy as np
from scipy.stats import entropy
from sklearn.metrics import silhouette_score
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from tqdm import tqdm 

class CrossDatasetAnalyzer:
    def __init__(self, dataset_names, max_depth=9, prefix=prefix):
        """
        Args:
            prefix: Path prefix for loading files
            dataset_names: List of dataset names to analyze
            max_depth: Maximum depth to analyze
        """
        self.prefix = prefix
        self.dataset_names = dataset_names
        self.max_depth = max_depth
    
    def load_depth_embeddings(self, depth, dataset_name):
        """Load embeddings for a specific depth and dataset"""
        path = f'{self.prefix}embeddings/mnist_encoder_{dataset_name}_depth_{depth}.pth'
        return torch.load(path)
    
    def analyze_activation_patterns(self, activations):
        """Analyze activation patterns"""
        metrics = {}
        
        # 1. Activation Statistics
        metrics['mean_activation'] = torch.mean(activations).item()
        metrics['activation_std'] = torch.std(activations).item()
        metrics['sparsity'] = (activations == 0).float().mean().item()
        
        # 2. Active Feature Count
        threshold = activations.mean() + activations.std()
        active_features = (activations > threshold).sum(dim=1)
        metrics['avg_active_features'] = active_features.float().mean().item()
        
        # 3. Feature Utilization
        feature_usage = (activations > threshold).float().mean(dim=0)
        metrics['feature_utilization'] = feature_usage.mean().item()
        metrics['feature_utilization_std'] = feature_usage.std().item()
        
        # 4. Activation Distribution
        normalized = torch.nn.functional.softmax(activations, dim=1)
        activation_entropy = entropy(normalized.numpy(), axis=1)
        metrics['activation_entropy'] = np.mean(activation_entropy)
        
        return metrics
    
    def compare_datasets(self):
        """Compare activation patterns across datasets and depths"""
        results = {}
        
        for dataset_name in tqdm(self.dataset_names, desc="Processing datasets"):
            depth_metrics = []
            
            for depth in tqdm(range(1, self.max_depth + 1), desc=f"Processing depth for {dataset_name}"):
                try:
                    # Load embeddings for this depth
                    activations = self.load_depth_embeddings(depth, dataset_name)
                    
                    # Analyze patterns
                    metrics = self.analyze_activation_patterns(activations)
                    depth_metrics.append(metrics)
                    
                except FileNotFoundError:
                    print(f"No embeddings found for {dataset_name} at depth {depth}")
                    break
                
            results[dataset_name] = depth_metrics
        
        return results
    
    def plot_metrics_across_depths(self, results):
        """Plot how metrics change across depths for each dataset"""
        metrics = list(next(iter(results.values()))[0].keys())
        
        for metric in metrics:
            plt.figure(figsize=(10, 6))
            for dataset_name in self.dataset_names:
                values = [m[metric] for m in results[dataset_name]]
                plt.plot(range(1, len(values) + 1), values, label=dataset_name)
            
            plt.xlabel('Depth')
            plt.ylabel(metric)
            plt.title(f'{metric} vs Depth')
            plt.legend()
            plt.grid(True)
            plt.savefig(f"plots/hypothesis/mnist_{metric}.png")
            plt.show()

In [18]:
analyzer = CrossDatasetAnalyzer(['MNIST', 'CIFAR100', 'EMNIST_letter', 'EMNIST'], max_depth=9, prefix=prefix)

In [19]:
results = analyzer.compare_datasets()

  return torch.load(path)
Processing depth for EMNIST_letter:   0%|          | 0/9 [00:00<?, ?it/s]
Processing datasets: 100%|██████████| 1/1 [00:00<00:00, 155.30it/s]

No embeddings found for EMNIST_letter at depth 1





In [21]:
analyzer.plot_metrics_across_depths(results)

IndexError: list index out of range