In [1]:
# ================================================================
# VGG-16 on CIFAR-10  +  SVD-based Filter-Importance Ranking
# ================================================================
# ① install deps on Colab (comment-out when running locally)
# !pip install --quiet torch torchvision tqdm

import os, math, json, random, pathlib, torch, torch.nn as nn, torch.optim as optim
import torchvision, torchvision.transforms as T
from torch.utils.data import DataLoader
from tqdm import tqdm

# ----------------------------- CONFIG ---------------------------
NUM_EPOCHS        = 20
BATCH_SIZE        = 128
LR                = 0.1
DEVICE            = 'cuda' if torch.cuda.is_available() else 'cpu'
RANDOM_SEED       = 42

# output files
CKPT_PATH         = "model/vgg16_trained.pt"
SVD_JSON_PATH     = "filter_importance.json"

# reproducibility
torch.manual_seed(RANDOM_SEED)
random.seed(RANDOM_SEED)
os.makedirs("model", exist_ok=True)

# -------------------------- DATASET -----------------------------
transform_train = T.Compose([
    T.RandomHorizontalFlip(),
    T.RandomCrop(32, padding=4),
    T.ToTensor(),
    T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
transform_test  = T.Compose([
    T.ToTensor(),
    T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,  download=True, transform=transform_train)
testset  = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)

train_loader = DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True,  num_workers=2, pin_memory=True)
test_loader  = DataLoader(testset,  batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)

# --------------------------- MODEL ------------------------------
def make_vgg16():
    vgg = torchvision.models.vgg16_bn(pretrained=False)
    vgg.features[0] = nn.Conv2d(3, 64, kernel_size=3, padding=1)   # CIFAR adaptation
    vgg.avgpool = nn.AdaptiveAvgPool2d((1, 1))
    vgg.classifier = nn.Sequential(
        nn.Flatten(),
        nn.Linear(512, 512), nn.ReLU(True), nn.Dropout(),
        nn.Linear(512, 512), nn.ReLU(True), nn.Dropout(),
        nn.Linear(512, 10)
    )
    return vgg

model = make_vgg16().to(DEVICE)

# ---------------------- TRAIN / EVAL ----------------------------
def accuracy(net, loader):
    net.eval()
    correct = total = 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(DEVICE), y.to(DEVICE)
            preds = net(x).argmax(1)
            correct += (preds == y).sum().item()
            total   += y.size(0)
    return 100. * correct / total

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=LR, momentum=0.9, weight_decay=5e-4)
scheduler = optim.lr_scheduler.MultiStepLR(
    optimizer, milestones=[NUM_EPOCHS//2, int(NUM_EPOCHS*0.75)], gamma=0.1)

print("⏳ Training VGG-16 for 20 epochs …")
for epoch in range(NUM_EPOCHS):
    model.train()
    pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{NUM_EPOCHS}')
    for imgs, lbls in pbar:
        imgs, lbls = imgs.to(DEVICE), lbls.to(DEVICE)
        optimizer.zero_grad(set_to_none=True)
        loss = criterion(model(imgs), lbls)
        loss.backward()
        optimizer.step()
        pbar.set_postfix({'loss': f'{loss.item():.3f}'})
    scheduler.step()

torch.save(model.state_dict(), CKPT_PATH)
test_acc = accuracy(model, test_loader)
print(f'\n✅ Finished. Test accuracy after 20 epochs: **{test_acc:.2f} %** (checkpoint saved to {CKPT_PATH})')

# --------------- SVD-BASED FILTER IMPORTANCE (all filters) -------
@torch.no_grad()
def svd_rank_filters(net, json_path=SVD_JSON_PATH, preview_top_k=1):
    """
    For each Conv2d layer:
      • flatten each filter to (in_c, k*k)
      • compute its singular values
      • importance score = sum of singular values (nuclear norm)
    Writes *all* filters to JSON.  Prints the top-k per layer for a quick look.
    """
    conv_layers = [m for m in net.modules() if isinstance(m, nn.Conv2d)]
    importance_records = []

    print("\nFilter-importance preview (top-{} per layer):".format(preview_top_k))
    for L, layer in enumerate(conv_layers):
        layer_scores = []
        for f_idx in range(layer.weight.size(0)):
            w = layer.weight[f_idx].detach().cpu().reshape(layer.weight.size(1), -1)
            svals = torch.linalg.svdvals(w)
            score = svals.sum().item()
            layer_scores.append((f_idx, score))

            # store every filter
            importance_records.append({
                "layer"        : L,
                "filter_index" : f_idx,
                "svd_score"    : score
            })

        # console preview
        layer_scores.sort(key=lambda x: x[1], reverse=True)
        for rank, (f_idx, score) in enumerate(layer_scores[:preview_top_k], start=1):
            print(f"  Layer {L:2d}  Rank {rank} ➜ filter {f_idx:3d}   Σσ = {score:.4f}")

    with open(json_path, "w") as fp:
        json.dump(importance_records, fp, indent=2)
    print(f"\nFull JSON written to: {json_path}")


svd_rank_filters(model)



100%|██████████| 170M/170M [00:03<00:00, 43.9MB/s]


⏳ Training VGG-16 for 20 epochs …


Epoch 1/20: 100%|██████████| 391/391 [00:30<00:00, 12.80it/s, loss=1.846]
Epoch 2/20: 100%|██████████| 391/391 [00:26<00:00, 14.76it/s, loss=1.411]
Epoch 3/20: 100%|██████████| 391/391 [00:26<00:00, 14.73it/s, loss=1.485]
Epoch 4/20: 100%|██████████| 391/391 [00:26<00:00, 14.95it/s, loss=1.238]
Epoch 5/20: 100%|██████████| 391/391 [00:26<00:00, 14.99it/s, loss=1.071]
Epoch 6/20: 100%|██████████| 391/391 [00:26<00:00, 14.95it/s, loss=1.000]
Epoch 7/20: 100%|██████████| 391/391 [00:25<00:00, 15.28it/s, loss=1.085]
Epoch 8/20: 100%|██████████| 391/391 [00:25<00:00, 15.15it/s, loss=0.833]
Epoch 9/20: 100%|██████████| 391/391 [00:25<00:00, 15.19it/s, loss=0.734]
Epoch 10/20: 100%|██████████| 391/391 [00:26<00:00, 14.76it/s, loss=0.933]
Epoch 11/20: 100%|██████████| 391/391 [00:27<00:00, 14.37it/s, loss=0.348]
Epoch 12/20: 100%|██████████| 391/391 [00:26<00:00, 14.99it/s, loss=0.451]
Epoch 13/20: 100%|██████████| 391/391 [00:25<00:00, 15.05it/s, loss=0.379]
Epoch 14/20: 100%|██████████| 391/


✅ Finished. Test accuracy after 20 epochs: **88.09 %** (checkpoint saved to model/vgg16_trained.pt)

📊 Per-layer most-important filter (highest Σσ):
  Layer  0 ➜ filter  18  (Σσ = 3.5622)
  Layer  1 ➜ filter  61  (Σσ = 4.5687)
  Layer  2 ➜ filter 124  (Σσ = 4.0786)
  Layer  3 ➜ filter  14  (Σσ = 4.7922)
  Layer  4 ➜ filter  48  (Σσ = 4.2364)
  Layer  5 ➜ filter 224  (Σσ = 3.1613)
  Layer  6 ➜ filter  27  (Σσ = 3.3195)
  Layer  7 ➜ filter 195  (Σσ = 1.5316)
  Layer  8 ➜ filter 296  (Σσ = 1.2354)
  Layer  9 ➜ filter  49  (Σσ = 1.0492)
  Layer 10 ➜ filter 247  (Σσ = 0.9704)
  Layer 11 ➜ filter  80  (Σσ = 1.0230)
  Layer 12 ➜ filter 311  (Σσ = 1.7227)

📝 Full JSON written to: filter_importance.json


In [2]:
# --------------- SVD-BASED FILTER IMPORTANCE (all filters) -------
@torch.no_grad()
def svd_rank_filters(net, json_path=SVD_JSON_PATH, preview_top_k=1):
    """
    For each Conv2d layer:
      • flatten each filter to (in_c, k*k)
      • compute its singular values
      • importance score = sum of singular values (nuclear norm)
    Writes *all* filters to JSON.  Prints the top-k per layer for a quick look.
    """
    conv_layers = [m for m in net.modules() if isinstance(m, nn.Conv2d)]
    importance_records = []

    print("\nFilter-importance preview (top-{} per layer):".format(preview_top_k))
    for L, layer in enumerate(conv_layers):
        layer_scores = []
        for f_idx in range(layer.weight.size(0)):
            w = layer.weight[f_idx].detach().cpu().reshape(layer.weight.size(1), -1)
            svals = torch.linalg.svdvals(w)
            score = svals.sum().item()
            layer_scores.append((f_idx, score))

            # store every filter
            importance_records.append({
                "layer"        : L,
                "filter_index" : f_idx,
                "svd_score"    : score
            })

        # console preview
        layer_scores.sort(key=lambda x: x[1], reverse=True)
        for rank, (f_idx, score) in enumerate(layer_scores[:preview_top_k], start=1):
            print(f"  Layer {L:2d}  Rank {rank} ➜ filter {f_idx:3d}   Σσ = {score:.4f}")

    with open(json_path, "w") as fp:
        json.dump(importance_records, fp, indent=2)
    print(f"\nFull JSON written to: {json_path}")


svd_rank_filters(model)



Filter-importance preview (top-1 per layer):
  Layer  0  Rank 1 ➜ filter  18   Σσ = 3.5622
  Layer  1  Rank 1 ➜ filter  61   Σσ = 4.5687
  Layer  2  Rank 1 ➜ filter 124   Σσ = 4.0786
  Layer  3  Rank 1 ➜ filter  14   Σσ = 4.7922
  Layer  4  Rank 1 ➜ filter  48   Σσ = 4.2364
  Layer  5  Rank 1 ➜ filter 224   Σσ = 3.1613
  Layer  6  Rank 1 ➜ filter  27   Σσ = 3.3195
  Layer  7  Rank 1 ➜ filter 195   Σσ = 1.5316
  Layer  8  Rank 1 ➜ filter 296   Σσ = 1.2354
  Layer  9  Rank 1 ➜ filter  49   Σσ = 1.0492
  Layer 10  Rank 1 ➜ filter 247   Σσ = 0.9704
  Layer 11  Rank 1 ➜ filter  80   Σσ = 1.0230
  Layer 12  Rank 1 ➜ filter 311   Σσ = 1.7227

Full JSON written to: filter_importance.json
