In [None]:
#| default_exp prune.pruner

In [None]:
#| export
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_pruning as tp
from torch_pruning.pruner import function

import numpy as np
import torch
import torch.nn as nn
import pickle
from itertools import cycle
from fastcore.basics import store_attr, listify, true
from fasterai.core.criteria import *
from fastai.vision.all import *

In [None]:
#| include: false
from nbdev.showdoc import *

In [None]:
#| export
class Pruner():
    def __init__(self, model, context, criteria, layer_type=[nn.Conv2d, nn.Linear, nn.LSTM], example_inputs=torch.randn(1,3,224,224), ignored_layers=None):
        store_attr()
        self.ignored_layers = []
        self.ignored_params = []
        if ignored_layers is not None:
            for layer in ignored_layers:
                if isinstance(layer, nn.Module):
                    self.ignored_layers.extend(list(layer.modules()))
                elif isinstance(layer, nn.Parameter):
                    self.ignored_params.append(layer)

        self.DG = tp.DependencyGraph()
        self.DG.build_dependency(self.model, example_inputs=example_inputs.to(next(model.parameters()).device), ignored_params=self.ignored_params)
        self._save_init_state()
        self._reset_threshold()
        self.init_num_groups = None

    def compute_threshold(self, sparsity):
        self.global_importance = {}
        for ix, grp in enumerate(self.DG.get_all_groups(root_module_types=self.layer_type, ignored_layers=self.ignored_layers)):
            imp = self.group_importance(grp)
            self.global_importance[ix] = imp

        global_imp = torch.cat(list(self.global_importance.values()), dim=0)

        self.init_num_groups = self.init_num_groups or len(global_imp)
        n_pruned = np.clip(int((1-sparsity/100)*self.init_num_groups), 1, len(global_imp))
        self.global_threshold = torch.topk(global_imp, n_pruned)[0].min()

    def prune_group(self, group, ix, sparsity, round_to):
        module = group[0][0].target.module
        pruning_fn = group[0][0].handler
        pruning_idxs = self.prune_method(group, ix, sparsity, round_to)
        
        group = self.DG.get_pruning_group(module, pruning_fn, pruning_idxs.tolist())
        group.prune()
    
    def prune_model(self, sparsity, round_to=None):
        if self.context=='global': self.compute_threshold(sparsity)
        for ix, group in enumerate(self.DG.get_all_groups(root_module_types=self.layer_type, ignored_layers=self.ignored_layers)):
            self.prune_group(group, ix, sparsity, round_to)

    def prune_method(self, group, ix, sparsity, round_to):
        if self.context=='global':
            imp = self.global_importance[ix]
            n_pruned = max(1, int(imp.ge(self.global_threshold).sum()))
        else:
            imp = self.group_importance(group)
            
            if self.DG.is_out_channel_pruning_fn(group[0].dep.handler):
                prunable_channels = group[0].dep.target.module._init_out_channels
            else:
                prunable_channels = group[0].dep.target.module._init_in_channels

            n_pruned = max(1, int((1-sparsity/100)*prunable_channels))
 
        threshold = torch.topk(imp, int(self._rounded_sparsity(torch.tensor(n_pruned), round_to)))[0].min() if round_to else torch.topk(imp, n_pruned)[0].min()

        return imp.lt(threshold).nonzero().view(-1)
    
                
    def updated_sparsity(self, m, sparsity):
        init_channels = m._init_out_channels
        return sparsity
                
    def _save_init_state(self):
        for m in self.model.modules():
            if hasattr(m, 'weight'):
                setattr(m, '_init_out_channels', self.DG.get_out_channels(m))
                setattr(m, '_init_in_channels', self.DG.get_in_channels(m))

    def _rounded_sparsity(self, n_to_prune, round_to):
        return max(round_to*torch.floor(n_to_prune/round_to), round_to)
    
    def _reset_threshold(self):
        self.global_threshold=None
    
    def group_importance(self, group):
        handler_map = {
            function.prune_conv_out_channels: 'filter',
            #function.prune_linear_out_channels: 'row',
            #function.prune_linear_in_channels: 'column',
            function.prune_conv_in_channels: 'shared_kernel',
            # Additional handlers can be added here
        }

        group_imp = []
        group_idxs = []

        for i, (dep, idxs) in enumerate(group):
            if dep.handler in handler_map:
                impo = self.criteria(dep.target.module, handler_map.get(dep.handler), squeeze=True)
                group_imp.append(impo)
                group_idxs.append(group[i].root_idxs)

        reduced_imp = torch.zeros_like(group_imp[0])

        for i, (imp, root_idxs) in enumerate(zip(group_imp, group_idxs)):
            imp = imp.to('cpu')
            reduced_imp = reduced_imp.to('cpu')
            reduced_imp.scatter_add_(0, torch.tensor(root_idxs, device=imp.device), imp)

        reduced_imp /= len(group_imp)

        return reduced_imp.to(default_device())

