# Memory Efficiency Comparison: pytorch-cka vs torch-cka

This notebook benchmarks GPU memory management between `pytorch-cka` (this library) and `torch-cka` using CIFAR-10 and ResNet-18 self-comparison.

**Key metrics:**
1. **Memory during computation** - How memory grows (or stays flat) as batches are processed
2. **Memory after computation** - Memory retained after computation completes (not deallocated)

**Why this matters:**
- `pytorch-cka`: Clears features after each batch → constant memory usage during computation
- `torch-cka`: Accumulates activations → memory grows with each batch and remains after computation

In [1]:
!pip install torch-cka pytorch-cka -q

In [2]:
import gc
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np

from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms, models

from cka import CKA as PytorchCKA
from torch_cka import CKA as TorchCKA

In [3]:
# Check CUDA availability
if not torch.cuda.is_available():
    raise RuntimeError("CUDA is required for GPU memory benchmarking. Please run on a GPU-enabled machine.")

device = torch.device("cuda")
print(f"Using device: {device}")
print(f"GPU: {torch.cuda.get_device_name(0)}")
print(f"Total GPU memory: {torch.cuda.get_device_properties(0).total_memory / (1024**3):.2f} GB")

Using device: cuda
GPU: NVIDIA A100-SXM4-80GB
Total GPU memory: 79.32 GB


In [4]:
def measure_memory_retention(func):
    """Measure memory retained after computation (not deallocated).

    This captures the key difference: pytorch-cka deallocates memory
    after computation, while torch-cka retains it.
    """
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()

    mem_before = torch.cuda.memory_allocated()

    result = func()

    # Force cleanup attempt
    gc.collect()
    torch.cuda.empty_cache()

    mem_after = torch.cuda.memory_allocated()
    peak_mb = torch.cuda.max_memory_allocated() / (1024**2)
    retained_mb = (mem_after - mem_before) / (1024**2)

    return result, peak_mb, retained_mb

In [None]:
def measure_memory_at_batch_counts(cka_benchmark_func, model_fn, full_dataloader, batch_counts):
    """Measure retained memory after processing different numbers of batches.

    This shows how memory accumulates (or not) as more data is processed.

    Args:
        cka_benchmark_func: Function that runs CKA (benchmark_pytorch_cka or benchmark_torch_cka)
        model_fn: Function that creates a fresh model
        full_dataloader: DataLoader with all data
        batch_counts: List of batch counts to test [1, 2, 4, 8, ...]

    Returns:
        memory_by_batch_count: List of retained memory (MB) after each batch count
    """
    memory_by_batch_count = []
    batch_size = full_dataloader.batch_size

    for num_batches in batch_counts:
        # Create subset with exactly num_batches worth of data
        subset_size = min(num_batches * batch_size, len(full_dataloader.dataset))
        subset_indices = list(range(subset_size))
        subset_data = Subset(full_dataloader.dataset, subset_indices)
        subset_loader = DataLoader(subset_data, batch_size=batch_size, shuffle=False, num_workers=0)

        # Create fresh model
        model = model_fn().to(device)
        model.eval()

        # Measure memory after running CKA
        gc.collect()
        torch.cuda.empty_cache()
        baseline = torch.cuda.memory_allocated()

        cka_benchmark_func(model, subset_loader)

        # Measure retained memory (before explicit cleanup)
        retained_mb = (torch.cuda.memory_allocated() - baseline) / (1024**2)
        memory_by_batch_count.append(retained_mb)

        # Cleanup for next iteration
        del model
        gc.collect()
        torch.cuda.empty_cache()

    return memory_by_batch_count

In [5]:
def get_resnet18_cifar(num_classes=10):
    """ResNet-18 adapted for CIFAR-10 (smaller input size)."""
    model = models.resnet18(weights=None)
    model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
    model.maxpool = nn.Identity()
    model.fc = nn.Linear(model.fc.in_features, num_classes)
    return model

# CIFAR-10 transforms
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

# Download CIFAR-10
dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)

# Use a subset for faster benchmarking (1000 samples)
subset_indices = list(range(1000))
subset = Subset(dataset, subset_indices)

print(f"Dataset size: {len(subset)} samples")

100%|██████████| 170M/170M [00:13<00:00, 12.3MB/s]


Dataset size: 1000 samples


In [None]:
def benchmark_pytorch_cka(model, dataloader):
    """Run pytorch-cka self-comparison."""
    with PytorchCKA(model, model, device=device) as cka:
        cka.compare(dataloader, progress=False)

