In [1]:
from models import MVCNN
from tools import ImgDataset
import torch
import numpy as np
import csv
from tqdm import tqdm

In [2]:
device = 'mps' if torch.backends.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
model = MVCNN.SVCNN('svcnn')
weights = torch.load('../test_results/MVCNN/trained-models/MVCNN/model-00030.pth', map_location=device)
model.load_state_dict(weights)
model.to(device)
model.eval()

SVCNN(
  (net_1): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (11): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (12): ReLU(inplace=True)
    (13): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (14): ReLU(inplace=True)
    (15): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  

In [4]:
test_dataset = ImgDataset.MultiviewImgDataset(
    root_dir='../searching-algorithm/ModelNet40-12View/*/test',
    scale_aug=False,
    rot_aug=False,
    test_mode=True,
    num_models=0,
    num_views=12,
    shuffle=False
)
test_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=1,  # Set batch_size to 1 to get all 12 views of a single model
    shuffle=False,
    num_workers=1,
    pin_memory=True
)

In [None]:
# Extract CNN feature extractor from the model
feature_extractor = model.net_1
classifier = model.net_2

feature_extractor.eval()
classifier.eval()

num_views = 12  # Number of views per model

views are as

0, 1, 10, 11, 2, 3, 4, 5, 6, 7, 8, 9

In [None]:
ranking = {}
sum_baseline = 0

In [None]:
ranking = {}

# Wrap test_loader with tqdm
for batch_idx, (target, data, _) in enumerate(tqdm(test_loader, desc="Batches")):
    
    data = data.to(device)  # shape: [1, 12, 3, 244, 244]
    target = target.to(device)

    all_views = data.squeeze(0)  # shape: [12, 3, 244, 244]

    # === Baseline ===
    all_feats = feature_extractor(all_views)  # shape: [12, C, H, W]
    pooled = torch.max(all_feats, dim=0)[0]  # shape: [C, H, W]
    pooled = pooled.unsqueeze(0)  # shape: [1, C, H, W]
    flattened = pooled.view(pooled.size(0), -1)  # shape: [1, C*H*W]
    with torch.no_grad():
        output = classifier(flattened)
    _, predicted = torch.max(output, 1)
    baseline_acc = (predicted == target).float().item()

    # if batch_idx % 100 == 0:
        # print(f"Processed {batch_idx} batches, current average baseline accuracy: {sum_baseline / (batch_idx + 1):.4f}")

    # === Per-View Importance ===
    for view in tqdm(range(12), desc=f"View removal batch {batch_idx}", leave=False):
        views_removed = torch.cat([all_views[:view], all_views[view+1:]], dim=0)
        feats = feature_extractor(views_removed)
        pooled = torch.max(feats, dim=0)[0].unsqueeze(0)
        flattened = pooled.view(pooled.size(0), -1)
        with torch.no_grad():
            output = classifier(flattened)
        _, predicted = torch.max(output, 1)
        acc = (predicted == target).float().item()

        drop = baseline_acc - acc
        
        if view in ranking:
            ranking[view] += drop
        else:
            ranking[view] = drop

In [143]:
ranking = {
    0: -5.0,
    10: -7.0,
    6: -12.0,
    7: -14.0,
    1: -19.0,
    3: -19.0,
    9: -21.0,
    5: -23.0,
    8: -25.0,
    2: -27.0,
    4: -30.0,
    11: -33.0
}
# more -ve, less important

In [144]:
#normalize the importance
norm = max(ranking.values())
for view in ranking:
    ranking[view] = -1/(abs(ranking[view]) / norm)

In [145]:
# Sort and build a new dict in descending‐drop order
ranking = dict(sorted(ranking.items(), key=lambda item: item[1], reverse=True))

In [146]:
ranking

{0: 1.0,
 10: 0.7142857142857143,
 6: 0.4166666666666667,
 7: 0.35714285714285715,
 1: 0.2631578947368421,
 3: 0.2631578947368421,
 9: 0.23809523809523808,
 5: 0.2173913043478261,
 8: 0.2,
 2: 0.18518518518518517,
 4: 0.16666666666666666,
 11: 0.15151515151515152}