# Training and trimming your own model

In this notebook, we detail how to use the learnable masks to trim models that are not listed in the paper or in this repository. 

In [116]:
from collections import OrderedDict
import os 
import random
import typing as tp

import torch 
import torch.nn as nn 

## Adding learnable masks within the model

Let us first define a toy network to illustrate the approach. We will define a small convolutional network, however the approach also works for other architectures (as illustrated in the paper).

In [117]:
class ToyModel(nn.Module):
    def __init__(self):
        super().__init__()
        channels = [16, 32, 64]
        out_dim = 10
        in_dim = 1
        network = []
        for c_in, c_out in zip([in_dim]+channels, channels):
            network.append(nn.Conv1d(c_in, c_out, kernel_size=5, stride=3))
            network.append(nn.BatchNorm1d(c_out))
            network.append(nn.ReLU())
        network.append(nn.Conv1d(channels[-1], out_dim, kernel_size=3, padding='same'))
        self.network = nn.Sequential(*network)
    
    def forward(self, x):
        return self.network(x).mean(-1)

model = ToyModel()

We can inspect the architecture of our network and see that it is composed of four 1-d convolutional layers, followed by batch normalization and ReLU activation

In [118]:
print(model)

ToyModel(
  (network): Sequential(
    (0): Conv1d(1, 16, kernel_size=(5,), stride=(3,))
    (1): BatchNorm1d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Conv1d(16, 32, kernel_size=(5,), stride=(3,))
    (4): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU()
    (6): Conv1d(32, 64, kernel_size=(5,), stride=(3,))
    (7): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU()
    (9): Conv1d(64, 10, kernel_size=(3,), stride=(1,), padding=same)
  )
)


Lets also save the number of parameters of the model for later

In [119]:
untrimmed_param_count = sum([p.numel() for p in model.parameters()])

Lets say we want to attach learnable mask modules after each BatchNorm (here, we do not trim the last layer to keep the "embedding" dimension unchanged). We first need to define such mask modules, add them to the network, and add hooks to each BatchNorm to post-process their output by applying the mask

In [120]:
def _unsqueeze_as(x: torch.Tensor, target: torch.Tensor
                  ) -> torch.Tensor:
    """
    Add dimensions to x to match number of dimensions of target
    """
    for _ in range(target.dim()-x.dim()):
        x=x.unsqueeze(-1)
    return x

class MaskModule(nn.Module):
    def __init__(self, num_features: int, feature_dim: int=1)->None:
        """
        Learnable mask module. 
        
        args:
        num_features (int): Number of features (e.g. convolutional channels, attention heads, ...) of the input
        feature_dim (int): Feature dimension of the input
        """
        super().__init__()
        self.feature_dim=feature_dim
        self.num_features=num_features
        self.mask = nn.Parameter(torch.ones(1, num_features))
    
    @property
    def binary_mask(self):
        _bin_mask = torch.round(torch.sigmoid(self.mask)) # Quantize mask values
        bin_mask = self.mask + (_bin_mask-self.mask).detach() # Bypass rounding operator during backward
        return bin_mask
    
    @property
    def masked_indexes(self):
        return torch.where(self.binary_mask==0)[-1].tolist()

    @property
    def num_masked_features(self):
        return int(self.num_features-self.binary_mask.sum().item())
    
    def set_mask(self, indexes: tp.List[str], keep: bool=False
                 )-> None:
        if keep:
            self.mask.data = torch.zeros_like(self.mask.data)
            self.mask.data[:, indexes]=1.
        else:
            self.mask.data[:, indexes]=0.
    
    def reset_mask(self):
        self.mask.data = torch.ones_like(self.mask.data)
        
    def forward(self, x: torch.Tensor)->torch.Tensor:
        mask = _unsqueeze_as(self.binary_mask, x)
        mask = mask.transpose(1, self.feature_dim)
        return x*mask

We see that, with our architecture, BatchNorm layers will be registered correspond to 'network.1', 'network.4' and 'network.7', hence we append the mask modules to these layers. 

In [121]:
def mask_hook(model, input, output):
    """
    Hook to mask output of a layer
    """
    return model.mask_module(output)

for layer in [1, 4, 7]:
    num_features = model.network[layer].num_features
    model.network[layer].mask_module = MaskModule(num_features=num_features, feature_dim=1)
    model.network[layer].register_forward_hook(mask_hook)

### Creating the sparsity inducing loss

As explained in the paper, the model is guided towards sparsity by adding a loss that minimizes the mean of the parameters values over the masks (small values will then be zeroed during the sigmoid+round operation). Let us define such a loss 

In [122]:
def gather_masks(masked_model: nn.Module
                 ) -> tp.Dict[str, nn.Parameter]:
    masks = {}
    for n, m in masked_model.named_modules():
        if isinstance(m, MaskModule):
            masks[n] = m.mask
    return masks

