In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from main import *

In [3]:
set_seed(config["seed"])

In [4]:
# load test dataset
transform = get_transform()
test_dataset = MNIST(
    root="data/mnist/test", train=False, download=True, transform=transform
)

# data loader
test_loader = DataLoader(
    test_dataset,
    batch_size=config["batch_size"],
    num_workers=config["num_workers"],
    shuffle=False,
    pin_memory=True,
    drop_last=False,
)

# load model
model = CNN()
model.load_state_dict(torch.load(config["model_path"]))
model.to(device)

CNN(
  (conv1): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
  (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (relu1): ReLU()
  (conv2): Conv2d(20, 40, kernel_size=(5, 5), stride=(1, 1))
  (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (relu2): ReLU()
  (dropout): Dropout2d(p=0.25, inplace=False)
  (fc1): Linear(in_features=640, out_features=400, bias=True)
  (relu3): ReLU()
  (fc2): Linear(in_features=400, out_features=10, bias=True)
)

In [32]:
class Pruner:
    def __init__(self, model, modules):
        self.model = model
        self.modules = modules
        self.activations = {}
        self.weights = {}
        self.handles = []
    
    def register_hooks(self):
        for module in self.modules:
            print("registered hook for", module)
            self.handles.append(module.register_forward_hook(self))
            self.activations[module] = []
        
    def unregister_hooks(self):
        for handle in self.handles:
            handle.remove()
        self.handles = []
            
    def prune(self, fn, *args, **kwargs):
        self.activations = {}
        self.weights = {}
        self.handles = []
        
        # register hooks
        self.register_hooks()
        
        # run the model
        outputs = fn(*args, **kwargs)
        
        # average the activations
        for module, tensors in self.activations.items():
            self.activations[module] = torch.mean(torch.cat(tensors), axis=0)
        
        # unregister hooks
        self.unregister_hooks()
        return outputs
        
    def __call__(self, module, module_in, module_out):
        self.activations[module].append(module_out)
        self.weights[module] = module_in

In [33]:
try:
    print('reset')
    pruner.reset()
except:
    pass

modules_to_register = [module for module in model.modules() if isinstance(module, nn.ReLU)]

pruner = Pruner(model, modules_to_register)

reset


In [34]:
criterion = nn.CrossEntropyLoss()
test_loss, test_acc = pruner.prune(eval_fn, model, test_loader, criterion)
print(f"Test Accuracy: {test_acc:.3f}")
print(f"Test Loss: {test_loss:.3f}")

registered hook for ReLU()
registered hook for ReLU()
registered hook for ReLU()
Test Accuracy: 0.992
Test Loss: 0.070


In [36]:
print(len(pruner.activations))

3


In [38]:
for module, activations in pruner.activations.items():
    print(module, tuple(activations.shape), activations.flatten().detach().cpu().numpy()[:5], sep="\n", end="\n\n")

ReLU()
(20, 12, 12)
[0.12976679 0.12985975 0.13066787 0.13741095 0.16576554]

ReLU()
(40, 4, 4)
[0.0539217  0.25208881 0.8704111  0.37289748 0.05550476]

ReLU()
(400,)
[1.3131658e-03 1.5103534e+00 2.3150957e-01 8.7541753e-01 1.2633779e+00]

