In [2]:
from models import MVCNN
from tools import ImgDataset
import torch
from tqdm import tqdm
import matplotlib.pyplot as plt

# Calculating Importance

## Initialization

In [3]:
device = 'mps' if torch.backends.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu'

In [4]:
model = MVCNN.SVCNN('svcnn')
weights = torch.load('../../../MVCNN/MVCNN/model-00050.pth', map_location=device)
model.load_state_dict(weights)
model.to(device)
model.eval()

SVCNN(
  (net_1): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (11): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (12): ReLU(inplace=True)
    (13): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (14): ReLU(inplace=True)
    (15): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  

In [5]:
test_dataset = ImgDataset.MultiviewImgDataset(
    root_dir='../../../MVCNN/ModelNet40-12View/*/test',
    scale_aug=False,
    rot_aug=False,
    test_mode=True,
    num_models=1000,
    num_views=12,
    shuffle=False
)
test_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=1,
    shuffle=False,
    num_workers=1,
    pin_memory=True
)

In [6]:
# Extract CNN feature extractor from the model
feature_extractor = model.net_1
classifier = model.net_2

feature_extractor.eval()
classifier.eval()

num_views = 12  # Number of views per model

views are as

0, 1, 10, 11, 2, 3, 4, 5, 6, 7, 8, 9

## Accuracy Of Each View

In [11]:
ranking = {
    'view 0': [], 
    'view 1': [],
    'view 2': [],
    'view 3': [],
    'view 4': [],
    'view 5': [],
    'view 6': [],
    'view 7': [],
    'view 8': [],
    'view 9': [],
    'view 10': [],
    'view 11': []
}

In [12]:
!rm outputs

In [13]:
for view in [0, 1, 10, 11, 2, 3, 4, 5, 6, 7, 8, 9]:
    model.eval()
    samples_class = torch.zeros(33)
    wrong_class = torch.zeros(33)
    all_points = 0
    all_correct_points = 0
    all_loss = 0.0

    with torch.no_grad():
        pbar = tqdm(test_loader, desc="Validating", leave=False, ncols=80)
        for _, data in enumerate(pbar, 0):
            data[1] = data[1][:, view, :, :, :].to(device)  # [N, C, H, W]
            N, C, H, W = data[1].size()
            in_data = data[1]
            target = data[0].to(device)

            out_data = model(in_data)
            pred = torch.max(out_data, 1)[1]
            all_loss += torch.nn.functional.cross_entropy(out_data, target).item()
            results = pred == target

            for i in range(N):
                obj_pred = pred[i]
                obj_target = target[i]
                if obj_target != obj_pred:
                    wrong_class[obj_target.item()] += 1
                samples_class[obj_target.item()] += 1

            correct = torch.sum(results.long())
            all_correct_points += correct.item()
            all_points += results.size(0)

    print(f"Validation accuracy for view {view}: ", all_correct_points / all_points)
    print("Validation loss: ", all_loss / len(test_loader))

                                                                                

Validation accuracy for view 0:  0.8338709677419355
Validation loss:  0.8813561098326828


                                                                                

Validation accuracy for view 1:  0.8575268817204301
Validation loss:  0.9082475790514812


                                                                                

Validation accuracy for view 10:  0.853763440860215
Validation loss:  0.8412012453337213


                                                                                

Validation accuracy for view 11:  0.8376344086021505
Validation loss:  0.8129057454942058


                                                                                

Validation accuracy for view 2:  0.8521505376344086
Validation loss:  0.7703247849596861


                                                                                

Validation accuracy for view 3:  0.8543010752688172
Validation loss:  0.7904298889504381


                                                                                

Validation accuracy for view 4:  0.8693548387096774
Validation loss:  0.7775200685571102


                                                                                

Validation accuracy for view 5:  0.8424731182795699
Validation loss:  0.7697400303302722


                                                                                

Validation accuracy for view 6:  0.85
Validation loss:  0.7882564354094462


                                                                                

Validation accuracy for view 7:  0.8225806451612904
Validation loss:  0.9045681064739782


                                                                                

Validation accuracy for view 8:  0.8016129032258065
Validation loss:  1.11109438548692


                                                                                

Validation accuracy for view 9:  0.8408602150537634
Validation loss:  0.8763477378425916