class SparsityLoss(nn.Module):
    def __init__(self, 
                 target: float = 0.5, 
                 power: int = 2 
                 ) -> None:
        super().__init__()
        self.target = target 
        self.power = power
        
    def forward(self, masked_model: nn.Module)-> float:
        """
        Note that this loss is applied to the model itself, NOT its output
        """
        masks = gather_masks(masked_model)
        if not len(masks.keys()):
            return torch.tensor(0.)
        loss = 0
        for mask in masks.values():
            loss = loss+(torch.mean(torch.sigmoid(mask))/len(masks))
        loss = (loss-self.target)**(self.power)
        return loss

Now, let's say our objective is classification, and let's denote x the input of the network, y its prediction, and label the ground truth. To add the sparsity loss to this objective, we would do as follow.

In [123]:
x = torch.randn(4, 1, 44100)
y = model(x)
labels = torch.randint(0, 10, size=(4, ))
cross_entropy_fn = nn.CrossEntropyLoss()
sparsity_loss_fn = SparsityLoss()
sparsity_loss_weight = 100

cross_entropy_loss = cross_entropy_fn(y, labels) 
sparsity_loss = sparsity_loss_fn(model)
total_loss = cross_entropy_loss+sparsity_loss_weight*sparsity_loss

print(f'Cross-entropy loss : {cross_entropy_loss.item()}')
print(f'Sparsity loss : {sparsity_loss.item()}')
print(f'Total loss : {total_loss.item()}')

Cross-entropy loss : 2.3212904930114746
Sparsity loss : 0.0533880740404129
Total loss : 7.660098075866699


### Trimming the model

Once, the model is trained, you will need to convert the masks into trimming indexes, and remove the masked units accordingly. Even though many packages allow to this "automatically", we demonstrate here how to do it on your own, which should help you avoiding common problems that occur when trimming neural networks.

In [124]:
# We did not train the model in this notebook, hence the masks are initialized with ones. 
# Hence, we will just mask random features

for n, m in model.named_modules():
    if isinstance(m, MaskModule):
        m.reset_mask()
        masked_channels = random.choices(list(range(m.num_features)), k=10)
        m.set_mask(masked_channels)

Lets feed a random input to our model, this way we can check later that 'true' trimming did not alter the output of the model

In [125]:
x = torch.randn(4, 1, 44100)
y_untrimmed = model(x)

Let us say that, for the first mask, the 4-th feature is zeroed (after sigmoid+round). This means that we can remove the 4-th _output_ channel of the preceeding convolution, as well as the 4-th feature of the precedding batch norm. Note that we also have to remove the 4-th _input_ channel of the following convolution as well.

In [126]:
# First step : convert the mask into trimming indexes, then remove the mask and its associated hooks

for layer in [1, 4, 7]:
    trimmed_out_features = model.network[layer].mask_module.masked_indexes
    num_features = model.network[layer].mask_module.num_features
    kept_out_features = [feat for feat in range(num_features) if feat not in trimmed_out_features] # feature that will not be trimmed
    model.network[layer].kept_out_features = kept_out_features
    delattr(model.network[layer], 'mask_module') # Delete mask
    model.network[layer]._forward_hooks = OrderedDict() # remove hook
    
# Second step : associate these kept features to the previous and following convolution layers
for cur_bn, prev_conv, next_conv in zip([1, 4, 7], [0, 3, 6], [3, 6, 9]):
    model.network[prev_conv].kept_out_features = model.network[cur_bn].kept_out_features
    model.network[next_conv].kept_in_features = model.network[cur_bn].kept_out_features


We can now remove all units that should not be kept !

In [127]:
for n, m in model.named_modules():
    if hasattr(m, 'kept_in_features'):
        m.weight.data = m.weight.data[:, m.kept_in_features]
    if hasattr(m, 'kept_out_features'):
        m.weight.data = m.weight.data[m.kept_out_features]
        if hasattr(m, 'bias') and m.bias is not None:
            m.bias.data = m.bias.data[m.kept_out_features]
        if hasattr(m, 'running_mean') and m.running_mean is not None:
            m.running_mean = m.running_mean[m.kept_out_features]
        if hasattr(m, 'running_var') and m.running_var is not None:
            m.running_var = m.running_var[m.kept_out_features]

In [130]:
y_trimmed = model(x)
trimmed_param_count = sum([p.numel() for p in model.parameters()])

print(f'Same output for untrimmed and trimmed models : {torch.allclose(y_untrimmed, y_trimmed)}')
print('Number of parameters : ')
print(f'\t - untrimmed : {untrimmed_param_count}')
print(f'\t - trimmed : {trimmed_param_count}')

Same output for untrimmed and trimmed models : True
Number of parameters : 
	 - untrimmed : 15146
	 - trimmed : 8957


## Tips and advices

- If your goal when using this method is to make models smaller, be careful to attach the mask modules at places that are trimmable  
- Sometimes, adding a forward hook doesn't work, so you will have to manually re-write the entire forward method (for examples, see networks/mask_utils.py)
- This approach can lead to layer trimming in certain cases (notably for transformers, where bypassing an entire self-attention is sometimes beneficial). In this case, you will also have to modify the forward method of the trimmed layer (see trim/trim_musicfm.py for an example)