# Set-Up for Analysis

In [1]:
import os
import dill
import copy

import torch
import numpy as np

from utils.load_data import *
from utils.analyse import *

from torch.utils.data import DataLoader

device = torch.device('cuda')

# make the output deterministic
SEED = 7
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

  from .autonotebook import tqdm as notebook_tqdm


#### Initialize the Models

In [2]:
def are_lists_of_tensors_equal(list1, list2):
    # Check if the lists have the same length
    if len(list1) != len(list2):
        return False

    # Check if each pair of tensors is equal
    for tensor1, tensor2 in zip(list1, list2):
        # Use torch.equal to compare tensors element-wise
        if not torch.equal(tensor1, tensor2):
            print("unequal tensors")
            return False

    # If all pairs of tensors are equal, the lists are equal
    return True

#### Load the Dataset

In [3]:
os.chdir('MAKE')
_, test_data = dataset()
g1, g1_names, num_g1_classes = g1_classes()
_test_loader = DataLoader(dataset=test_data, batch_size=64, shuffle=False)
_test_loader_td = DataLoader(dataset=test_data, batch_size=1, shuffle=False)

## Run Analysis
The following $\text{Threshold}$ hyperparameter values are tuned such that the outputs of the Super-HBN network match the Sem-HBN network from prior papers (13 exits at branch 1 and 2, and 14  exits at branch 3). The fine-tolerance values will are varied for analysis.

In [4]:
# models dictionary
fine_tolerance = np.arange(0.0, 1.0, 0.02)
models = {
    'AlexNet': {'model': AlexNet, 'filepath': '../results/models/AlexNet.pth', 'analyser': AnalyseAlexNet, 'threshold': [[None]], 'fine_tolerance': [None]},
    'BranchyAlexNet':
        {'model': BranchyAlexNet, 'filepath': '../results/models/Branchy-AlexNet.pth', 'analyser': AnalyseBranchyNet, 'threshold': [[1e6, -1e6], [-1e6, 1e6], [-1e-6, -1e-6]], 'fine_tolerance': [None]},
    'Sem_HBN':
        {'model': SemHBN, 'filepath': '../results/models/Sem-HBN.pth', 'analyser': AnalyseSemHBN, 'threshold': [[1e6, -1e6], [-1e6, 1e6], [-1e-6, -1e-6]], 'fine_tolerance': [None]},
    'TD_HBN':
        {'model': TD_HBN, 'filepath': '../results/models/TD-HBN.pth', 'analyser': AnalyseTDHBN, 'threshold': [[1e6, -1e6], [-1e6, 1e6], [-1e-6, -1e-6]], 'fine_tolerance': [None]},
    'SuperHBN':
        {'model': SuperHBN, 'filepath': '../results/models/Super-HBN.pth', 'analyser': AnalyseHBN, 'threshold': [[1e6, -1e6], [-1e6, 1e6], [-1e-6, -1e-6]], 'fine_tolerance': fine_tolerance},
}

for model_name, model_contents in models.items():
    test_loader = _test_loader_td if model_name == 'TD_HBN' else _test_loader

    print(f"Running {model_name} Model\nTesting Batch Size: = {test_loader.batch_size}\n" + "="*75)
    state_dict, imagi = None, None
    for idx, th in enumerate(model_contents['threshold']):
        # structure for results dictionary
        results = {
            'test_accuracy': [],
            'hierarchical_accuracy': [],
            'specificity': [],
            'flops': [],
            'memory': [],
            'time_taken': [],
            'branch1_exits': [],
            'branch2_exits': [],
            'branch3_exits': [],
            'fine_exits': [],
            'coarse_exits': []
        }
        for i, ft in enumerate(model_contents['fine_tolerance']):
            print(f"Branch: {idx+1}\t\tFine-Tolerance: {ft}\n" + "-"*75)

            # perform analysis and add to results dictionary
            result_key = model_name + (f'_results_branch{idx+1}' if th[0] is not None else '_results')
            globals()[result_key] = results.copy()
            analysis = model_contents['analyser'](model_class=model_contents['model'], filepath=model_contents['filepath'], device=device, coarse_converter=g1, threshold=th, fine_tolerance=ft,
                                                  test_loader=copy.deepcopy(test_loader))
            analysis_results = analysis.perform_analysis()

            # update results dictionary values
            for key, value in zip(globals()[result_key], analysis_results):
                globals()[result_key][key].append(value)

        # debugging
        formatted_results = "{\n"
        for key, value in globals()[result_key].items():
            formatted_results += f'  "{key}": {value},\n'
        formatted_results += "}"
        print(f"\n{result_key}:\n{formatted_results}")
    print("="*75 + "\n")