In [14]:
prev=torch.tensor([]).to(device)
for batch_idx, (target, data, data_path) in tqdm(enumerate(test_loader), total=len(test_loader)):
    data = data.to(device)
    target = target.to(device)
    for i, view in enumerate([0, 1, 10, 11, 2, 3, 4, 5, 6, 7, 8, 9]):
        view_data = data[:, view, :, :, :].squeeze(0)
        view_data = view_data.to(device)
        
        if torch.equal(view_data, prev):
            break
        
        with torch.no_grad():
            features = feature_extractor(view_data.unsqueeze(0))
            pooled = torch.max(features, dim=0)[0].unsqueeze(0)
            pooled = pooled.view(pooled.size(0), -1)
            
            with torch.no_grad():
                output = classifier(pooled)
            
            _, pred = torch.max(output, 1)
            out = (pred==target).float().item()
            ranking[f'view {view}'].append(out)
        
        prev = view_data
        
        with open ("outputs", 'a') as f:
            f.write(f"{data_path[i]} view {view}, {out}\n")

100%|██████████| 1860/1860 [02:26<00:00, 12.72it/s]


In [15]:
ranking = {view: sum(ranking[view]) / len(ranking[view]) if ranking[view] else 0.0 for view in ranking}

In [16]:
ranking

{'view 0': 0.8338709677419355,
 'view 1': 0.8575268817204301,
 'view 2': 0.8521505376344086,
 'view 3': 0.8543010752688172,
 'view 4': 0.8693548387096774,
 'view 5': 0.8424731182795699,
 'view 6': 0.85,
 'view 7': 0.8225806451612904,
 'view 8': 0.8016129032258065,
 'view 9': 0.8408602150537634,
 'view 10': 0.853763440860215,
 'view 11': 0.8376344086021505}

In [17]:
def min_max_norm(arr):
    max_val = max(arr)
    min_val = min(arr)
    return [(x-min_val) / (max_val - min_val) for x in arr]

def l1_norm(arr):
    norm = sum(abs(x) for x in arr)
    return [(x)*10 / norm for x in arr] if norm != 0 else arr

In [18]:
min_max_norm(list(ranking.values()))

[0.4761904761904757,
 0.8253968253968252,
 0.7460317460317459,
 0.7777777777777767,
 1.0,
 0.6031746031746036,
 0.7142857142857136,
 0.30952380952380987,
 0.0,
 0.5793650793650784,
 0.7698412698412694,
 0.5317460317460315]

In [19]:
l1_norm(list(ranking.values()))

[0.824298469387755,
 0.8476828231292517,
 0.8423681972789114,
 0.8444940476190476,
 0.8593749999999999,
 0.8328018707482994,
 0.8402423469387754,
 0.8131377551020408,
 0.7924107142857142,
 0.8312074829931971,
 0.8439625850340134,
 0.8280187074829932]

## Accuracy While Removing Views

In [7]:
test_dataset = ImgDataset.MultiviewImgDataset(
    root_dir='../../../MVCNN/ModelNet40-12View/*/test',
    scale_aug=False,
    rot_aug=False,
    test_mode=True,
    # num_models=100,
    num_views=12,
)
test_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=1,
    shuffle=False,
    num_workers=1,
    pin_memory=True
)

In [8]:
train_path = '../../../MVCNN/ModelNet40-12View/*/train'
test_path = '../../../MVCNN/ModelNet40-12View/*/test'

In [9]:
per_view_importance = {
    0: {
        'samples': torch.zeros(33),
        'wrong': torch.zeros(33),
        'all_points': 0,
        'all_correct': 0,
    },
    1: {
        'samples': torch.zeros(33),
        'wrong': torch.zeros(33),
        'all_points': 0,
        'all_correct': 0,
    },
    2: {
        'samples': torch.zeros(33),
        'wrong': torch.zeros(33),
        'all_points': 0,
        'all_correct': 0,
    },
    3: {
        'samples': torch.zeros(33),
        'wrong': torch.zeros(33),
        'all_points': 0,
        'all_correct': 0,
    },
    4: {
        'samples': torch.zeros(33),
        'wrong': torch.zeros(33),
        'all_points': 0,
        'all_correct': 0,
    },
    5: {
        'samples': torch.zeros(33),
        'wrong': torch.zeros(33),
        'all_points': 0,
        'all_correct': 0,
    },
    6: {
        'samples': torch.zeros(33),
        'wrong': torch.zeros(33),
        'all_points': 0,
        'all_correct': 0,
    },
    7: {
        'samples': torch.zeros(33),
        'wrong': torch.zeros(33),
        'all_points': 0,
        'all_correct': 0,
    },
    8: {
        'samples': torch.zeros(33),
        'wrong': torch.zeros(33),
        'all_points': 0,
        'all_correct': 0,
    },
    9: {
        'samples': torch.zeros(33),
        'wrong': torch.zeros(33),
        'all_points': 0,
        'all_correct': 0,
    },
    10: {
        'samples': torch.zeros(33),
        'wrong': torch.zeros(33),
        'all_points': 0,
        'all_correct': 0,
    },
    11: {
        'samples': torch.zeros(33),
        'wrong': torch.zeros(33),
        'all_points': 0,
        'all_correct': 0,
    },
}

