In [1]:
import os
import torch
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from tqdm import tqdm
from collections import defaultdict

In [2]:
threshold = 1e-3

def spatial_activation_sparsity(activation, threshold=threshold):
    # Element-wise sparsity: count near-zero elements
    total_elements = activation.numel()
    zero_elements = (activation.abs() < threshold).sum().item()
    return zero_elements / total_elements


def channel_activation_sparsity(activation, threshold=threshold):
    # Mean activation per channel: shape [B, C, H, W] -> [B, C]
    per_channel_mean = activation.mean(dim=(2, 3))  # mean over H and W
    # Count how many channels have near-zero mean activation
    sparse_channels = (per_channel_mean.abs() < threshold).float().mean().item()
    return sparse_channels


def compute_spatial_activation_sparsity(spatial_maps, threshold=threshold):
    sparsity = {}
    for name, fmap in spatial_maps.items():  # fmap: [B, H, W]
        sparse_ratio = (fmap.abs() < threshold).float().mean().item()
        sparsity[name] = sparse_ratio
    return sparsity


def compute_channel_activation_sparsity(channel_maps, threshold=threshold):
    sparsity = {}
    for name, fmap in channel_maps.items():  # fmap: [B, C]
        sparse_ratio = (fmap.abs() < threshold).float().mean().item()
        sparsity[name] = sparse_ratio
    return sparsity


In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [4]:
dataset_path = "../aircraft_dataset"
batch_size = 16

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

test_path = os.path.join(dataset_path, f"test")
test_ds = ImageFolder(root=test_path, transform=transform)
test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False)

num_classes = len(test_ds.classes)

print(num_classes, len(test_ds))

70 1000


In [6]:
from instrumented_models.resnet_instrumented import ResNet50_InternalRepresentation

model = ResNet50_InternalRepresentation(num_classes=num_classes, pretrained=False)

weights = torch.load("resnet_epoch_25.pth", map_location=device)
weights = {"model." + k: v for k, v in weights.items()}
model.load_state_dict(weights)

model.to(device)
model.eval()

print(model.__class__.__name__)

num_params = sum(p.numel() for p in model.parameters())
print(f"model parameters: {num_params}")

ResNet50_InternalRepresentation
model parameters: 23651462


In [7]:
all_outputs = []
all_labels = []
all_spatial_maps = defaultdict(list)
all_channel_maps = defaultdict(list)

with torch.no_grad():
    for inputs, labels in tqdm(test_loader, desc="evaluating test set"):
        inputs = inputs.to(device)
        labels = labels.to(device)

        outputs = model(inputs)

        spatial_maps, channel_maps = model.get_activation_maps()

        for name in spatial_maps:
            all_spatial_maps[name].append(spatial_maps[name])
            all_channel_maps[name].append(channel_maps[name])

        all_outputs.append(outputs)
        all_labels.append(labels)

# Concatenate collected feature maps
for name in all_spatial_maps:
    all_spatial_maps[name] = torch.cat(all_spatial_maps[name], dim=0)
    all_channel_maps[name] = torch.cat(all_channel_maps[name], dim=0)

all_outputs = torch.cat(all_outputs, dim=0)
all_labels = torch.cat(all_labels, dim=0)

# Compute both spatial and channel sparsity
spatial_sparsity = compute_spatial_activation_sparsity(all_spatial_maps)
channel_sparsity = compute_channel_activation_sparsity(all_channel_maps)

print("\nSpatial Activation Sparsity:")
for k, v in spatial_sparsity.items():
    print(f"{k}:\t{v:.4f}")

print("\nChannel Activation Sparsity:")
for k, v in channel_sparsity.items():
    print(f"{k}:\t{v:.4f}")


evaluating test set: 100%|██████████| 63/63 [00:54<00:00,  1.16it/s]


Spatial Activation Sparsity:
conv1:	0.0054
layer1:	0.0000
layer2:	0.0000
layer3:	0.0000
layer4:	0.3247

Channel Activation Sparsity:
conv1:	0.0077
layer1:	0.0009
layer2:	0.0013
layer3:	0.0002
layer4:	0.2191



