In [1]:
import torch
from tqdm import tqdm
# add path
import sys
sys.path.append('../../../MVCNN')
from models import MVCNN
from tools import ImgDataset
import numpy as np

# Calculating Importance

## Initialization

In [2]:
device = 'mps' if torch.backends.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
model = MVCNN.SVCNN('mvcnn')
weights = torch.load('../../../MVCNN/MVCNN/model-mvcnn-00050.pth', map_location=device)
model.load_state_dict(weights)
model.to(device)
model.eval()

SVCNN(
  (net_1): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (11): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (12): ReLU(inplace=True)
    (13): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (14): ReLU(inplace=True)
    (15): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  

In [4]:
# Extract CNN feature extractor from the model
feature_extractor = model.net_1
classifier = model.net_2

feature_extractor.eval()
classifier.eval()

num_views = 12  # Number of views per model

## Accuracy Calculator

In [5]:
def validate_model(model,
                   model_name,
                   test_loader,
                   single_view: bool = False,
                   view_idx: int = 0):
    all_correct = 0
    all_samples = 0
    all_loss = 0.0

    # for per‐class stats (if you still want them)
    wrong_class   = np.zeros(33, dtype=int)
    samples_class = np.zeros(33, dtype=int)

    model.eval()
    pbar = tqdm(test_loader, desc=f'Val {model_name}'+(' (1 view)' if single_view else ''), unit='batch')
    for batch_i, data in enumerate(pbar):
        labels, views = data[0].to(device), data[1].to(device)
        # views: (N, V, C, H, W)
        if model_name == 'mvcnn' and not single_view:
            # —— full MVCNN voting as before ——
            N, V, C, H, W = views.size()
            x = views.view(-1, C, H, W)                # (N*V, C, H, W)
            tgt = labels.repeat_interleave(V, dim=0)   # (N*V,)
            out = model(x)
            preds = out.argmax(1)

            # compute batch loss
            all_loss += torch.nn.functional.cross_entropy(out, tgt).item()
            
            # majority vote per object
            batch_correct = 0
            for i in range(N):
                vp = preds[i*V:(i+1)*V].cpu()
                voted = torch.mode(vp)[0]
                if voted == labels[i].cpu():
                    batch_correct += 1
                else:
                    wrong_class[labels[i].item()] += 1
                samples_class[labels[i].item()] += 1

            all_correct += batch_correct
            all_samples += N
            acc = batch_correct / N

        else:
            # —— single‐view branch (or any non-mvcnn model) ——
            # pick one view:
            x     = views[:, view_idx, ...]    # (N, C, H, W)
            tgt   = labels                     # (N,)
            out   = model(x)
            preds = out.argmax(1)

            all_loss += torch.nn.functional.cross_entropy(out, tgt).item()
            
            # accumulate per‐class if you want
            matches = (preds == tgt)
            for i, correct in enumerate(matches):
                samples_class[tgt[i].item()] += 1
                if not correct:
                    wrong_class[tgt[i].item()] += 1

            batch_correct = matches.sum().item()
            all_correct += batch_correct
            all_samples += tgt.size(0)
            acc = batch_correct / tgt.size(0)

        pbar.set_postfix({
            'acc':  f'{acc:.4f}',
            'loss': f'{(all_loss / (batch_i+1)):.4f}'
        })

    overall_acc = all_correct / all_samples
    per_cls     = (samples_class - wrong_class) / np.maximum(samples_class, 1)
    mean_cls    = per_cls[samples_class>0].mean()

    print(f'\nOverall Acc: {overall_acc:.4f}   Mean Class Acc: {mean_cls:.4f}')

## Dataset Initializer

### MVCNN Dataset

In [6]:
test_dataset_mvcnn = ImgDataset.MultiviewImgDataset(
    root_dir='../../../MVCNN/ModelNet40-12View/*/test',
    scale_aug=False,
    rot_aug=False,
    test_mode=True,
    num_models=0,
    num_views=12,
)
test_loader_mvcnn = torch.utils.data.DataLoader(
    test_dataset_mvcnn,
    batch_size=8,
    shuffle=False,
    num_workers=4,
    pin_memory=True
)

### SVCNN Dataset

In [7]:
test_dataset_svcnn = ImgDataset.SingleImgDataset(
    root_dir='../../../MVCNN/ModelNet40-12View/*/test',
    scale_aug=False,
    rot_aug=False,
    test_mode=True,
    num_models=0,
)
test_loader_svcnn = torch.utils.data.DataLoader(
    test_dataset_svcnn,
    batch_size=32,
    shuffle=False,
    num_workers=4,
    pin_memory=True
)

## Baseline Accuracy

In [8]:
validate_model(model, 'mvcnn', test_loader_mvcnn, single_view=False, view_idx=0)

Val mvcnn: 100%|██████████| 233/233 [01:41<00:00,  2.30batch/s, acc=1.0000, loss=1.3893]


Overall Acc: 0.9177   Mean Class Acc: 0.8773





## Accuracy Of Each View

In [9]:
views = [0, 1, 10, 11, 2, 3, 4, 5, 6, 7, 8, 9]

In [10]:
for i in range(num_views):
    print(f'\nValidating view {views[i]}...')
    validate_model(model, 'mvcnn', test_loader_mvcnn, single_view=True, view_idx=i)


Validating view 0...


Val mvcnn (1 view): 100%|██████████| 233/233 [00:35<00:00,  6.64batch/s, acc=1.0000, loss=1.4149]



Overall Acc: 0.8742   Mean Class Acc: 0.8224

Validating view 1...


Val mvcnn (1 view): 100%|██████████| 233/233 [00:33<00:00,  6.97batch/s, acc=0.7500, loss=1.2530]



Overall Acc: 0.8887   Mean Class Acc: 0.8427

Validating view 10...


Val mvcnn (1 view): 100%|██████████| 233/233 [00:33<00:00,  6.99batch/s, acc=1.0000, loss=1.3136]



Overall Acc: 0.8909   Mean Class Acc: 0.8488

Validating view 11...


Val mvcnn (1 view): 100%|██████████| 233/233 [00:33<00:00,  7.02batch/s, acc=1.0000, loss=1.2038]



Overall Acc: 0.8968   Mean Class Acc: 0.8485

Validating view 2...


Val mvcnn (1 view): 100%|██████████| 233/233 [00:33<00:00,  7.00batch/s, acc=0.7500, loss=1.2764]



Overall Acc: 0.8914   Mean Class Acc: 0.8588

Validating view 3...


Val mvcnn (1 view): 100%|██████████| 233/233 [00:34<00:00,  6.83batch/s, acc=0.7500, loss=1.5373]



Overall Acc: 0.8478   Mean Class Acc: 0.8076

Validating view 4...


Val mvcnn (1 view): 100%|██████████| 233/233 [00:33<00:00,  6.90batch/s, acc=0.7500, loss=1.4712]



Overall Acc: 0.8823   Mean Class Acc: 0.8488

Validating view 5...


Val mvcnn (1 view): 100%|██████████| 233/233 [00:33<00:00,  6.89batch/s, acc=0.7500, loss=1.4172]



Overall Acc: 0.8844   Mean Class Acc: 0.8342

Validating view 6...


Val mvcnn (1 view): 100%|██████████| 233/233 [00:34<00:00,  6.78batch/s, acc=0.7500, loss=1.5293]



Overall Acc: 0.8688   Mean Class Acc: 0.8145

Validating view 7...


Val mvcnn (1 view): 100%|██████████| 233/233 [00:34<00:00,  6.69batch/s, acc=0.7500, loss=1.3619]



Overall Acc: 0.8930   Mean Class Acc: 0.8452

Validating view 8...


Val mvcnn (1 view): 100%|██████████| 233/233 [00:33<00:00,  6.91batch/s, acc=0.7500, loss=1.3240]



Overall Acc: 0.8914   Mean Class Acc: 0.8491

Validating view 9...


Val mvcnn (1 view): 100%|██████████| 233/233 [00:33<00:00,  6.89batch/s, acc=1.0000, loss=1.5693]


Overall Acc: 0.8710   Mean Class Acc: 0.8242





## Accuracy While Removing Views

#### Use Multi View Dataset