In [None]:
model.eval()
samples_class_baseline = torch.zeros(33)
wrong_class_baseline = torch.zeros(33)
all_points = 0
all_correct_points_baseline = 0
all_loss = 0.0

with torch.no_grad():
    pbar = tqdm(test_loader, desc="Validating", leave=False, ncols=80)
    for batch_idx, data in enumerate(pbar, 0):
        N, V, C, H, W = data[1].size()
        
        in_data = data[1].view(-1, C, H, W).to(device)
        target = data[0].to(device).repeat_interleave(V)
        
        out_data = model(in_data)
        pred = torch.max(out_data, 1)[1]
        all_loss += torch.nn.functional.cross_entropy(out_data, target).item()
        results = pred == target
        
        for i in range(N):
            obj_preds = pred[i*V:(i+1)*V]
            obj_pred = torch.mode(obj_preds.to('cpu'))[0].to(device)
            obj_target = target[i*V]
            
            if obj_target != obj_pred:
                wrong_class_baseline[obj_target.item()] += 1
            samples_class_baseline[obj_target.item()] += 1
        correct = torch.sum(results.long())
        all_correct_points_baseline += correct.item()
        all_points += results.size()[0]
        
        for view in tqdm([0, 1, 10, 11, 2, 3, 4, 5, 6, 7, 8, 9], desc=f"View removal batch {batch_idx}", leave=False):
            in_data = data[1][:, [i for i in range(V) if i != view], :, :, :]
            in_data = in_data.contiguous().view(-1, C, H, W).to(device)
            target_view_removed = data[0].to(device).repeat_interleave(V - 1)
            out_data = model(in_data)
            pred = torch.max(out_data, 1)[1]
            results = pred == target_view_removed
            
            for i in range(N):
                obj_preds = pred[i*(V-1):(i+1)*(V-1)]
                obj_pred = torch.mode(obj_preds.to('cpu'))[0].to(device)
                obj_target = target[i*V]
                if obj_preds.numel() == 0:
                    continue
                if obj_target != obj_pred:
                    per_view_importance[view]['wrong'][obj_target.item()] += 1
                # print(obj_target.item(), obj_pred.item(), view)
                per_view_importance[view]['samples'][obj_target.item()] += 1
            correct = torch.sum(results.long())
            per_view_importance[view]['all_correct']+=correct.item()
            per_view_importance[view]['all_points']+=results.size()[0]

Validating:   0%|                                      | 0/1860 [00:00<?, ?it/s]

In [8]:
per_view_importance