Running AlexNet Model
Testing Batch Size: = 64
Branch: 1		Fine-Tolerance: None
---------------------------------------------------------------------------
Timer Active
# Testing Iterations: 157
Timer Ended
Getting FLOPs
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.pooling.MaxPool2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.activation.ReLU'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
Getting memory size

AlexNet_results:
{
  "test_accuracy": [48.97],
  "hierarchical_accuracy": [62.56],
  "specificity": [1.0],
  "flops": [126222336.0],
  "memory": [90.19248867034912],
  "time_taken": [22.2242221],
  "branch1_exits": [0],
  "branch2_exits": [0],
  "branch3_exits": [0],
  "fine_exits": [157],
  "coarse


TD_HBN_results_branch1:
{
  "test_accuracy": [57.05],
  "hierarchical_accuracy": [64.59],
  "specificity": [0.723],
  "flops": [31793152.0],
  "memory": [46.53574848175049],
  "time_taken": [154.343473],
  "branch1_exits": [10000],
  "branch2_exits": [0],
  "branch3_exits": [0],
  "fine_exits": [7230],
  "coarse_exits": [2770],
}
Branch: 2		Fine-Tolerance: None
---------------------------------------------------------------------------
Timer Active
# Testing Iterations: 10000
Timer Ended
Getting FLOPs
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.pooling.MaxPool2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.activation.ReLU'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
Getting memory size

TD_HBN_resul

Branch: 1		Fine-Tolerance: 0.18
---------------------------------------------------------------------------
Timer Active
# Testing Iterations: 157
Timer Ended
Getting FLOPs
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.pooling.MaxPool2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.activation.ReLU'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
Getting memory size
Branch: 1		Fine-Tolerance: 0.2
---------------------------------------------------------------------------
Timer Active
# Testing Iterations: 157
Timer Ended
Getting FLOPs
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.pooling.MaxPool2d'>.
[INFO] Register zero_ops

Branch: 1		Fine-Tolerance: 0.44
---------------------------------------------------------------------------
Timer Active
# Testing Iterations: 157
Timer Ended
Getting FLOPs
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.pooling.MaxPool2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.activation.ReLU'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
Getting memory size
Branch: 1		Fine-Tolerance: 0.46
---------------------------------------------------------------------------
Timer Active
# Testing Iterations: 157
Timer Ended
Getting FLOPs
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.pooling.MaxPool2d'>.
[INFO] Register zero_op

Branch: 1		Fine-Tolerance: 0.7000000000000001
---------------------------------------------------------------------------
Timer Active
# Testing Iterations: 157
Timer Ended
Getting FLOPs
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.pooling.MaxPool2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.activation.ReLU'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
Getting memory size
Branch: 1		Fine-Tolerance: 0.72
---------------------------------------------------------------------------
Timer Active
# Testing Iterations: 157
Timer Ended
Getting FLOPs
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.pooling.MaxPool2d'>.
[INFO] Re

Branch: 1		Fine-Tolerance: 0.96
---------------------------------------------------------------------------
Timer Active
# Testing Iterations: 157
Timer Ended
Getting FLOPs
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.pooling.MaxPool2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.activation.ReLU'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
Getting memory size
Branch: 1		Fine-Tolerance: 0.98
---------------------------------------------------------------------------
Timer Active
# Testing Iterations: 157
Timer Ended
Getting FLOPs
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.pooling.MaxPool2d'>.
[INFO] Register zero_op

Branch: 2		Fine-Tolerance: 0.06
---------------------------------------------------------------------------
Timer Active
# Testing Iterations: 157
Timer Ended
Getting FLOPs
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.pooling.MaxPool2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.activation.ReLU'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
Getting memory size
Branch: 2		Fine-Tolerance: 0.08
---------------------------------------------------------------------------
Timer Active
# Testing Iterations: 157
Timer Ended
Getting FLOPs
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.pooling.MaxPool2d'>.
[INFO] Register zero_op

Branch: 2		Fine-Tolerance: 0.32
---------------------------------------------------------------------------
Timer Active
# Testing Iterations: 157
Timer Ended
Getting FLOPs
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.pooling.MaxPool2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.activation.ReLU'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
Getting memory size
Branch: 2		Fine-Tolerance: 0.34
---------------------------------------------------------------------------
Timer Active
# Testing Iterations: 157
Timer Ended
Getting FLOPs
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.pooling.MaxPool2d'>.
[INFO] Register zero_op

Branch: 2		Fine-Tolerance: 0.58
---------------------------------------------------------------------------
Timer Active
# Testing Iterations: 157
Timer Ended
Getting FLOPs
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.pooling.MaxPool2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.activation.ReLU'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
Getting memory size
Branch: 2		Fine-Tolerance: 0.6
---------------------------------------------------------------------------
Timer Active
# Testing Iterations: 157
Timer Ended
Getting FLOPs
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.pooling.MaxPool2d'>.
[INFO] Register zero_ops

Branch: 2		Fine-Tolerance: 0.84
---------------------------------------------------------------------------
Timer Active
# Testing Iterations: 157
Timer Ended
Getting FLOPs
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.pooling.MaxPool2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.activation.ReLU'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
Getting memory size
Branch: 2		Fine-Tolerance: 0.86
---------------------------------------------------------------------------
Timer Active
# Testing Iterations: 157
Timer Ended
Getting FLOPs
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.pooling.MaxPool2d'>.
[INFO] Register zero_op

Timer Active
# Testing Iterations: 157
Timer Ended
Getting FLOPs
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.pooling.MaxPool2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.activation.ReLU'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
Getting memory size
Branch: 3		Fine-Tolerance: 0.02
---------------------------------------------------------------------------
Timer Active
# Testing Iterations: 157
Timer Ended
Getting FLOPs
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.pooling.MaxPool2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.activation.ReLU'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.

Branch: 3		Fine-Tolerance: 0.26
---------------------------------------------------------------------------
Timer Active
# Testing Iterations: 157
Timer Ended
Getting FLOPs
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.pooling.MaxPool2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.activation.ReLU'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
Getting memory size
Branch: 3		Fine-Tolerance: 0.28
---------------------------------------------------------------------------
Timer Active
# Testing Iterations: 157
Timer Ended
Getting FLOPs
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.pooling.MaxPool2d'>.
[INFO] Register zero_op

Branch: 3		Fine-Tolerance: 0.52
---------------------------------------------------------------------------
Timer Active
# Testing Iterations: 157
Timer Ended
Getting FLOPs
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.pooling.MaxPool2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.activation.ReLU'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
Getting memory size
Branch: 3		Fine-Tolerance: 0.54
---------------------------------------------------------------------------
Timer Active
# Testing Iterations: 157
Timer Ended
Getting FLOPs
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.pooling.MaxPool2d'>.
[INFO] Register zero_op

Branch: 3		Fine-Tolerance: 0.78
---------------------------------------------------------------------------
Timer Active
# Testing Iterations: 157
Timer Ended
Getting FLOPs
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.pooling.MaxPool2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.activation.ReLU'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
Getting memory size
Branch: 3		Fine-Tolerance: 0.8
---------------------------------------------------------------------------
Timer Active
# Testing Iterations: 157
Timer Ended
Getting FLOPs
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.pooling.MaxPool2d'>.
[INFO] Register zero_ops

***
# Save Results
#### Save the Experimentation Results for Analysis

In [7]:
dill.dump_session('../results/experiment_results.db')