In [1]:
import torch
import torch.nn as nn
import torchvision
import torchvision.models as models 
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
max_batches = 1 # if set to 0 block wont run so when make 1 have an correct answer (if not)
# DEVICE
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ------------------- 1. RECONSTRUCT THE MODEL ---------------------
model = models.vgg16_bn(pretrained=False)
model.avgpool = nn.AdaptiveAvgPool2d((1, 1))
model.classifier = nn.Sequential(
    nn.Flatten(),
    nn.Linear(512, 512),
    nn.ReLU(True),
    nn.Dropout(),
    nn.Linear(512, 512),
    nn.ReLU(True),
    nn.Dropout(),
    nn.Linear(512, 10)
)
model.to(DEVICE)

# ------------------- 2. LOAD STATE DICT ---------------------
state_dict = torch.load("model/baseline_vgg16.pt", map_location=DEVICE)
model.load_state_dict(state_dict)
model.eval()

# ------------------- 3. LOAD TEST DATA ----------------------
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

test_set = torchvision.datasets.CIFAR10(root='./data', train=False,
                                        download=True, transform=transform)
test_loader = DataLoader(test_set, batch_size=128, shuffle=False)

# ------------------- 4. EVALUATE ACCURACY ----------------------
def accuracy(model, loader):
    correct = total = 0
    with torch.no_grad():
        for i, (images, labels) in enumerate(test_loader):
            if i >= max_batches:
                print(images.device)

                break  # Stop after 10 batches
            
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            preds = model(images).argmax(1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    return 100. * correct / total

acc = accuracy(model, test_loader)
print(f"Accuracy: {acc:.2f}%")
print(next(model.parameters()).device)




  from .autonotebook import tqdm as notebook_tqdm


Files already downloaded and verified
cpu
Accuracy: 78.12%
cuda:0


In [2]:
# ---------- rank-all-conv-filters-layer_fmt.py ----------
import torch, json
from collections import defaultdict
from tqdm import tqdm

###############################################################################
# Config – tweak to taste
###############################################################################
BATCH_LIMIT   = 5                  # mini-batches to analyse
JSON_OUTFILE  = "filter_ranks.json"
DEVICE        = next(model.parameters()).device
LOADER        = test_loader        # or whichever DataLoader you want
###############################################################################

###############################################################################
# 1. Prepare accumulators and hooks
###############################################################################
layer_stats  = defaultdict(lambda: {"sum": None, "count": 0})
layer_order  = []                  # keeps the order convs are visited

def make_hook(layer_key: str):
    def hook(_module, _input, output):
        out = output.detach().cpu()
        N, C = out.shape[:2]

        # rank per (sample, filter)
        ranks = torch.empty(N, C, dtype=torch.float32)
        for n in range(N):
            for c in range(C):
                ranks[n, c] = torch.linalg.matrix_rank(out[n, c]).item()

        batch_sum = ranks.sum(dim=0)

        acc = layer_stats[layer_key]
        acc["sum"]   = batch_sum if acc["sum"] is None else acc["sum"] + batch_sum
        acc["count"] += N
    return hook

# register hooks on every Conv2d -- capture insertion order
handles = []
for name, m in model.named_modules():
    if isinstance(m, torch.nn.Conv2d):
        layer_order.append(name)               # remember vis-order
        handles.append(m.register_forward_hook(make_hook(name)))

###############################################################################
# 2. Push a few batches through the net
###############################################################################
model.eval()
with torch.no_grad():
    for i, (x, _) in enumerate(tqdm(LOADER, total=BATCH_LIMIT, desc="Analysing")):
        if i >= BATCH_LIMIT:
            break
        model(x.to(DEVICE, non_blocking=True))

###############################################################################
# 3. Build JSON in the format your reader wants
###############################################################################
json_dict = {}
for idx, key in enumerate(layer_order):
    stat = layer_stats[key]
    avg  = (stat["sum"] / stat["count"]).tolist()     # list of floats
    json_dict[f"layer_{idx}"] = {"avg_ranks": avg}    # <- EXACT field name

with open(JSON_OUTFILE, "w") as f:
    json.dump(json_dict, f, indent=2)
print(f"✓ JSON written to {JSON_OUTFILE}")

###############################################################################
# 4. Clean-up hooks
###############################################################################
for h in handles:
    h.remove()
# ------------------------------------------------


Analysing: 100%|██████████| 5/5 [02:10<00:00, 26.05s/it]

✓ JSON written to filter_ranks.json