{0: {'samples': tensor([100., 100.,  20., 100., 100.,  20., 100., 100.,  20.,  20.,  20.,  20.,
           20., 100., 100.,  20.,  20.,  20., 100.,  20., 100., 100.,  20., 100.,
           20.,  20.,  20.,  20., 100., 100., 100.,  20.,  20.]),
  'wrong': tensor([ 0.,  0.,  8.,  3.,  4.,  1.,  0.,  3.,  4.,  5.,  4.,  2., 16.,  5.,
           8.,  1.,  3.,  0.,  6.,  2., 11., 11., 11., 13.,  6.,  6.,  6.,  3.,
           2., 19., 17.,  3.,  7.]),
  'all_points': 20460,
  'all_correct': 17265},
 1: {'samples': tensor([100., 100.,  20., 100., 100.,  20., 100., 100.,  20.,  20.,  20.,  20.,
           20., 100., 100.,  20.,  20.,  20., 100.,  20., 100., 100.,  20., 100.,
           20.,  20.,  20.,  20., 100., 100., 100.,  20.,  20.]),
  'wrong': tensor([ 0.,  0.,  7.,  3.,  5.,  1.,  0.,  3.,  4.,  4.,  4.,  2., 16.,  4.,
           7.,  1.,  3.,  0.,  7.,  2., 10., 12., 12., 13.,  7.,  6.,  6.,  3.,
           1., 20., 19.,  3.,  6.]),
  'all_points': 20460,
  'all_correct': 17221},
 2: 

In [23]:
view_indices = list(range(12))  # adjust if you have a different number of views

# 1) Baseline object-level accuracy
correct_baseline = 0
total_objects = 0

with torch.no_grad():
    # Outer tqdm over batches
    for data in tqdm(test_loader, desc='Baseline Evaluation', ncols=80):
        imgs, labels = data[1], data[0]           # imgs: [N, V, C, H, W], labels: [N]
        N, V, C, H, W = imgs.shape

        # Flatten all views
        in_all = imgs.view(-1, C, H, W).to(device)       # [N*V, C, H, W]
        lbls  = labels.to(device)                        # [N]

        out = model(in_all)                              # [N*V, num_classes]
        pred = out.argmax(dim=1)                         # [N*V]

        # Majority vote per object
        for i in range(N):
            block = pred[i*V:(i+1)*V]
            vote = torch.mode(block.to('cpu'))[0].to(device)
            if vote.item() == lbls[i].item():
                correct_baseline += 1
            total_objects += 1

baseline_acc = correct_baseline / total_objects
print(f"\nBaseline object-level accuracy: {baseline_acc:.4f}\n")


# 2) Accuracy with each view removed
drop_acc = {}

with torch.no_grad():
    for v in view_indices:
        correct = 0
        total   = 0

        # Progress bar over batches, indicating which view is removed
        desc = f'Remove view {v:2d}'
        for data in tqdm(test_loader, desc=desc, ncols=80, leave=False):
            imgs, labels = data[1], data[0]        # [N, V, C, H, W], [N]
            N, V, C, H, W = imgs.shape

            # remove view v across all samples
            keep_idx = [i for i in range(V) if i != v]
            sub = imgs[:, keep_idx, :, :, :]       # [N, V-1, C, H, W]
            in_sub = sub.contiguous().view(-1, C, H, W).to(device)  # [N*(V-1), C, H, W]
            lbls  = labels.to(device)

            out = model(in_sub)                    # [N*(V-1), num_classes]
            pred = out.argmax(dim=1)               # [N*(V-1)]

            # Majority vote per object
            for i in range(N):
                block = pred[i*(V-1):(i+1)*(V-1)]
                vote = torch.mode(block.to('cpu'))[0].to(device)
                if vote.item() == lbls[i].item():
                    correct += 1
                total += 1

        acc = correct / total
        drop_acc[v] = acc
        delta = baseline_acc - acc
        print(f"Without view {v:2d}: accuracy = {acc:.4f}, drop = {delta:+.4f}")

# 3) Optionally, summarize all drops
print("\nSummary of accuracy drops by view:")
for v in view_indices:
    print(f" View {v:2d}: Δ = {baseline_acc - drop_acc[v]:+.4f}")

Baseline Evaluation: 100%|██████████████████| 1860/1860 [02:09<00:00, 14.35it/s]



Baseline object-level accuracy: 0.9000



                                                                                

Without view  0: accuracy = 0.8978, drop = +0.0022


                                                                                

Without view  1: accuracy = 0.8973, drop = +0.0027


                                                                                

Without view  2: accuracy = 0.9005, drop = -0.0005


                                                                                

Without view  3: accuracy = 0.9000, drop = +0.0000


                                                                                

Without view  4: accuracy = 0.8995, drop = +0.0005


                                                                                

Without view  5: accuracy = 0.8989, drop = +0.0011


                                                                                

Without view  6: accuracy = 0.8995, drop = +0.0005


                                                                                

Without view  7: accuracy = 0.9016, drop = -0.0016


                                                                                

Without view  8: accuracy = 0.8984, drop = +0.0016


                                                                                

Without view  9: accuracy = 0.8989, drop = +0.0011


                                                                                

Without view 10: accuracy = 0.9000, drop = +0.0000


                                                                                

Without view 11: accuracy = 0.8978, drop = +0.0022

Summary of accuracy drops by view:
 View  0: Δ = +0.0022
 View  1: Δ = +0.0027
 View  2: Δ = -0.0005
 View  3: Δ = +0.0000
 View  4: Δ = +0.0005
 View  5: Δ = +0.0011
 View  6: Δ = +0.0005
 View  7: Δ = -0.0016
 View  8: Δ = +0.0016
 View  9: Δ = +0.0011
 View 10: Δ = +0.0000
 View 11: Δ = +0.0022




In [None]:
len(per_view_importance)

In [None]:
all_correct_points_baseline/all_points

In [None]:
for view in per_view_importance:
    print(f"view {view} removed, accuracy now: {per_view_importance[view]['all_correct']*100/per_view_importance[view]['all_points']:.4f}%")

In [None]:
for view in per_view_importance:
    print(view,"," ,sum(per_view_importance[view]) / len(per_view_importance[view])) if per_view_importance[view] else 0.0

In [None]:
min_max_norm([sum(per_view_importance[view]) for view in per_view_importance])

In [None]:
max_l1 = max(sum(per_view_importance[view])/len(per_view_importance[view]) for view in per_view_importance)
min_l1 = min(sum(per_view_importance[view])/len(per_view_importance[view]) for view in per_view_importance)
min_l1+max_l1

In [None]:
def l1_norm(arr):
    norm = sum(abs(x) for x in arr)
    return [x / (norm) for x in arr] if norm != 0 else arr

In [None]:
l1_norm(([sum(per_view_importance[view])/len(per_view_importance[view]) for view in per_view_importance]))

In [None]:
norms = [0.08223510806536635,
    0.08328940432261465,
    0.08389185932675652,
    0.08328940432261465,
    0.08411777995330974,
    0.0835906318246856,
    0.0827622561939905,
    0.08291286994502597,
    0.08374124557572106,
    0.0833647111981324,
    0.08246102869191957,
    0.08434370057986294]
norms = [1-min_l1+max_l1-x for x in norms]

In [None]:
norms

In [None]:
def find_imp(arr):
    mean = sum(abs(x) for x in arr)/len(arr)
    return [(mean)/x for x in arr] if mean != 0 else arr

In [None]:
find_imp(norms)

In [None]:
find_imp([sum(per_view_importance[view]) for view in per_view_importance])

In [None]:
sum(find_imp([sum(per_view_importance[view]) for view in per_view_importance]))

In [None]:
plt.plot(find_imp([sum(per_view_importance[view]) for view in per_view_importance]))

## Baseline Accuarcy of Whole MVCNN Model

In [None]:
test_dataset = ImgDataset.MultiviewImgDataset(
    root_dir='../../../MVCNN/ModelNet40-12View/*/test',
    scale_aug=False,
    rot_aug=False,
    test_mode=True,
    num_models=1000,
    num_views=12,
    shuffle=False
)
test_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=32,
    shuffle=False,
    num_workers=1,
    pin_memory=True
)

