In [2]:
import torch
from tqdm import tqdm
# add path
import sys
sys.path.append('../../../MVCNN/')
from models import MVCNN
from tools import ImgDataset
import numpy as np
import torch.nn.functional as F
from typing import Optional

# Calculating Importance

## Initialization

In [2]:
device = 'mps' if torch.backends.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
model = MVCNN.SVCNN('mvcnn')
weights = torch.load('../../../MVCNN/MVCNN/model-mvcnn-00050.pth', map_location=device)
model.load_state_dict(weights)
model.to(device)
model.eval()

SVCNN(
  (net_1): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (11): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (12): ReLU(inplace=True)
    (13): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (14): ReLU(inplace=True)
    (15): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  

In [4]:
# Extract CNN feature extractor from the model
feature_extractor = model.net_1
classifier = model.net_2

feature_extractor.eval()
classifier.eval()

num_views = 12  # Number of views per model
num_classes = 33

## Accuracy Calculator

In [5]:
# global or at top of script
view_order = [0, 1, 10, 11, 2, 3, 4, 5, 6, 7, 8, 9]

In [6]:
def validate_model(
    model: torch.nn.Module,
    loader: torch.utils.data.DataLoader,
    num_classes: int,
    device: torch.device,
    single_view: bool = False,
    view_idx: int = 0,
    drop_view_label: Optional[int] = None,
):
    """
    - single_view: evaluate on only loader batch[:, view_idx]
    - drop_view_label: the *label* of the view to drop (e.g. 5 → slice_idx 7)
                       if None, no views are dropped.
    """

    # map from semantic label → tensor‐slice index
    if drop_view_label is not None:
        assert drop_view_label in view_order, f"{drop_view_label=} not in view_order"
        drop_slice = view_order.index(drop_view_label)
    else:
        drop_slice = None

    model.eval()
    total_correct = total_samples = 0
    total_loss = 0.0
    wrong_per_class = np.zeros(num_classes, dtype=int)
    samples_per_class = np.zeros(num_classes, dtype=int)

    pbar = tqdm(loader, desc="Validating", unit="batch", leave=False, dynamic_ncols=True)
    for batch_i, data in enumerate(pbar):
        labels = data[0].to(device)
        views  = data[1].to(device)              # (N, 12, C, H, W)
        N, V, C, H, W = views.shape

        # drop the semantic view
        if drop_slice is not None:
            assert 0 <= drop_slice < V
            keep = [i for i in range(V) if i != drop_slice]
            views = views[:, keep]
            V -= 1

        # single-view path
        if single_view:
            assert 0 <= view_idx < V
            x   = views[:, view_idx]           # (N, C, H, W)
            tgt = labels                       # (N,)
            with torch.no_grad():
                out   = model(x)
                loss  = F.cross_entropy(out, tgt).item()
                preds = out.argmax(1)
        # full-MVCNN path
        else:
            flat = views.reshape(-1, C, H, W)       # (N*V, C, H, W)
            tgt  = labels.repeat_interleave(V, 0)   # (N*V,)
            with torch.no_grad():
                out   = model(flat)
                loss  = F.cross_entropy(out, tgt).item()
                preds = out.argmax(1).cpu().numpy() # flatten

        total_loss += loss

        if single_view:
            correct_mask = (preds == tgt).cpu().numpy()
            batch_correct = correct_mask.sum()
            for i, ok in enumerate(correct_mask):
                cls = tgt[i].item()
                samples_per_class[cls] += 1
                if not ok:
                    wrong_per_class[cls] += 1
        else:
            preds = preds.reshape(N, V)    # (N, V)
            voted = np.array([np.bincount(preds[i]).argmax() for i in range(N)])
            gts   = labels.cpu().numpy()
            batch_correct = (voted == gts).sum()
            for i in range(N):
                cls = gts[i]
                samples_per_class[cls] += 1
                if voted[i] != cls:
                    wrong_per_class[cls] += 1

        total_correct += batch_correct
        total_samples += N

        acc = batch_correct / N
        avg_loss = total_loss / (batch_i + 1)
        pbar.set_postfix(acc=f"{acc:.4f}", loss=f"{avg_loss:.4f}")

    pbar.close()

    overall_acc = total_correct / total_samples
    per_cls_acc = (samples_per_class - wrong_per_class) / np.maximum(samples_per_class, 1)
    mean_cls_acc = per_cls_acc[samples_per_class > 0].mean()

    drop_msg = f", dropped view {drop_view_label}" if drop_view_label is not None else ""
    mode     = "single-view" if single_view else "full-mvcnn"
    print(f"\n[{mode}{drop_msg}] Overall Acc: {overall_acc:.4f}   Mean Class Acc: {mean_cls_acc:.4f}")

    return overall_acc, mean_cls_acc

## Dataset Initializer

### MVCNN Dataset

In [14]:
test_dataset_mvcnn = ImgDataset.MultiviewImgDataset(
    root_dir='../../../MVCNN/ModelNet40-12View/*/test',
    scale_aug=False,
    rot_aug=False,
    test_mode=True,
    num_models=0,
    num_views=12,
)
test_loader_mvcnn = torch.utils.data.DataLoader(
    test_dataset_mvcnn,
    batch_size=8,
    shuffle=False,
    num_workers=4,
    pin_memory=True
)

### SVCNN Dataset

In [8]:
test_dataset_svcnn = ImgDataset.SingleImgDataset(
    root_dir='../../../MVCNN/ModelNet40-12View/*/test',
    scale_aug=False,
    rot_aug=False,
    test_mode=True,
    num_models=0,
)
test_loader_svcnn = torch.utils.data.DataLoader(
    test_dataset_svcnn,
    batch_size=32,
    shuffle=False,
    num_workers=4,
    pin_memory=True
)

## Baseline Accuracy

In [9]:
validate_model(model = model, loader=test_loader_mvcnn, num_classes=num_classes, device=torch.device('mps'), single_view=False, drop_view_label=None)

                                                                                         


[full-mvcnn] Overall Acc: 0.8919   Mean Class Acc: 0.8494




(np.float64(0.8919354838709678), np.float64(0.8493939393939394))

## Accuracy Of Each View

In [17]:
for view_idx in range(12):
    validate_model(
        model=model,
        loader=test_loader_mvcnn,
        single_view=True,
        view_idx=view_idx,
        drop_view_label=None,
        device=torch.device('mps'),
        num_classes=num_classes
    )

                                                                                         


[single-view] Overall Acc: 0.7505   Mean Class Acc: 0.6958


                                                                                         


[single-view] Overall Acc: 0.7984   Mean Class Acc: 0.7518


                                                                                         


[single-view] Overall Acc: 0.8371   Mean Class Acc: 0.7894


                                                                                         


[single-view] Overall Acc: 0.8312   Mean Class Acc: 0.7764


                                                                                         


[single-view] Overall Acc: 0.8000   Mean Class Acc: 0.7636


                                                                                         


[single-view] Overall Acc: 0.7333   Mean Class Acc: 0.7042


                                                                                         


[single-view] Overall Acc: 0.8070   Mean Class Acc: 0.7688


                                                                                         


[single-view] Overall Acc: 0.8086   Mean Class Acc: 0.7479


                                                                                         


[single-view] Overall Acc: 0.7844   Mean Class Acc: 0.7221


                                                                                         


[single-view] Overall Acc: 0.8323   Mean Class Acc: 0.7794


                                                                                         


[single-view] Overall Acc: 0.8403   Mean Class Acc: 0.7924


                                                                                         


[single-view] Overall Acc: 0.8011   Mean Class Acc: 0.7776




## Accuracy While Removing Views

#### Use Multi View Dataset

In [None]:
for view_idx in range(12):
    validate_model(
        model=model,
        loader=test_loader_mvcnn,
        single_view=False,
        view_idx=view_idx,
        drop_view_label=view_idx,
        device=torch.device('mps'),
        num_classes=num_classes
    )

                                                                                         


[full-mvcnn, dropped view 0] Overall Acc: 0.8919   Mean Class Acc: 0.8458


                                                                                         


[full-mvcnn, dropped view 1] Overall Acc: 0.8919   Mean Class Acc: 0.8506


                                                                                         


[full-mvcnn, dropped view 10] Overall Acc: 0.8855   Mean Class Acc: 0.8373


                                                                                         


[full-mvcnn, dropped view 11] Overall Acc: 0.8882   Mean Class Acc: 0.8412


                                                                                         


[full-mvcnn, dropped view 2] Overall Acc: 0.8887   Mean Class Acc: 0.8427


                                                                                         


[full-mvcnn, dropped view 3] Overall Acc: 0.8898   Mean Class Acc: 0.8433


                                                                                         


[full-mvcnn, dropped view 4] Overall Acc: 0.8919   Mean Class Acc: 0.8458


                                                                                         


[full-mvcnn, dropped view 5] Overall Acc: 0.8946   Mean Class Acc: 0.8497


                                                                                         


[full-mvcnn, dropped view 6] Overall Acc: 0.8930   Mean Class Acc: 0.8512


                                                                                         


[full-mvcnn, dropped view 7] Overall Acc: 0.8925   Mean Class Acc: 0.8497


                                                                                         


[full-mvcnn, dropped view 8] Overall Acc: 0.8909   Mean Class Acc: 0.8439


                                                                                         


[full-mvcnn, dropped view 9] Overall Acc: 0.8887   Mean Class Acc: 0.8403




View	| Baseline Accuracy |	Accuracy of View	| Accuracy when View Removed |	Delta Per View	| Delta Drop View | 	Importance	| Importance Norm
-------|-------------------|------------------|-----------------------------|------------------|-----------------|-------------|----------------
0	|84.94%	|69.58%	|84.58%	|0.1536000	|0.0036000	|0.0234375	|0.0308751
1	|84.94%	|75.18%	|85.06%	|0.0976000	|0.0012000	|0.0122951	|0.0161968
2|	84.94%	|76.36%	|84.27%	|0.0858000	|0.0067000	|0.0780886	|0.1028691
3|	84.94%	|70.42%	|84.33%	|0.1452000	|0.0061000	|0.0420110	|0.0553428
4|	84.94%	|76.88%	|84.58%	|0.0806000	|0.0036000	|0.0446650	|0.0588390
5|	84.94%	|74.79%	|84.97%	|0.1015000	|0.0003000	|0.0029557	|0.0038936
6|	84.94%	|72.21%	|85.12%	|0.1273000	|0.0018000	|0.0141398	|0.0186269
7|	84.94%	|77.94%	|84.97%	|0.0700000	|0.0003000	|0.0042857	|0.0056457
8|	84.94%	|79.24%	|84.39%	|0.0570000	|0.0055000	|0.0964912	|0.1271117
9|	84.94%	|77.76%	|84.03%	|0.0718000	|0.0091000	|0.1267409	|0.1669608
10|	84.94%	|78.94%	|83.73%	|0.0600000	|0.0121000	|0.2016667	|0.2656634
11|	84.94%	|77.64%	|84.12%	|0.0730000	|0.0082000	|0.1123288	|0.1479751

In [8]:
I = torch.tensor([
    0.0308751,
    0.0161968,
    0.1028691,
    0.0553428,
    0.0588390,
    0.0038936,
    0.0186269,
    0.0056457,
    0.1271117,
    0.1669608,
    0.2656634,
    0.1479751
]).to('mps')

In [9]:
I

tensor([0.0309, 0.0162, 0.1029, 0.0553, 0.0588, 0.0039, 0.0186, 0.0056, 0.1271,
        0.1670, 0.2657, 0.1480], device='mps:0')