In [1]:
import torch
from torchvision import transforms, models
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import os
from torchvision.datasets import ImageFolder
import tarfile
from tqdm.notebook import tqdm, trange
import scipy.special

# Define the transformations
transform = transforms.Compose([
    transforms.Resize(160),
    transforms.CenterCrop(160),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

# Paths for the Imagenette dataset
imagenette_root = "./data/imagenette2-160"
train_dir = os.path.join(imagenette_root, "train")
val_dir = os.path.join(imagenette_root, "val")

# download & extract if not already present
if not os.path.isdir(imagenette_root):
    os.makedirs(os.path.dirname(imagenette_root), exist_ok=True)
    url = "https://s3.amazonaws.com/fast-ai-imageclas/imagenette2-160.tgz"
    archive_path = os.path.join(os.path.dirname(imagenette_root), "imagenette2-160.tgz")
    if not os.path.exists(archive_path):
        print("Downloading Imagenette dataset...")
        import urllib.request
        urllib.request.urlretrieve(url, archive_path)
    print("Extracting Imagenette dataset...")
    with tarfile.open(archive_path, "r:gz") as tar:
        tar.extractall(path=os.path.dirname(imagenette_root))
    print("Done.")

# Load the test dataset
test_dataset = ImageFolder(val_dir, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# Function to center the kernel matrix
def centre_kernel(K):
    n = K.shape[0]
    unit = np.ones([n, n]) / n
    return K - unit.dot(K) - K.dot(unit) + unit.dot(K).dot(unit)

# Function to extract features from a model
def extract_features(model, dataloader, device):
    model.eval()
    features = []
    
    with torch.no_grad():
        for images, _ in tqdm(dataloader, desc="Extracting features"):
            images = torch.nn.functional.interpolate(
                images, size=(224,224), mode='bilinear', align_corners=False
            ).to(device)
            # Forward pass through the model
            feat = model(images)
            features.append(feat.cpu())
    
    return torch.cat(features)

# Function to create a normalized kernel matrix from features
def create_kernel_matrix(features):
    # Convert to numpy for easier manipulation
    features_np = features.numpy()
    
    # Normalize features
    norms = np.sqrt(np.sum(features_np**2, axis=1, keepdims=True))
    normed_features = features_np / norms
    
    # Create kernel matrix (dot product of normalized features)
    kernel_matrix = normed_features @ normed_features.T
    
    # Center the kernel
    centered_kernel = centre_kernel(kernel_matrix)
    
    return centered_kernel


Using device: cuda
Loading ResNet models...




Extracting features from ResNet models...


Extracting features:   0%|          | 0/62 [00:00<?, ?it/s]

Extracting features:   0%|          | 0/62 [00:00<?, ?it/s]

Extracting features:   0%|          | 0/62 [00:00<?, ?it/s]

Creating kernel matrices...
Sampling with temperature T=0.1


Sampling traces with T=0.1:   0%|          | 0/5000 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [2]:
# Load the three ResNet models
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

print("Loading ResNet models...")
resnet18 = models.resnet18(pretrained=True)
resnet18.fc = torch.nn.Identity()  # Remove classification layer
resnet18 = resnet18.to(device)

resnet50 = models.resnet50(pretrained=True)
resnet50.fc = torch.nn.Identity()  # Remove classification layer
resnet50 = resnet50.to(device)

resnet152 = models.resnet152(pretrained=True)
resnet152.fc = torch.nn.Identity()  # Remove classification layer
resnet152 = resnet152.to(device)

# Extract features for each model
print("Extracting features from ResNet models...")
features_resnet18 = extract_features(resnet18, test_loader, device)
features_resnet50 = extract_features(resnet50, test_loader, device)
features_resnet152 = extract_features(resnet152, test_loader, device)

# Create kernel matrices
print("Creating kernel matrices...")
kernel_resnet18 = create_kernel_matrix(features_resnet18)
kernel_resnet50 = create_kernel_matrix(features_resnet50)
kernel_resnet152 = create_kernel_matrix(features_resnet152)

# Use ResNet50's kernel as the base for the graph sampling (M)
M = kernel_resnet50

Using device: cuda
Loading ResNet models...
Extracting features from ResNet models...


Extracting features:   0%|          | 0/62 [00:00<?, ?it/s]

Extracting features:   0%|          | 0/62 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [4]:

# Function to sample Tr(KG) where G is sampled using sigmoid(M/T)
def sample_tr_kg(K, M, T=1.0, num_samples=1000):
    n = K.shape[0]
    
    # Compute edge probabilities using sigmoid(M_ij / T)
    edge_probs = scipy.special.expit(M / T)  # sigmoid function
    
    # Set diagonal to 0 probability (no self-loops)
    np.fill_diagonal(edge_probs, 0)
    
    # Pre-compute K * edge_probs for weighted sampling
    K_weighted = K * edge_probs
    
    # Use vectorized operations for sampling
    # Generate all random samples at once
    random_values = np.random.random((num_samples, n, n))
    
    # Create a progress bar
    with trange(num_samples, desc=f"Sampling traces with T={T}") as pbar:
        # Preallocate array for traces
        traces = np.zeros(num_samples)
        
        # Process in batches to avoid memory issues
        batch_size = 50
        for i in range(0, num_samples, batch_size):
            end_idx = min(i + batch_size, num_samples)
            batch_count = end_idx - i
            
            # Generate batch of adjacency matrices
            G_batch = (random_values[i:end_idx] < edge_probs).astype(np.float32)
            
            # Compute traces for the batch
            traces[i:end_idx] = np.sum(K * G_batch, axis=(1, 2))
            
            pbar.update(batch_count)
    
    return traces


# Sample Tr(KG) for each kernel with fixed temperature
T = 0.1  # Temperature parameter
print(f"Sampling with temperature T={T}")
traces_resnet18 = sample_tr_kg(kernel_resnet18, M, T=T, num_samples=256)
traces_resnet50 = sample_tr_kg(kernel_resnet50, M, T=T, num_samples=256)
traces_resnet152 = sample_tr_kg(kernel_resnet152, M, T=T, num_samples=256)

# Create KDE plots
plt.figure(figsize=(12, 8))
sns.kdeplot(traces_resnet18, label="ResNet-18", fill=True, alpha=0.3)
sns.kdeplot(traces_resnet50, label="ResNet-50", fill=True, alpha=0.3)
sns.kdeplot(traces_resnet152, label="ResNet-152", fill=True, alpha=0.3)

plt.title(f"Distribution of Tr(KG) for Different ResNet Models (T={T})", fontsize=16)
plt.xlabel("Tr(KG)", fontsize=14)
plt.ylabel("Density", fontsize=14)
plt.legend(fontsize=12)
plt.grid(True, linestyle='--', alpha=0.7)

# Add statistics to the plot
stats_text = []
for name, traces in [("ResNet-18", traces_resnet18), 
                     ("ResNet-50", traces_resnet50), 
                     ("ResNet-152", traces_resnet152)]:
    mean = np.mean(traces)
    std = np.std(traces)
    stats_text.append(f"{name}: Mean = {mean:.2f}, Std = {std:.2f}")

plt.figtext(0.5, 0.01, "\n".join(stats_text), ha="center", fontsize=12, 
            bbox={"facecolor":"white", "alpha":0.8, "pad":5})

plt.tight_layout(rect=[0, 0.05, 1, 0.95])
plt.savefig(f"resnet_trace_distributions_T{T}.png", dpi=300)
plt.show()

# Create a DataFrame for easier analysis
trace_df = pd.DataFrame({
    'ResNet-18': traces_resnet18,
    'ResNet-50': traces_resnet50,
    'ResNet-152': traces_resnet152
})

# Summary statistics
print("Summary Statistics:")
print(trace_df.describe())

# Save the raw data
np.savez(f"resnet_trace_data_T{T}.npz", 
         resnet18=traces_resnet18, 
         resnet50=traces_resnet50, 
         resnet152=traces_resnet152)

Sampling with temperature T=0.1


MemoryError: Unable to allocate 29.4 GiB for an array with shape (256, 3925, 3925) and data type float64