In [13]:
import torch
from tqdm import tqdm
# add path
import sys
sys.path.append('../../../MVCNN')
from models import MVCNN
from tools import ImgDataset
import numpy as np
import torch.nn.functional as F
from typing import Optional

# Calculating Importance

## Initialization

In [14]:
device = 'mps' if torch.backends.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu'

In [15]:
model = MVCNN.SVCNN('mvcnn')
weights = torch.load('../../../MVCNN/MVCNN/model-mvcnn-00050.pth', map_location=device)
model.load_state_dict(weights)
model.to(device)
model.eval()

SVCNN(
  (net_1): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (11): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (12): ReLU(inplace=True)
    (13): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (14): ReLU(inplace=True)
    (15): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  

In [16]:
# Extract CNN feature extractor from the model
feature_extractor = model.net_1
classifier = model.net_2

feature_extractor.eval()
classifier.eval()

num_views = 12  # Number of views per model
num_classes = 33

## Accuracy Calculator

In [None]:
# global or at top of script
view_order = [0, 1, 10, 11, 2, 3, 4, 5, 6, 7, 8, 9]

In [28]:
def validate_model(
    model: torch.nn.Module,
    loader: torch.utils.data.DataLoader,
    num_classes: int,
    device: torch.device,
    single_view: bool = False,
    view_idx: int = 0,
    drop_view_label: Optional[int] = None,
):
    """
    - single_view: evaluate on only loader batch[:, view_idx]
    - drop_view_label: the *label* of the view to drop (e.g. 5 → slice_idx 7)
                       if None, no views are dropped.
    """

    # map from semantic label → tensor‐slice index
    if drop_view_label is not None:
        assert drop_view_label in view_order, f"{drop_view_label=} not in view_order"
        drop_slice = view_order.index(drop_view_label)
    else:
        drop_slice = None

    model.eval()
    total_correct = total_samples = 0
    total_loss = 0.0
    wrong_per_class = np.zeros(num_classes, dtype=int)
    samples_per_class = np.zeros(num_classes, dtype=int)

    pbar = tqdm(loader, desc="Validating", unit="batch", leave=False, dynamic_ncols=True)
    for batch_i, data in enumerate(pbar):
        labels = data[0].to(device)
        views  = data[1].to(device)              # (N, 12, C, H, W)
        N, V, C, H, W = views.shape

        # drop the semantic view
        if drop_slice is not None:
            assert 0 <= drop_slice < V
            keep = [i for i in range(V) if i != drop_slice]
            views = views[:, keep]
            V -= 1

        # single-view path
        if single_view:
            assert 0 <= view_idx < V
            x   = views[:, view_idx]           # (N, C, H, W)
            tgt = labels                       # (N,)
            with torch.no_grad():
                out   = model(x)
                loss  = F.cross_entropy(out, tgt).item()
                preds = out.argmax(1)
        # full-MVCNN path
        else:
            flat = views.reshape(-1, C, H, W)       # (N*V, C, H, W)
            tgt  = labels.repeat_interleave(V, 0)   # (N*V,)
            with torch.no_grad():
                out   = model(flat)
                loss  = F.cross_entropy(out, tgt).item()
                preds = out.argmax(1).cpu().numpy() # flatten

        total_loss += loss

        if single_view:
            correct_mask = (preds == tgt).cpu().numpy()
            batch_correct = correct_mask.sum()
            for i, ok in enumerate(correct_mask):
                cls = tgt[i].item()
                samples_per_class[cls] += 1
                if not ok:
                    wrong_per_class[cls] += 1
        else:
            preds = preds.reshape(N, V)    # (N, V)
            voted = np.array([np.bincount(preds[i]).argmax() for i in range(N)])
            gts   = labels.cpu().numpy()
            batch_correct = (voted == gts).sum()
            for i in range(N):
                cls = gts[i]
                samples_per_class[cls] += 1
                if voted[i] != cls:
                    wrong_per_class[cls] += 1

        total_correct += batch_correct
        total_samples += N

        acc = batch_correct / N
        avg_loss = total_loss / (batch_i + 1)
        pbar.set_postfix(acc=f"{acc:.4f}", loss=f"{avg_loss:.4f}")

    pbar.close()

    overall_acc = total_correct / total_samples
    per_cls_acc = (samples_per_class - wrong_per_class) / np.maximum(samples_per_class, 1)
    mean_cls_acc = per_cls_acc[samples_per_class > 0].mean()

    drop_msg = f", dropped view {drop_view_label}" if drop_view_label is not None else ""
    mode     = "single-view" if single_view else "full-mvcnn"
    print(f"\n[{mode}{drop_msg}] Overall Acc: {overall_acc:.4f}   Mean Class Acc: {mean_cls_acc:.4f}")

    return overall_acc, mean_cls_acc

## Dataset Initializer

### MVCNN Dataset

In [24]:
test_dataset_mvcnn = ImgDataset.MultiviewImgDataset(
    root_dir='../../../MVCNN/ModelNet40-12View/*/test',
    scale_aug=False,
    rot_aug=False,
    test_mode=True,
    num_models=0,
    num_views=12,
)
test_loader_mvcnn = torch.utils.data.DataLoader(
    test_dataset_mvcnn,
    batch_size=8,
    shuffle=False,
    num_workers=4,
    pin_memory=True
)

### SVCNN Dataset

In [25]:
test_dataset_svcnn = ImgDataset.SingleImgDataset(
    root_dir='../../../MVCNN/ModelNet40-12View/*/test',
    scale_aug=False,
    rot_aug=False,
    test_mode=True,
    num_models=0,
)
test_loader_svcnn = torch.utils.data.DataLoader(
    test_dataset_svcnn,
    batch_size=32,
    shuffle=False,
    num_workers=4,
    pin_memory=True
)

## Baseline Accuracy

In [27]:
validate_model(model = model, loader=test_loader_mvcnn, num_classes=num_classes, device=torch.device('mps'), single_view=False, drop_view_label=None)

                                                      

ValueError: too many values to unpack (expected 2)

## Accuracy Of Each View

In [22]:
for view_idx in range(12):
    validate_model(
        model,
        'mvcnn',
        test_loader_mvcnn,
        single_view=True,
        view_idx=view_idx,
        drop_view=None
    )

TypeError: validate_model() got an unexpected keyword argument 'drop_view'

## Accuracy While Removing Views

#### Use Multi View Dataset

In [None]:
for view_idx in range(12):
    validate_model(
        model,
        'mvcnn',
        test_loader_mvcnn,
        single_view=False,
        drop_view=view_idx
    )

Validation: 100%|██████████| 233/233 [01:39<00:00,  2.34batch/s, acc=1.0000, loss=1.3829]



Overall Acc: 0.9188   Mean Class Acc: 0.8767 when view 0 is dropped


Validation: 100%|██████████| 233/233 [01:37<00:00,  2.38batch/s, acc=1.0000, loss=1.3977]



Overall Acc: 0.9194   Mean Class Acc: 0.8806 when view 1 is dropped


Validation: 100%|██████████| 233/233 [01:39<00:00,  2.34batch/s, acc=1.0000, loss=1.3966]



Overall Acc: 0.9194   Mean Class Acc: 0.8794 when view 2 is dropped


Validation: 100%|██████████| 233/233 [01:38<00:00,  2.37batch/s, acc=1.0000, loss=1.3726]



Overall Acc: 0.9177   Mean Class Acc: 0.8773 when view 3 is dropped


Validation: 100%|██████████| 233/233 [01:37<00:00,  2.38batch/s, acc=1.0000, loss=1.3785]



Overall Acc: 0.9167   Mean Class Acc: 0.8767 when view 4 is dropped


Validation: 100%|██████████| 233/233 [01:38<00:00,  2.38batch/s, acc=1.0000, loss=1.3832]



Overall Acc: 0.9161   Mean Class Acc: 0.8764 when view 5 is dropped


Validation: 100%|██████████| 233/233 [01:37<00:00,  2.38batch/s, acc=1.0000, loss=1.3728]



Overall Acc: 0.9177   Mean Class Acc: 0.8761 when view 6 is dropped


Validation: 100%|██████████| 233/233 [01:40<00:00,  2.31batch/s, acc=1.0000, loss=1.3880]



Overall Acc: 0.9177   Mean Class Acc: 0.8785 when view 7 is dropped


Validation: 100%|██████████| 233/233 [01:41<00:00,  2.30batch/s, acc=1.0000, loss=1.3917]



Overall Acc: 0.9172   Mean Class Acc: 0.8758 when view 8 is dropped


Validation: 100%|██████████| 233/233 [01:38<00:00,  2.36batch/s, acc=1.0000, loss=1.3689]



Overall Acc: 0.9177   Mean Class Acc: 0.8773 when view 9 is dropped


Validation: 100%|██████████| 233/233 [01:37<00:00,  2.38batch/s, acc=1.0000, loss=1.3921]



Overall Acc: 0.9188   Mean Class Acc: 0.8791 when view 10 is dropped


Validation: 100%|██████████| 233/233 [01:39<00:00,  2.34batch/s, acc=1.0000, loss=1.4021]


Overall Acc: 0.9172   Mean Class Acc: 0.8770 when view 11 is dropped