def benchmark_torch_cka(model, dataloader):
    """Run torch-cka self-comparison."""
    cka = TorchCKA(
        model, model,
        model1_name="ResNet18",
        model2_name="ResNet18",
        device=device
    )
    with torch.no_grad():
        cka.compare(dataloader, dataloader)

In [7]:
# Batch sizes to test
batch_sizes = [64, 128, 256, 512]

# Results storage
pytorch_cka_peak = []
pytorch_cka_retained = []
torch_cka_peak = []
torch_cka_retained = []

print("Running memory benchmarks...\n")
print(f"{'Batch':<8} {'pytorch-cka':<24} {'torch-cka':<24}")
print(f"{'Size':<8} {'Peak (MB)':<12} {'Retained (MB)':<12} {'Peak (MB)':<12} {'Retained (MB)':<12}")
print("-" * 72)

for batch_size in batch_sizes:
    dataloader = DataLoader(subset, batch_size=batch_size, shuffle=False, num_workers=0)

    # Create fresh model for each test
    model = get_resnet18_cifar().to(device)
    model.eval()

    # Benchmark pytorch-cka
    _, peak1, retained1 = measure_memory_retention(lambda: benchmark_pytorch_cka(model, dataloader))
    pytorch_cka_peak.append(peak1)
    pytorch_cka_retained.append(retained1)

    # Clear and recreate model
    del model
    gc.collect()
    torch.cuda.empty_cache()

    model = get_resnet18_cifar().to(device)
    model.eval()

    # Benchmark torch-cka
    _, peak2, retained2 = measure_memory_retention(lambda: benchmark_torch_cka(model, dataloader))
    torch_cka_peak.append(peak2)
    torch_cka_retained.append(retained2)

    print(f"{batch_size:<8} {peak1:<12.2f} {retained1:<12.2f} {peak2:<12.2f} {retained2:<12.2f}")

    # Cleanup
    del model
    gc.collect()
    torch.cuda.empty_cache()

Running memory benchmarks...

