<a href="https://colab.research.google.com/github/harvard-visionlab/sroh/blob/main/2022/sroh_boostrap_lesioning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Pruning Bootstrap

To establish a random baseline, it would be nice to randomly shuffle our mask a large number of times (say 1000).

A Naive approach would be just to run prune.random_unstructured(module, name="weight", amount=0.3) 1000 times, but that would be slow!

The biggest waist would be passing images through the model backbone (model.features) 1000 times, since the backbone isn't changing.

So we could just "save" the output of the backbone, then pass it through the classifier 1000 times. This should be faster, but still could be slow. 

More generally, we'll want to setup two functions, one to handle the forward pass up to our "pruning layer", and then store the outputs for multiple passes from the pruning layer forward with different random masks. 

Below I implement a demo of this approach, and use it to benchmark about how long it would take to run on the full validation set.

In [1]:
import torch 
from torch.nn.utils import prune
from torch.nn.utils.prune import (
    BasePruningMethod, 
    _validate_pruning_amount_init,
    _validate_pruning_amount,
    _compute_nparams_toprune
)
from torchvision import models
from pdb import set_trace

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = models.alexnet(pretrained=True)
model.to(device)

  f"The parameter '{pretrained_param}' is deprecated since 0.13 and will be removed in 0.15, "
Downloading: "https://download.pytorch.org/models/alexnet-owt-7be5be79.pth" to /root/.cache/torch/hub/checkpoints/alexnet-owt-7be5be79.pth


  0%|          | 0.00/233M [00:00<?, ?B/s]

AlexNet(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(6, 6))
  (classifier): Sequential(
    (0): Dropout(p=0.5, inplace=False)
    (1): Linear(in_features=9216, out_features=4096, bias=True)
 

## lesion model.classifier[1]

Let's say we're lesioning units in model.classifier[1]

We'll need two functions, forward_before_prune, and forward_after_prune, which will basically split our model in two, such that:

```
out1 = forward_before_prune(model, image_batch)
out = forward_after_prune(model, out1)
```

will give the same output as:
```
out model(image_batch)
```

So how do we do this splitting? We start by inspecting the forward pass of each module of the model.
```
model.forward??
```

and any submodules we need to subdivide further, e.g.,

```
model.classifier.forward??
```


In [None]:
model.forward??

In [None]:
model.classifier.forward??

In [3]:
# Within our custom functions, we refer to the model as "self" 
# (so that our forward functions read like the actual forward function of the 
# that you see when running `model.forward??`).

# we need to write some functions to break the forward pass up into two steps
# everything that happens before our pruned module, and everything
# that happens after (including the pruned module)
def forward_before_prune(self, x):
  x = self.features(x)
  x = self.avgpool(x)
  x = torch.flatten(x, 1)
  x = self.classifier[0](x)

  return x

def forward_after_prune(self, x):
  
  x = self.classifier[1](x)
  x = self.classifier[2](x)
  x = self.classifier[3](x)
  x = self.classifier[4](x)
  x = self.classifier[5](x)
  x = self.classifier[6](x)

  return x  

In [4]:
# test to make sure your custom forward functions give 
# the same output as the original model

x = torch.rand(10,3,224,224).to(device)
model.eval()
with torch.no_grad():
  out1 = model(x)
  out_before = forward_before_prune(model, x)
  out2 = forward_after_prune(model, out_before)
torch.allclose(out1, out2)

True

## Now run the pruning/lesioning bootstrap

In [5]:
module = model.classifier[1]
prune.random_unstructured(module, name="weight", amount=0.0)

Linear(in_features=9216, out_features=4096, bias=True)

In [6]:
mask = model.classifier[1].weight_mask
print(mask.sum() / mask.nelement())

tensor(1., device='cuda:0')


In [7]:
# our dummy batch to simulate how long things take for a batch of 256 images
imgs = torch.rand(256,3,224,224).to(device)

In [13]:
import time
import numpy as np
from fastprogress import master_bar, progress_bar 

start_time = time.time()
percentages = np.array([.05, .10, .15, .30, .50, .70, .90])
num_samples = 1000
model.eval()
with torch.no_grad():
  before_features = forward_before_prune(model, imgs)
  mask = module.weight_mask

  # this new mask will get "randomly filled" with ones/zeros 
  # on each pass through nested loops below
  new_mask = torch.empty(mask.nelement()).to(device)

  # big savings here, by going through each mask percent
  # and each number of samples without recomptuing features
  # for this batch
  mb = master_bar(percentages)
  for pct_num,mask_pct in enumerate(mb):
    for sample_num in progress_bar(range(num_samples), parent=mb):
      # setting random seed based on mask_pct and sample_num
      # to make sure every batch has the same mask for the same 
      # combination of mask_pct and sample number
      torch.manual_seed(mask_pct + sample_num)

      # the _ operators happen "in place" (memory efficient)
      # we get random uniform numbers gretaer than mask_pct
      new_mask.uniform_(0, 1).gt_(mask_pct)    

      # copy (in place)   
      model.classifier[1].weight_mask.copy_(new_mask.view(mask.size()))

      # forward through the model
      out = forward_after_prune(model, before_features.clone()) 
duration = time.time() - start_time

In [14]:
estimated_total_seconds = duration * 50000 / 256 # 50k validation images, divided by batch size 256
estimated_hrs = estimated_total_seconds / (3600) # 3600 seconds per hour
print(f"Estimated hours to run bootstrapped validation: {estimated_hrs:4.1f}hrs")

Estimated hours to run bootstrapped validation:  2.5hrs
