In [1]:
import torch
from tqdm import tqdm
# add path
import sys
sys.path.append('../../../MVCNN')
from models import MVCNN
from tools import ImgDataset
import numpy as np

# Calculating Importance

## Initialization

In [2]:
device = 'mps' if torch.backends.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
model = MVCNN.SVCNN('mvcnn')
weights = torch.load('../../../MVCNN/MVCNN/model-mvcnn-00050.pth', map_location=device)
model.load_state_dict(weights)
model.to(device)
model.eval()

SVCNN(
  (net_1): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (11): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (12): ReLU(inplace=True)
    (13): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (14): ReLU(inplace=True)
    (15): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  

In [4]:
# Extract CNN feature extractor from the model
feature_extractor = model.net_1
classifier = model.net_2

feature_extractor.eval()
classifier.eval()

num_views = 12  # Number of views per model
num_classes = 33

## Accuracy Calculator

In [5]:
# global or at top of script
view_order = [0, 1, 10, 11, 2, 3, 4, 5, 6, 7, 8, 9]

def get_include_indices(drop_view: int):
    # find the position of the unwanted view in your ordering
    drop_idx = view_order.index(drop_view)
    # build a list of all other positions
    return [i for i in range(len(view_order)) if i != drop_idx]

In [6]:
def validate_model(model,
                   model_name,
                   test_loader,
                   single_view: bool = False,
                   view_idx: int = 0,
                   drop_view = None):
    # precompute which positions to keep
    include_idxs = None if drop_view is None else get_include_indices(drop_view)
    model.eval()
    all_correct, all_samples, all_loss = 0, 0, 0.0
    wrong_class   = np.zeros(num_classes, dtype=int)
    samples_class = np.zeros(num_classes, dtype=int)

    pbar = tqdm(test_loader, desc='Validation', unit='batch', dynamic_ncols=True)
    for batch_i, data in enumerate(pbar):
        
        labels, views = data[0].to(device), data[1].to(device)
        
        # — if we’re dropping one view, slice it out
        if include_idxs is not None:
            # views[:, include_idxs, ...] → (N, V-1, C, H, W)
            views = views[:, include_idxs, :, :, :]

        if model_name=='mvcnn' and not single_view:
            # full‐vote branch
            N, V, C, H, W = views.shape
            x   = views.reshape(-1, C, H, W)                     # (N*V, C, H, W)
            tgt = labels.repeat_interleave(V, dim=0)             # (N*V,)
            out = model(x)
            preds = out.argmax(1)
            all_loss += torch.nn.functional.cross_entropy(out, tgt).item()

            # majority vote per object
            batch_correct = 0
            for i in range(N):
                vp = preds[i*V:(i+1)*V].cpu()
                voted = torch.mode(vp)[0]
                if voted == labels[i].cpu():
                    batch_correct += 1
                else:
                    wrong_class[labels[i].item()] += 1
                samples_class[labels[i].item()] += 1

            all_correct += batch_correct
            all_samples += N
            acc = batch_correct / N

        else:
            # single‐view branch
            # view_idx is now relative to the *new* V (after dropping)
            x   = views[:, view_idx, ...]    # (N, C, H, W)
            tgt = labels
            out = model(x)
            preds = out.argmax(1)
            all_loss += torch.nn.functional.cross_entropy(out, tgt).item()

            matches = preds == tgt
            for i, correct in enumerate(matches):
                samples_class[tgt[i].item()] += 1
                if not correct:
                    wrong_class[tgt[i].item()] += 1

            batch_correct = matches.sum().item()
            all_correct += batch_correct
            all_samples += tgt.size(0)
            acc = batch_correct / tgt.size(0)

        pbar.set_postfix({
            'acc':  f'{acc:.4f}',
            'loss': f'{(all_loss/(batch_i+1)):.4f}'
        })

    overall_acc = all_correct / all_samples
    per_cls     = (samples_class - wrong_class) / np.maximum(samples_class, 1)
    mean_cls    = per_cls[samples_class>0].mean()
    
    pbar.close()
    
    if single_view:
        print(f'\nOverall Acc: {overall_acc:.4f}   Mean Class Acc: {mean_cls:.4f} for view {view_idx}')
    elif drop_view is not None:
        print(f'\nOverall Acc: {overall_acc:.4f}   Mean Class Acc: {mean_cls:.4f} when view {drop_view} is dropped')
    else:
        print(f'\nOverall Acc: {overall_acc:.4f}   Mean Class Acc: {mean_cls:.4f} when all views are considered')

## Dataset Initializer

### MVCNN Dataset

In [7]:
test_dataset_mvcnn = ImgDataset.MultiviewImgDataset(
    root_dir='../../../MVCNN/ModelNet40-12View/*/test',
    scale_aug=False,
    rot_aug=False,
    test_mode=True,
    num_models=0,
    num_views=12,
)
test_loader_mvcnn = torch.utils.data.DataLoader(
    test_dataset_mvcnn,
    batch_size=8,
    shuffle=False,
    num_workers=4,
    pin_memory=True
)

### SVCNN Dataset

In [8]:
test_dataset_svcnn = ImgDataset.SingleImgDataset(
    root_dir='../../../MVCNN/ModelNet40-12View/*/test',
    scale_aug=False,
    rot_aug=False,
    test_mode=True,
    num_models=0,
)
test_loader_svcnn = torch.utils.data.DataLoader(
    test_dataset_svcnn,
    batch_size=32,
    shuffle=False,
    num_workers=4,
    pin_memory=True
)

## Baseline Accuracy

In [9]:
validate_model(model, 'mvcnn', test_loader_mvcnn, single_view=False, view_idx=0)

Validation: 100%|██████████| 233/233 [01:43<00:00,  2.25batch/s, acc=1.0000, loss=1.3856]


Overall Acc: 0.9177   Mean Class Acc: 0.8773 when all views are considered





## Accuracy Of Each View

In [10]:
for view_idx in range(12):
    validate_model(
        model,
        'mvcnn',
        test_loader_mvcnn,
        single_view=True,
        view_idx=view_idx,
        drop_view=None
    )

Validation: 100%|██████████| 233/233 [00:35<00:00,  6.60batch/s, acc=1.0000, loss=1.4149]



Overall Acc: 0.8742   Mean Class Acc: 0.8224 for view 0


Validation: 100%|██████████| 233/233 [00:34<00:00,  6.77batch/s, acc=1.0000, loss=1.2526]



Overall Acc: 0.8887   Mean Class Acc: 0.8427 for view 1


Validation: 100%|██████████| 233/233 [00:33<00:00,  6.91batch/s, acc=1.0000, loss=1.3136]



Overall Acc: 0.8909   Mean Class Acc: 0.8488 for view 2


Validation: 100%|██████████| 233/233 [00:33<00:00,  7.01batch/s, acc=1.0000, loss=1.2038]



Overall Acc: 0.8968   Mean Class Acc: 0.8485 for view 3


Validation: 100%|██████████| 233/233 [00:34<00:00,  6.81batch/s, acc=1.0000, loss=1.2649]



Overall Acc: 0.8914   Mean Class Acc: 0.8588 for view 4


Validation: 100%|██████████| 233/233 [00:34<00:00,  6.82batch/s, acc=0.7500, loss=1.5286]



Overall Acc: 0.8478   Mean Class Acc: 0.8076 for view 5


Validation: 100%|██████████| 233/233 [00:33<00:00,  6.89batch/s, acc=1.0000, loss=1.4634]



Overall Acc: 0.8823   Mean Class Acc: 0.8488 for view 6


Validation: 100%|██████████| 233/233 [00:33<00:00,  6.93batch/s, acc=1.0000, loss=1.4123]



Overall Acc: 0.8844   Mean Class Acc: 0.8342 for view 7


Validation: 100%|██████████| 233/233 [00:33<00:00,  6.92batch/s, acc=1.0000, loss=1.5260]



Overall Acc: 0.8688   Mean Class Acc: 0.8145 for view 8


Validation: 100%|██████████| 233/233 [00:35<00:00,  6.61batch/s, acc=1.0000, loss=1.3592]



Overall Acc: 0.8930   Mean Class Acc: 0.8452 for view 9


Validation: 100%|██████████| 233/233 [00:34<00:00,  6.79batch/s, acc=1.0000, loss=1.3186]



Overall Acc: 0.8914   Mean Class Acc: 0.8491 for view 10


Validation: 100%|██████████| 233/233 [00:34<00:00,  6.75batch/s, acc=1.0000, loss=1.5693]


Overall Acc: 0.8710   Mean Class Acc: 0.8242 for view 11





## Accuracy While Removing Views

#### Use Multi View Dataset

In [16]:
for view_idx in range(12):
    validate_model(
        model,
        'mvcnn',
        test_loader_mvcnn,
        single_view=False,   # full‐vote mode
        drop_view=view_idx
    )

Validation: 100%|██████████| 233/233 [01:39<00:00,  2.34batch/s, acc=1.0000, loss=1.3829]



Overall Acc: 0.9188   Mean Class Acc: 0.8767 when view 0 is dropped


Validation: 100%|██████████| 233/233 [01:37<00:00,  2.38batch/s, acc=1.0000, loss=1.3977]



Overall Acc: 0.9194   Mean Class Acc: 0.8806 when view 1 is dropped


Validation: 100%|██████████| 233/233 [01:39<00:00,  2.34batch/s, acc=1.0000, loss=1.3966]



Overall Acc: 0.9194   Mean Class Acc: 0.8794 when view 2 is dropped


Validation: 100%|██████████| 233/233 [01:38<00:00,  2.37batch/s, acc=1.0000, loss=1.3726]



Overall Acc: 0.9177   Mean Class Acc: 0.8773 when view 3 is dropped


Validation: 100%|██████████| 233/233 [01:37<00:00,  2.38batch/s, acc=1.0000, loss=1.3785]



Overall Acc: 0.9167   Mean Class Acc: 0.8767 when view 4 is dropped


Validation: 100%|██████████| 233/233 [01:38<00:00,  2.38batch/s, acc=1.0000, loss=1.3832]



Overall Acc: 0.9161   Mean Class Acc: 0.8764 when view 5 is dropped


Validation: 100%|██████████| 233/233 [01:37<00:00,  2.38batch/s, acc=1.0000, loss=1.3728]



Overall Acc: 0.9177   Mean Class Acc: 0.8761 when view 6 is dropped


Validation: 100%|██████████| 233/233 [01:40<00:00,  2.31batch/s, acc=1.0000, loss=1.3880]



Overall Acc: 0.9177   Mean Class Acc: 0.8785 when view 7 is dropped


Validation: 100%|██████████| 233/233 [01:41<00:00,  2.30batch/s, acc=1.0000, loss=1.3917]



Overall Acc: 0.9172   Mean Class Acc: 0.8758 when view 8 is dropped


Validation: 100%|██████████| 233/233 [01:38<00:00,  2.36batch/s, acc=1.0000, loss=1.3689]



Overall Acc: 0.9177   Mean Class Acc: 0.8773 when view 9 is dropped


Validation: 100%|██████████| 233/233 [01:37<00:00,  2.38batch/s, acc=1.0000, loss=1.3921]



Overall Acc: 0.9188   Mean Class Acc: 0.8791 when view 10 is dropped


Validation: 100%|██████████| 233/233 [01:39<00:00,  2.34batch/s, acc=1.0000, loss=1.4021]


Overall Acc: 0.9172   Mean Class Acc: 0.8770 when view 11 is dropped



