In [None]:
import torch
import torch.nn as nn
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader

# Device (use GPU if available)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load pretrained VGG11 with batch normalization
model = models.vgg11_bn(pretrained=True).to(device)
model.eval()

# Prepare CIFAR-10 dataset (resize to 224x224 to match ImageNet size)
transform = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],   # ImageNet normalization
                         std=[0.229, 0.224, 0.225])
])
dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
loader = DataLoader(dataset, batch_size=1, shuffle=True)

# Identify Conv2d layers in the model.features (in order)
conv_indices = [i for i, layer in enumerate(model.features) if isinstance(layer, nn.Conv2d)]
conv_modules = [model.features[i] for i in conv_indices]

# Prepare accumulators for rank sums per filter in each conv layer
conv_ranks_sum = []
for conv in conv_modules:
    num_filters = conv.out_channels
    conv_ranks_sum.append([0.0] * num_filters)

# Number of images to process (approx 100-200; here we use 100)
num_images = 100
count = 0

with torch.no_grad():
    for images, _ in loader:
        images = images.to(device)
        x = images
        conv_outputs = []
        # Forward through the feature layers, capturing conv outputs
        for layer in model.features:
            x = layer(x)
            if isinstance(layer, nn.Conv2d):
                conv_outputs.append(x)
        # Compute and accumulate ranks per filter for this image
        for layer_idx, feature_map in enumerate(conv_outputs):
            # feature_map shape: (1, C, H, W)
            feature_map = feature_map.cpu().squeeze(0)  # shape (C, H, W)
            C = feature_map.shape[0]
            for f in range(C):
                fm = feature_map[f]  # 2D tensor HxW
                # Compute matrix rank (SVD-based)
                rank = torch.linalg.matrix_rank(fm).item()
                conv_ranks_sum[layer_idx][f] += rank
        count += 1
        if count >= num_images:
            break

# Compute average ranks and print results per conv layer
for layer_idx, conv in enumerate(conv_modules, start=1):
    in_ch = conv.in_channels
    out_ch = conv.out_channels
    ranks = conv_ranks_sum[layer_idx-1]
    avg_ranks = [r / count for r in ranks]
    avg_ranks_tensor = torch.tensor(avg_ranks)
    max_val = torch.max(avg_ranks_tensor).item()
    min_val = torch.min(avg_ranks_tensor).item()
    max_indices = (avg_ranks_tensor == max_val).nonzero(as_tuple=True)[0].tolist()
    min_indices = (avg_ranks_tensor == min_val).nonzero(as_tuple=True)[0].tolist()
    print(f"Layer {layer_idx} (Conv2d {in_ch}->{out_ch}):")
    for idx, avg in enumerate(avg_ranks):
        print(f"  Filter {idx}: average rank {avg:.3f}")
    print(f"  --> Highest avg rank {max_val:.3f} at filter(s) {max_indices}")
    print(f"  --> Lowest  avg rank {min_val:.3f} at filter(s) {min_indices}\n")


Layer 1 (Conv2d 3->64):
  Filter 0: average rank 178.720
  Filter 1: average rank 173.460
  Filter 2: average rank 124.280
  Filter 3: average rank 127.980
  Filter 4: average rank 189.640
  Filter 5: average rank 99.580
  Filter 6: average rank 157.360
  Filter 7: average rank 93.300
  Filter 8: average rank 145.200
  Filter 9: average rank 94.200
  Filter 10: average rank 186.240
  Filter 11: average rank 177.620
  Filter 12: average rank 171.520
  Filter 13: average rank 189.420
  Filter 14: average rank 175.030
  Filter 15: average rank 183.750
  Filter 16: average rank 141.090
  Filter 17: average rank 169.850
  Filter 18: average rank 144.010
  Filter 19: average rank 141.170
  Filter 20: average rank 154.870
  Filter 21: average rank 187.820
  Filter 22: average rank 189.640
  Filter 23: average rank 147.800
  Filter 24: average rank 189.430
  Filter 25: average rank 107.600
  Filter 26: average rank 133.680
  Filter 27: average rank 154.340
  Filter 28: average rank 144.470
  F

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import json
import os

# Create output directory
os.makedirs("hrank_output", exist_ok=True)

# Store results in dictionary for JSON
results_json = {}

for layer_idx, conv in enumerate(conv_modules, start=1):
    in_ch = conv.in_channels
    out_ch = conv.out_channels
    ranks = conv_ranks_sum[layer_idx-1]
    avg_ranks = [r / count for r in ranks]

    # Save to results_json
    layer_name = f"Layer_{layer_idx}_Conv2d_{in_ch}to{out_ch}"
    results_json[layer_name] = {f"Filter_{i}": round(r, 3) for i, r in enumerate(avg_ranks)}

    # Plot heatmap (1 row × num_filters)
    plt.figure(figsize=(max(6, len(avg_ranks) // 4), 2))
    sns.heatmap([avg_ranks], cmap="viridis", cbar=True, xticklabels=True, yticklabels=False)
    plt.title(f"Avg Rank Heatmap: {layer_name}")
    plt.xlabel("Filter Index")
    plt.tight_layout()
    plt.savefig(f"hrank_output/{layer_name}_heatmap.png")
    plt.close()

    # Plot histogram
    plt.figure(figsize=(6, 4))
    plt.hist(avg_ranks, bins=10, color='skyblue', edgecolor='black')
    plt.title(f"Avg Rank Histogram: {layer_name}")
    plt.xlabel("Average Rank")
    plt.ylabel("Number of Filters")
    plt.tight_layout()
    plt.savefig(f"hrank_output/{layer_name}_histogram.png")
    plt.close()

# Save JSON file
with open("hrank_output/filter_ranks.json", "w") as f:
    json.dump(results_json, f, indent=4)

print(" All plots and JSON saved to 'hrank_output/'")

import shutil
shutil.make_archive("hrank_output", 'zip', "hrank_output")
from google.colab import files
files.download("hrank_output.zip")



✅ All plots and JSON saved to 'hrank_output/'


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>