In [None]:
show_doc(Pruner.prune_model)

---

[source](https://github.com/nathanhubens/fasterai/tree/master/blob/master/fasterai/prune/pruner.py#L51){target="_blank" style="float:right; font-size:smaller"}

### Pruner.prune_model

>      Pruner.prune_model (sparsity, round_to=None)

Let's try the `Pruner` with a VGG16 model

In [None]:
torch.topk(torch.randn(1000), int(max(8*torch.floor(torch.tensor(30)/8), 8)))[0].min()

tensor(2.0126)

In [None]:
torch.topk(torch.randn(1000), 30)[0].min()

tensor(2.0065)

In [None]:
max(8*torch.floor(torch.tensor(30)/8), 8)

tensor(24.)

In [None]:
model = resnet18().to('cuda:0')
model

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

The `Pruner`can either remove filters based on `local` criteria (i.e. each layer will be trimmed of the same % of filters)

In [None]:
pruner = Pruner(model, 'local', large_final, layer_type=[nn.Conv2d])
pruner.prune_model(3, round_to=8)
print(model)

pruned_macs, pruned_params = tp.utils.count_ops_and_params(model, torch.randn(1,3,224,224).to('cuda:0')); pruned_macs, pruned_params

ResNet(
  (conv1): Conv2d(3, 48, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(48, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(48, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(48, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

(1393209728.0, 10405416)

The `Pruner`can also remove filters based on `global` criteria (i.e. each layer will be trimmed of a different % of filters, but we specify the sparsity of the whole network)

In [None]:
model = vgg16_bn()

In [None]:
pruner = Pruner(model, 'global', large_final, layer_type=[nn.Conv2d])
pruner.prune_model(50)
print(model)

pruned_macs, pruned_params = tp.utils.count_ops_and_params(model, torch.randn(1,3,224,224)); pruned_macs, pruned_params

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (7): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (9): ReLU(inplace=True)
    (10): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (12): ReLU(inplace=True)
    (13): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (14): Conv2d(128, 256

(9945391172.0, 126347419)

# New Pruner

In [None]:
#| export
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_pruning as tp
from torch_pruning.pruner import function

import numpy as np
import torch
import torch.nn as nn
import pickle
from itertools import cycle
from fastcore.basics import store_attr, listify, true
from fasterai.core.criteria import *
from fastai.vision.all import *

In [None]:
#| export
class Pruner():
    def __init__(self, model, context, criteria, example_inputs=torch.randn(1,3,224,224), ignored_layers=None):
        store_attr()
        
        self.pruner = tp.pruner.MetaPruner( # We can always choose MetaPruner if sparse training is not required.
            model,
            example_inputs,
            importance=criteria,
            pruning_ratio=0.5,
            ignored_layers=ignored_layers,
            global_pruning= context,
        )
    
    def prune_model_old(self, sparsity):
        for ix, group in enumerate(self.pruner.DG.get_all_groups(root_module_types=self.pruner.root_module_types, ignored_layers=self.pruner.ignored_layers)):
            self.prune_group(group, ix, sparsity)

    def prune_model(self, sparsity):
        for m in self.model.modules():
            try: 
                self.manual_prune(m, sparsity)
            except Exception as error: print(error)
    
    def prune_group(self, group, ix, sparsity):
        module = group[0][0].target.module
        pruning_fn = group[0][0].handler
        pruning_idxs = self.prune_method(group, ix, sparsity, round_to)
        
        group = self.DG.get_pruning_group(module, pruning_fn, pruning_idxs.tolist())
        group.prune()

    def manual_prune(self, layer, pruning_ratios_or_idxs):
        all_groups = list(self.pruner.DG.get_all_groups(root_module_types=self.pruner.root_module_types, ignored_layers=self.pruner.ignored_layers))
        pruning_fn = all_groups[0][0][0].handler
        if self.pruner.DG.is_out_channel_pruning_fn(pruning_fn):
            prunable_channels = self.pruner.DG.get_out_channels(layer)
        else:
            prunable_channels = self.pruner.DG.get_in_channels(layer)
        full_group = self.pruner.DG.get_pruning_group(layer, pruning_fn, list(range(prunable_channels)))
        imp = self.pruner.estimate_importance(full_group)
        imp_argsort = torch.argsort(imp)
        n_pruned = int(prunable_channels * (1 - pruning_ratios_or_idxs))
        pruning_idxs = imp_argsort[:n_pruned]
        print(pruning_idxs)
 
        group = self.pruner.DG.get_pruning_group(layer, pruning_fn, pruning_idxs)
        group.prune()

In [None]:
import torch
from torchvision.models import resnet18
import torch_pruning as tp

model = resnet18(pretrained=True)
example_inputs = torch.randn(1, 3, 224, 224)

pruner = Pruner(model, "False", tp.importance.GroupNormImportance(p=2))

In [None]:
pruner.prune_model(30)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [None]:
pruner.model

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [None]:
import torch
from torchvision.models import resnet18
import torch_pruning as tp

model = resnet18(pretrained=True)
example_inputs = torch.randn(1, 3, 224, 224)

# 1. Importance criterion
imp = tp.importance.GroupNormImportance(p=2) # or GroupTaylorImportance(), GroupHessianImportance(), etc.

# 2. Initialize a pruner with the model and the importance criterion
ignored_layers = []
for m in model.modules():
    if isinstance(m, torch.nn.Linear) and m.out_features == 1000:
        ignored_layers.append(m) # DO NOT prune the final classifier!

pruner = tp.pruner.MetaPruner( # We can always choose MetaPruner if sparse training is not required.
    model,
    example_inputs,
    importance=imp,
    pruning_ratio=0.5, # remove 50% channels, ResNet18 = {64, 128, 256, 512} => ResNet18_Half = {32, 64, 128, 256}
    ignored_layers=ignored_layers,
)

In [None]:
for i in range(10):
    for group in pruner.step(interactive=True): # Warning: groups must be handled sequentially. Do not keep them as a list.
        print(group) 
        # do whatever you like with the group 
        dep, idxs = group[0] # get the idxs
        target_module = dep.target.module # get the root module
        pruning_fn = dep.handler # get the pruning function
        group.prune()
        # group.prune(idxs=[0, 2, 6]) # It is even possible to change the pruning behaviour with the idxs parameter
    macs, nparams = tp.utils.count_ops_and_params(model, example_inputs)
    print(macs)


--------------------------------
          Pruning Group
--------------------------------
[0] prune_out_channels on layer4.0.downsample.0 (Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)) => prune_out_channels on layer4.0.downsample.0 (Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)), len(idxs)=256
[1] prune_out_channels on layer4.0.downsample.0 (Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)) => prune_out_channels on layer4.0.downsample.1 (BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), len(idxs)=256
[2] prune_out_channels on layer4.0.downsample.1 (BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) => prune_out_channels on _ElementWiseOp_6(AddBackward0), len(idxs)=256
[3] prune_out_channels on _ElementWiseOp_6(AddBackward0) => prune_out_channels on layer4.0.bn2 (BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), len(idxs)=256
[4] prune_out

In [None]:
#base_macs, base_nparams = tp.utils.count_ops_and_params(model, example_inputs)
pruner.step()
macs, nparams = tp.utils.count_ops_and_params(model, example_inputs)
print(macs)
print(nparams)
#print(f"MACs: {base_macs/1e9} G -> {macs/1e9} G, #Params: {base_nparams/1e6} M -> {nparams/1e6} M")

487202536.0
3055880


In [None]:
model

ResNet(
  (conv1): Conv2d(3, 32, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [None]:
for group in pruner.step(interactive=True): # Warning: groups must be handled sequentially. Do not keep them as a list.
    print(group) 
    # do whatever you like with the group 
    dep, idxs = group[0] # get the idxs
    target_module = dep.target.module # get the root module
    pruning_fn = dep.handler # get the pruning function
    group.prune()
    # group.prune(idxs=[0, 2, 6]) # It is even possible to change the pruning behaviour with the idxs parameter
macs, nparams = tp.utils.count_ops_and_params(model, example_inputs)


--------------------------------
          Pruning Group
--------------------------------
[0] prune_out_channels on layer4.0.downsample.0 (Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)) => prune_out_channels on layer4.0.downsample.0 (Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)), len(idxs)=256
[1] prune_out_channels on layer4.0.downsample.0 (Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)) => prune_out_channels on layer4.0.downsample.1 (BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), len(idxs)=256
[2] prune_out_channels on layer4.0.downsample.1 (BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) => prune_out_channels on _ElementWiseOp_6(AddBackward0), len(idxs)=256
[3] prune_out_channels on _ElementWiseOp_6(AddBackward0) => prune_out_channels on layer4.0.bn2 (BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), len(idxs)=256
[4] prune_out

In [None]:
print(f"MACs: {base_macs/1e9} G -> {macs/1e9} G, #Params: {base_nparams/1e6} M -> {nparams/1e6} M")

MACs: 0.487202536 G -> 0.487202536 G, #Params: 3.05588 M -> 3.05588 M


In [None]:
 return scores[None].mean(dim=dim, keepdim=True).squeeze(0)

In [None]:
    def group_importance(self, group):
        handler_map = {
            function.prune_conv_out_channels: 'filter',
            function.prune_conv_in_channels: 'shared_kernel',
        }

        group_imp = []
        group_idxs = []

        for i, (dep, idxs) in enumerate(group):
            if dep.handler in handler_map:
                impo = self.criteria(dep.target.module, handler_map.get(dep.handler), squeeze=True)
                group_imp.append(impo)
                group_idxs.append(group[i].root_idxs)

        reduced_imp = torch.zeros_like(group_imp[0])

        for i, (imp, root_idxs) in enumerate(zip(group_imp, group_idxs)):
            imp = imp.to('cpu')
            reduced_imp = reduced_imp.to('cpu')
            reduced_imp.scatter_add_(0, torch.tensor(root_idxs, device=imp.device), imp)

        reduced_imp /= len(group_imp)

        return reduced_imp.to(default_device())

In [None]:
def _normalize(self, group_importance, normalizer):
        if normalizer is None:
            return group_importance
        elif isinstance(normalizer, typing.Callable):
            return normalizer(group_importance)
        elif normalizer == "sum":
            return group_importance / group_importance.sum()
        elif normalizer == "standarization":
            return (group_importance - group_importance.min()) / (group_importance.max() - group_importance.min()+1e-8)
        elif normalizer == "mean":
            return group_importance / group_importance.mean()
        elif normalizer == "max":
            return group_importance / group_importance.max()

        else:
            raise NotImplementedError

    def _reduce(self, group_imp: typing.List[torch.Tensor], group_idxs: typing.List[typing.List[int]]):
        if len(group_imp) == 0: return group_imp
        if self.group_reduction == 'prod':
            reduced_imp = torch.ones_like(group_imp[0])
        elif self.group_reduction == 'max':
            reduced_imp = torch.ones_like(group_imp[0]) * -99999
        else:
            reduced_imp = torch.zeros_like(group_imp[0])

        for i, (imp, root_idxs) in enumerate(zip(group_imp, group_idxs)):
            imp = imp.to(reduced_imp.device)
            if self.group_reduction == "sum" or self.group_reduction == "mean":
                reduced_imp.scatter_add_(0, torch.tensor(root_idxs, device=imp.device), imp) # accumulated importance
            elif self.group_reduction == 'first':
                if i == 0:
                    reduced_imp.scatter_(0, torch.tensor(root_idxs, device=imp.device), imp)
            elif self.group_reduction == 'gate':
                if i == len(group_imp)-1:
                    reduced_imp.scatter_(0, torch.tensor(root_idxs, device=imp.device), imp)
            elif self.group_reduction is None:
                reduced_imp = torch.stack(group_imp, dim=0) # no reduction
            else:
                raise NotImplementedError
        
        if self.group_reduction == "mean":
            reduced_imp /= len(group_imp)
        return reduced_imp


    group_imp = self._reduce(group_imp, group_idxs)
    group_imp = self._normalize(group_imp, self.normalizer)