Batch    pytorch-cka              torch-cka               
Size     Peak (MB)    Retained (MB) Peak (MB)    Retained (MB)
------------------------------------------------------------------------


  warn(f"Both model have identical names - {self.model2_info['Name']}. " \
| Comparing features |: 100%|██████████| 16/16 [01:07<00:00,  4.24s/it]


64       367.57       9.12         387.35       197.08      


| Comparing features |: 100%|██████████| 8/8 [00:35<00:00,  4.41s/it]


128      673.32       0.00         719.83       489.71      


| Comparing features |: 100%|██████████| 4/4 [00:24<00:00,  6.19s/it]


256      1306.60      0.00         1384.71      1090.96     


| Comparing features |: 100%|██████████| 2/2 [00:19<00:00,  9.56s/it]


512      2664.05      0.00         2715.71      2289.22     


In [None]:
# Memory timeline: How memory grows (or stays flat) as more batches are processed
# This directly compares memory allocation DURING computation

print("Running memory timeline benchmark...")
print("(Measuring retained memory after processing 1, 2, 4, 8, 16 batches)\n")

timeline_batch_size = 64
timeline_dataloader = DataLoader(subset, batch_size=timeline_batch_size, shuffle=False, num_workers=0)
batch_counts = [1, 2, 4, 8, 16]

# Measure memory timeline for pytorch-cka
print("Testing pytorch-cka...")
pytorch_memory_timeline = measure_memory_at_batch_counts(
    benchmark_pytorch_cka, get_resnet18_cifar, timeline_dataloader, batch_counts
)

# Measure memory timeline for torch-cka
print("Testing torch-cka...")
torch_memory_timeline = measure_memory_at_batch_counts(
    benchmark_torch_cka, get_resnet18_cifar, timeline_dataloader, batch_counts
)

print("\nMemory Timeline Results:")
print(f"{'Batches':<10} {'pytorch-cka (MB)':<20} {'torch-cka (MB)':<20}")
print("-" * 50)
for i, num_batches in enumerate(batch_counts):
    print(f"{num_batches:<10} {pytorch_memory_timeline[i]:<20.2f} {torch_memory_timeline[i]:<20.2f}")

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

x = np.arange(len(batch_sizes))
width = 0.35

# Left: Memory timeline (during computation)
ax0 = axes[0]
ax0.plot(batch_counts, pytorch_memory_timeline, 'o-', label='pytorch-cka', color='#2ecc71', linewidth=2, markersize=8)
ax0.plot(batch_counts, torch_memory_timeline, 'o-', label='torch-cka', color='#e74c3c', linewidth=2, markersize=8)
ax0.set_xlabel('Number of Batches Processed')
ax0.set_ylabel('Memory Allocated (MB)')
ax0.set_title('Memory During Computation\n(grows vs stays flat)')
ax0.legend()
ax0.grid(alpha=0.3)
ax0.set_xticks(batch_counts)

# Middle: Peak memory comparison
ax1 = axes[1]
ax1.bar(x - width/2, pytorch_cka_peak, width, label='pytorch-cka', color='#2ecc71')
ax1.bar(x + width/2, torch_cka_peak, width, label='torch-cka', color='#e74c3c')
ax1.set_xlabel('Batch Size')
ax1.set_ylabel('Peak GPU Memory (MB)')
ax1.set_title('Peak Memory During Computation')
ax1.set_xticks(x)
ax1.set_xticklabels(batch_sizes)
ax1.legend()
ax1.grid(axis='y', alpha=0.3)

# Right: Retained memory comparison (after computation)
ax2 = axes[2]
ax2.bar(x - width/2, pytorch_cka_retained, width, label='pytorch-cka', color='#2ecc71')
ax2.bar(x + width/2, torch_cka_retained, width, label='torch-cka', color='#e74c3c')
ax2.set_xlabel('Batch Size')
ax2.set_ylabel('Retained GPU Memory (MB)')
ax2.set_title('Memory After Computation\n(not deallocated)')
ax2.set_xticks(x)
ax2.set_xticklabels(batch_sizes)
ax2.legend()
ax2.grid(axis='y', alpha=0.3)

plt.tight_layout()
plt.show()

In [None]:
print("=" * 70)
print("SUMMARY: Memory Efficiency Results")
print("=" * 70)

avg_pytorch_retained = np.mean(pytorch_cka_retained)
avg_torch_retained = np.mean(torch_cka_retained)
avg_pytorch_peak = np.mean(pytorch_cka_peak)
avg_torch_peak = np.mean(torch_cka_peak)

print("\n1. MEMORY DURING COMPUTATION (as batches are processed)")
print("-" * 70)
print(f"   pytorch-cka: Memory stays flat (~{pytorch_memory_timeline[-1]:.1f} MB regardless of batch count)")
print(f"   torch-cka:   Memory grows with batches ({torch_memory_timeline[0]:.1f} MB -> {torch_memory_timeline[-1]:.1f} MB)")
growth_rate = (torch_memory_timeline[-1] - torch_memory_timeline[0]) / (batch_counts[-1] - batch_counts[0])
print(f"   torch-cka growth rate: ~{growth_rate:.1f} MB per batch")

print("\n2. MEMORY AFTER COMPUTATION (retained, not deallocated)")
print("-" * 70)
print(f"   pytorch-cka: {avg_pytorch_retained:.2f} MB (properly deallocates)")
print(f"   torch-cka:   {avg_torch_retained:.2f} MB (retains memory)")
if avg_pytorch_retained > 0:
    print(f"   Ratio: torch-cka retains {avg_torch_retained / max(avg_pytorch_retained, 0.01):.1f}x more memory")

print("\n3. PEAK MEMORY")
print("-" * 70)
print(f"   pytorch-cka: {avg_pytorch_peak:.2f} MB")
print(f"   torch-cka:   {avg_torch_peak:.2f} MB")

print("\n4. DETAILED RESULTS BY BATCH SIZE")
print("-" * 70)
print(f"{'Batch':<8} {'pytorch-cka':<24} {'torch-cka':<24}")
print(f"{'Size':<8} {'Peak':<12} {'Retained':<12} {'Peak':<12} {'Retained':<12}")
for i, bs in enumerate(batch_sizes):
    print(f"{bs:<8} {pytorch_cka_peak[i]:<12.2f} {pytorch_cka_retained[i]:<12.2f} "
          f"{torch_cka_peak[i]:<12.2f} {torch_cka_retained[i]:<12.2f}")

print("\n" + "=" * 70)
print("CONCLUSION:")
print("- DURING computation: pytorch-cka maintains constant memory,")
print("  while torch-cka memory grows linearly with the number of batches.")
print("- AFTER computation: pytorch-cka deallocates GPU memory,")
print("  while torch-cka retains activations in memory.")
print("=" * 70)