In [None]:

"""Calculate validation accuracy"""
print("Calculating validation accuracy...")
model.eval()
samples_class = torch.zeros(33)
wrong_class = torch.zeros(33)
all_points = 0
all_correct_points = 0
all_loss = 0.0

model.eval()

with torch.no_grad():
    pbar = tqdm(
        test_loader, 
        desc="Validating", 
        leave=False,
        ncols=80
    )
    for _, data in enumerate(pbar, 0):
        N, V, C, H, W = data[1].size()
        in_data = data[1].view(-1, C, H, W).to(device)
        target = data[0].to(device).repeat_interleave(V)
        
        out_data = model(in_data)
        pred = torch.max(out_data, 1)[1]
        all_loss += torch.nn.functional.cross_entropy(out_data, target).item()
        results = pred==target
        
        for i in range(N):
            obj_preds = pred[i*V:(i+1)*V]
            obj_pred = torch.mode(obj_preds.to('cpu'))[0].to(device)
            obj_target = target[i*V]
            
            if obj_target != obj_pred:
                wrong_class[obj_target.item()] += 1
            samples_class[obj_target.item()] += 1
        
        correct = torch.sum(results.long())
        all_correct_points += correct.item()
        all_points += results.size()[0]

print("Validation accuracy: ", all_correct_points/all_points)
print("Validation loss: ", all_loss/len(test_loader))