# Pruning

This tutorial will be based on *Torchtext + Padded + BiLSTM* under *Classification* folder.  

In [105]:
import torch, torchdata, torchtext
from torch import nn

import time

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

#make our work comparable if restarted the kernel
SEED = 1234
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

cpu


In [None]:
# torch.cuda.get_device_name(0)

In [106]:
torch.__version__

'2.1.2'

In [107]:
torchtext.__version__

'0.16.2'

## 1. ETL: Loading the dataset

In [108]:
#uncomment this if you are not using our department puffer
# import os
# os.environ['http_proxy']  = 'http://192.41.170.23:3128'
# os.environ['https_proxy'] = 'http://192.41.170.23:3128'

from torchtext.datasets import AG_NEWS
train, test = AG_NEWS()

In [109]:
train_size = len(list(iter(train)))
too_much, train, valid = train.random_split(total_length=train_size, weights = {"too_much": 0.7, "smaller_train": 0.2, "valid": 0.1}, seed=999)

## 2. Preprocessing 

### Tokenizing

The first step is to decide which tokenizer we want to use, which depicts how we split our sentences.

In [110]:
#pip install spacy
#python3 -m spacy download en_core_web_sm
from torchtext.data.utils import get_tokenizer
tokenizer = get_tokenizer('spacy', language='en_core_web_sm')
tokens = tokenizer("We are learning torchtext in AIT!")  #some test
tokens

['We', 'are', 'learning', 'torchtext', 'in', 'AIT', '!']

### Text to integers (numeral)

Next we gonna create function (torchtext called vocabs) that turn these tokens into integers.  Here we use built in factory function <code>build_vocab_from_iterator</code> which accepts iterator that yield list or iterator of tokens.

In [111]:
from torchtext.vocab import build_vocab_from_iterator
def yield_tokens(data_iter):
    for _, text in data_iter:
        yield tokenizer(text)

vocab = build_vocab_from_iterator(yield_tokens(train), specials=['<unk>', '<pad>', '<bos>', '<eos>'])
vocab.set_default_index(vocab["<unk>"])

## 3. FastText Embeddings

In [112]:
from torchtext.vocab import FastText
fast_vectors = FastText(language='simple') #small for easy training

In [113]:
fast_embedding = fast_vectors.get_vecs_by_tokens(vocab.get_itos()).to(device)
# vocab.get_itos() returns a list of strings (tokens), where the token at the i'th position is what you get from doing vocab[token]
# get_vecs_by_tokens gets the pre-trained vector for each string when given a list of strings
# therefore pretrained_embedding is a fully "aligned" embedding matrix

In [114]:
fast_embedding.shape

torch.Size([52828, 300])

## 4. Preparing the dataloader

In torchtext, first thing before the batch iterator is to define how you want to process your text and label.  

In [115]:
text_pipeline  = lambda x: vocab(tokenizer(x))
label_pipeline = lambda x: int(x) - 1 #turn {1, 2, 3, 4} to {0, 1, 2, 3} for pytorch training 

Next, let's make the batch iterator.  Here we create a function <code>collate_fn</code> that define how we want to create our batch.  **We gonna add length of the sequence since packed padded sequences require this.**

In [116]:
from torch.utils.data   import DataLoader
from torch.nn.utils.rnn import pad_sequence

pad_idx = vocab['<pad>'] #++<----making sure our embedding layer ignores pad

def collate_batch(batch):
    label_list, text_list, length_list = [], [], []
    for (_label, _text) in batch:
        label_list.append(label_pipeline(_label))
        processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)
        text_list.append(processed_text)
        length_list.append(processed_text.size(0))  #++<-----packed padded sequences require 
    #criterion expects float labels
    return torch.tensor(label_list, dtype=torch.int64), pad_sequence(text_list,  padding_value=pad_idx, batch_first=True), torch.tensor(length_list, dtype=torch.int64)

Create train, val, and test dataloaders

In [117]:
batch_size = 64

train_loader = DataLoader(train, batch_size=batch_size,
                              shuffle=True, collate_fn=collate_batch)
valid_loader = DataLoader(valid, batch_size=batch_size,
                              shuffle=False, collate_fn=collate_batch)
test_loader  = DataLoader(test, batch_size=batch_size,
                             shuffle=False, collate_fn=collate_batch)

## 5. Model and Evaluate

Here we will simply evaluate the model that we have saved before.  

In [118]:
import torch.nn as nn

class LSTM(nn.Module):
    def __init__(self, input_dim, emb_dim, hid_dim, output_dim, num_layers, bidirectional, dropout):
        super().__init__()
        self.embedding = nn.Embedding(input_dim, emb_dim, padding_idx=pad_idx)
        self.lstm = nn.LSTM(emb_dim, 
                           hid_dim, 
                           num_layers=num_layers, 
                           bidirectional=bidirectional, 
                           dropout=dropout,
                           batch_first=True)
        self.fc = nn.Linear(hid_dim * 2, output_dim)
        
    def forward(self, text, text_lengths):
        embedded = self.embedding(text)
        packed_embedded = nn.utils.rnn.pack_padded_sequence(embedded, text_lengths.to('cpu'), enforce_sorted=False, batch_first=True)
        packed_output, (hn, cn) = self.lstm(packed_embedded)  #if no h0, all zeroes
        output, output_lengths = nn.utils.rnn.pad_packed_sequence(packed_output, batch_first=True)
        hn = torch.cat((hn[-2,:,:], hn[-1,:,:]), dim = 1)
        
        return self.fc(hn)

In [119]:
input_dim  = len(vocab)
hid_dim    = 256
emb_dim    = 300         #**<----change to 300
output_dim = 4 #four classes

#for biLSTM
num_layers = 2
bidirectional = True
dropout = 0.5

model = LSTM(input_dim, emb_dim, hid_dim, output_dim, num_layers, bidirectional, dropout).to(device)
model.embedding.weight.data = fast_embedding #**<------applied the fast text embedding as the initial weights

In [120]:
criterion = nn.CrossEntropyLoss() #combine softmax with cross entropy

In [121]:
def accuracy(preds, y):
    
    predicted = torch.max(preds.data, 1)[1]
    batch_corr = (predicted == y).sum()
    acc = batch_corr / len(y)
    
    return acc

In [122]:
def evaluate(model, loader, criterion, loader_length):
    epoch_loss = 0
    epoch_acc = 0
    model.eval()
    
    with torch.no_grad():
        for i, (label, text, text_length) in enumerate(loader): 
            label = label.to(device) #(batch_size, )
            text  = text.to(device)  #(seq len, batch_size)

            predictions = model(text, text_length).squeeze(1) 
            
            loss = criterion(predictions, label)
            acc  = accuracy(predictions, label)

            epoch_loss += loss.item()
            epoch_acc += acc.item()
        
    return epoch_loss / loader_length, epoch_acc / loader_length

In [123]:
test_loader_length  = len(list(iter(test_loader)))

In [124]:
save_path = f'models/{model.__class__.__name__}.pt'
model.load_state_dict(torch.load(save_path, map_location=torch.device('cpu')))
test_loss, test_acc = evaluate(model, test_loader, criterion, test_loader_length)

print(f'Test Loss: {test_loss:.3f} | Test Acc: {test_acc*100:.2f}%')

Test Loss: 0.251 | Test Acc: 91.54%


## 6. Pruning

Let's try prune and see the effect on the accuracy.  We will use `torch.nn.utils.prune` to prune our neural networks, and also learn how to extend it to implement our own custom pruning technique.

In [125]:
import torch.nn.utils.prune as prune

Let's inspect a layer

In [126]:
fc = model.fc

In [127]:
fc #512 * 4 = 2048 parameters

Linear(in_features=512, out_features=4, bias=True)

Let's see the weights.  Notice `weight` and `bias`.

In [128]:
print(list(fc.named_parameters()))

[('weight', Parameter containing:
tensor([[-0.1492, -0.0344,  0.0106,  ..., -0.0669, -0.0046, -0.0156],
        [ 0.0277,  0.0525, -0.0747,  ...,  0.0243,  0.1342, -0.1535],
        [ 0.0019,  0.1591, -0.0453,  ..., -0.0197,  0.0368, -0.0796],
        [-0.0286, -0.1525, -0.0484,  ..., -0.0525,  0.0902,  0.0980]],
       requires_grad=True)), ('bias', Parameter containing:
tensor([ 0.0069, -0.0114,  0.0139, -0.0159], requires_grad=True))]


In [129]:
for l in list(fc.named_parameters()):
    print(l[0], ':', l[1].detach().numpy().shape)

weight : (4, 512)
bias : (4,)


Let's also see the `buffers`.  Buffer is simply a tensor but does not involved in gradient update.  Buffer is like hidden variable and is useful to store different things (is like hidden in html)

In [130]:
#nothing yet
print(list(fc.named_buffers()))

[]


To prune a module, there are many pruning techniques available in 
`torch.nn.utils.prune` (or [implement](#extending-torch-nn-utils-pruning-with-custom-pruning-functions)
your own by subclassing `BasePruningMethod`). 

### 6.1 Random Pruning

In this example, we will prune at random 5% of the connections in 
the parameter named `weight` in the `fc` layer.

In [131]:
prune.random_unstructured(fc, name="weight", amount=0.95)

Linear(in_features=512, out_features=4, bias=True)

Now let's check the `fc` layer.  Notice the weight is now replaced with `weight_orig`, which is basically the original one.

In [132]:
print(list(fc.named_parameters()))

[('bias', Parameter containing:
tensor([ 0.0069, -0.0114,  0.0139, -0.0159], requires_grad=True)), ('weight_orig', Parameter containing:
tensor([[-0.1492, -0.0344,  0.0106,  ..., -0.0669, -0.0046, -0.0156],
        [ 0.0277,  0.0525, -0.0747,  ...,  0.0243,  0.1342, -0.1535],
        [ 0.0019,  0.1591, -0.0453,  ..., -0.0197,  0.0368, -0.0796],
        [-0.0286, -0.1525, -0.0484,  ..., -0.0525,  0.0902,  0.0980]],
       requires_grad=True))]


Now you may wonder how the pruning happens.  When forwarding, pytorch will apply the `weight_mask` stored in `named_buffers` to `weight_orig`.   Notice how the `named_buffers` has a `weight_mask` with bunch of 0 and 1, where this mask will be multipled with `weight_orig` resulting in some weight pruned.

In [133]:
print(list(fc.named_buffers()))

[('weight_mask', tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 1.,  ..., 0., 0., 0.]]))]


The pruning techniques implemented in `torch.nn.utils.prune` compute the pruned version of the weight (by 
combining the mask with the original parameter) and store them in the attribute `weight`. Note, this is no longer a parameter of the `fc`, it is now simply an attribute.

In [134]:
print(fc.weight)

tensor([[-0.0000, -0.0000,  0.0000,  ..., -0.0000, -0.0000, -0.0000],
        [ 0.0000,  0.0000, -0.0000,  ...,  0.0000,  0.0000, -0.0000],
        [ 0.0000,  0.0000, -0.0000,  ..., -0.0000,  0.0000, -0.0000],
        [-0.0000, -0.0000, -0.0484,  ..., -0.0000,  0.0000,  0.0000]],
       grad_fn=<MulBackward0>)


Last thing, you may wonder how you can be sure PyTorch prune the network really during forward.   Actually, pruning is applied prior to each forward pass using PyTorch's `forward_pre_hooks`. Specifically, when the `fc` is pruned, it will acquire a `forward_pre_hook` for each parameter associated with it that gets pruned. In this case, since we have so far only pruned the original parameter named `fc`, only one hook will be present. **Please note that you don't have to do anything.  I am just trying to say the internal mechnanism how PyTorch does stuff**.

In [135]:
print(fc._forward_pre_hooks)

OrderedDict([(11, <torch.nn.utils.prune.RandomUnstructured object at 0x176bab290>)])


Lastly, let's try classification using our randomly pruned version.

In [136]:
test_loss, test_acc = evaluate(model, test_loader, criterion, test_loader_length)
print(f'Test Loss: {test_loss:.3f} | Test Acc: {test_acc*100:.2f}%')

Test Loss: 1.164 | Test Acc: 87.39%


Wow, pruning 95% still get 88% accuracy!  In case you want to make your model permanently pruned, just use the command `remove`.

In [137]:
#make it permanent....
prune.remove(fc, 'weight')
print(list(fc.named_parameters()))  #notice now the weight is the pruned version and weight_orig is gone.

[('bias', Parameter containing:
tensor([ 0.0069, -0.0114,  0.0139, -0.0159], requires_grad=True)), ('weight', Parameter containing:
tensor([[-0.0000, -0.0000,  0.0000,  ..., -0.0000, -0.0000, -0.0000],
        [ 0.0000,  0.0000, -0.0000,  ...,  0.0000,  0.0000, -0.0000],
        [ 0.0000,  0.0000, -0.0000,  ..., -0.0000,  0.0000, -0.0000],
        [-0.0000, -0.0000, -0.0484,  ..., -0.0000,  0.0000,  0.0000]],
       requires_grad=True))]


### 6.2 Magnitude Pruning

We can also prune less randomly, but based on the lowest magnitude, which indicates that those parameters has less importance.

In [138]:
#reset the pruning
def reset():
    model = LSTM(input_dim, emb_dim, hid_dim, output_dim, num_layers, bidirectional, dropout).to(device)
    model.load_state_dict(torch.load(save_path, map_location=torch.device('cpu')))
    return model
    
model = reset()
fc = model.fc

In [139]:
#prune based on the lowest magnitude based on l1 norm
prune.l1_unstructured(fc, name='weight', amount=0.95)
print(fc.weight)

tensor([[-0.1492, -0.0000,  0.0000,  ..., -0.0000, -0.0000, -0.0000],
        [ 0.0000,  0.0000, -0.0000,  ...,  0.0000,  0.0000, -0.1535],
        [ 0.0000,  0.1591, -0.0000,  ..., -0.0000,  0.0000, -0.0000],
        [-0.0000, -0.1525, -0.0000,  ..., -0.0000,  0.0000,  0.0000]],
       grad_fn=<MulBackward0>)


In case you want to prune based on L2 norm on specific dim, use `ln_structured`, where structured means on particular dimension, while `n` refers to the type of norm.  I have commented in case you wanna try.

In [None]:
# prune.ln_structured(fc, name="weight", amount=0.5, n=2, dim=0)
# print(fc.weight)

In [140]:
test_loss, test_acc = evaluate(model, test_loader, criterion, test_loader_length)
print(f'Test Loss: {test_loss:.3f} | Test Acc: {test_acc*100:.2f}%')

Test Loss: 0.542 | Test Acc: 91.22%


### 6.3 Pruning multiple parameters in a model 

By specifying the desired pruning technique and parameters, we can easily prune multiple tensors in a network.

In [141]:
model = reset()

for name, module in model.named_modules():
    # prune 99% of connections in all embedding layers 
    if isinstance(module, torch.nn.Embedding):
        prune.l1_unstructured(module, name='weight', amount=0.99)  #forward weight (you can check the name in named_parameters)
    # prune 50% of connections in all linear layers 
    elif isinstance(module, torch.nn.Linear):
        prune.l1_unstructured(module, name='weight', amount=0.5)

print(dict(model.named_buffers()).keys())  # to verify that all masks exist

dict_keys(['embedding.weight_mask', 'fc.weight_mask'])


In [142]:
test_loss, test_acc = evaluate(model, test_loader, criterion, test_loader_length)
print(f'Test Loss: {test_loss:.3f} | Test Acc: {test_acc*100:.2f}%')

Test Loss: 0.820 | Test Acc: 69.60%


### 6.4 Global pruning

So far, we only looked at what is usually referred to as "local" pruning, i.e. the practice of pruning tensors in a model one by one, by  comparing the statistics (weight magnitude, activation, gradient, etc.) of  each entry exclusively to the other entries in that tensor. However, a  common and perhaps more powerful technique is to prune the model all at  once, by removing (for example) the lowest 20% of connections across the  whole model, instead of removing the lowest 20% of connections in each  layer. This is likely to result in different pruning percentages per layer.

Let's see how to do that using `global_unstructured` from `torch.nn.utils.prune`.

In [143]:
model = reset()

parameters_to_prune = (
    (model.embedding, 'weight'),
    (model.lstm, 'weight_ih_l0'),
    (model.lstm, 'weight_hh_l0'),
    (model.fc, 'weight'),
    (model.fc, 'bias'),
)

prune.global_unstructured(
    parameters_to_prune,
    pruning_method=prune.L1Unstructured,
    amount=0.7,
)

In [144]:
print(
    "Sparsity in embedding.weight: {:.2f}%".format(
        100. * float(torch.sum(model.embedding.weight == 0))
        / float(model.embedding.weight.nelement())
    )
)
print(
    "Sparsity in lstm.weight_ih_l0: {:.2f}%".format(
        100. * float(torch.sum(model.lstm.weight_ih_l0 == 0))
        / float(model.lstm.weight_ih_l0.nelement())
    )
)
print(
    "Sparsity in lstm.weight_hh_l0: {:.2f}%".format(
        100. * float(torch.sum(model.lstm.weight_hh_l0 == 0))
        / float(model.lstm.weight_hh_l0.nelement())
    )
)
print(
    "Sparsity in fc.weight: {:.2f}%".format(
        100. * float(torch.sum(model.fc.weight == 0))
        / float(model.fc.weight.nelement())
    )
)
print(
    "Sparsity in fc.bias: {:.2f}%".format(
        100. * float(torch.sum(model.fc.bias == 0))
        / float(model.fc.bias.nelement())
    )
)
print(
    "Global sparsity: {:.2f}%".format(
        100. * float(
            torch.sum(model.embedding.weight == 0)
            + torch.sum(model.lstm.weight_ih_l0 == 0)
            + torch.sum(model.lstm.weight_hh_l0 == 0)
            + torch.sum(model.fc.weight == 0)
            + torch.sum(model.fc.bias == 0)
        )
        / float(
            model.embedding.weight.nelement()
            + model.lstm.weight_ih_l0.nelement()
            + model.lstm.weight_hh_l0.nelement()
            + model.fc.weight.nelement()
            + model.fc.bias.nelement()
        )
    )
)

Sparsity in embedding.weight: 69.48%
Sparsity in lstm.weight_ih_l0: 78.33%
Sparsity in lstm.weight_hh_l0: 91.77%
Sparsity in fc.weight: 65.43%
Sparsity in fc.bias: 100.00%
Global sparsity: 70.00%


In [145]:
test_loss, test_acc = evaluate(model, test_loader, criterion, test_loader_length)
print(f'Test Loss: {test_loss:.3f} | Test Acc: {test_acc*100:.2f}%')

Test Loss: 0.313 | Test Acc: 90.30%


Wow, pruning 70% of parameters...but still getting 90% of accuracy!

### 6.5 Custom pruning function

To implement your own pruning function, you can extend the `nn.utils.prune` module by subclassing the `BasePruningMethod`
base class, the same way all other pruning methods do. The base class implements the following methods for you: `__call__`, `apply_mask`,`apply`, `prune`, and `remove`. Beyond some special cases, you shouldn't have to reimplement these methods for your new pruning technique.

You will, however, have to implement `__init__` (the constructor), and `compute_mask` (the instructions on how to compute the mask
for the given tensor according to the logic of your pruning technique). In addition, you will have to specify which type of
pruning this technique implements (supported options are `global`,`structured`, and `unstructured`). This is needed to determine how to combine masks in the case in which pruning is applied iteratively. In other words, when pruning a pre-pruned parameter, the current prunining techique is expected to act on the unpruned portion of the parameter. Specifying the `PRUNING_TYPE` will enable the `PruningContainer` (which handles the iterative application of pruning masks) to correctly identify the slice of the parameter to prune.

Let's assume, for example, that you want to implement a pruning technique that prunes every other entry in a tensor (or -- if the
tensor has previously been pruned -- in the remaining unpruned portion of the tensor). This will be of `PRUNING_TYPE='unstructured'`
because it acts on individual connections in a layer and not on entire units/channels (`'structured'`), or across different parameters
(`'global'`).

In [146]:
class ExamplePruningMethod(prune.BasePruningMethod):
    """Prune every other entry in a tensor
    """
    PRUNING_TYPE = 'unstructured'

    def compute_mask(self, t, default_mask):
        mask = default_mask.clone()
        mask.view(-1)[::2] = 0 
        return mask

Now, to apply this to a parameter in an ``nn.Module``, you should also provide a simple function that instantiates the method and
applies it.

In [147]:
def prune_custom_unstructured(module, name):
    """Prunes tensor corresponding to parameter called `name` in `module`
    by removing every other entry in the tensors.
    Modifies module in place (and also return the modified module) 
    by:
    1) adding a named buffer called `name+'_mask'` corresponding to the 
    binary mask applied to the parameter `name` by the pruning method.
    The parameter `name` is replaced by its pruned version, while the 
    original (unpruned) parameter is stored in a new parameter named 
    `name+'_orig'`.

    Args:
        module (nn.Module): module containing the tensor to prune
        name (string): parameter name within `module` on which pruning
                will act.

    Returns:
        module (nn.Module): modified (i.e. pruned) version of the input
            module
    
    Examples:
        >>> m = nn.Linear(3, 4)
        >>> foobar_unstructured(m, name='bias')
    """
    ExamplePruningMethod.apply(module, name)
    return module

In [148]:
model = reset()
prune_custom_unstructured(model.fc, name='bias')

print(model.fc.bias_mask)

tensor([0., 1., 0., 1.])


## Conclusion

In practice, we may want to perform *iterative pruning* where we iteratively prune the lowest magnitude, perhaps after each finetuning epoch, until sparsity is reached.   You will be very amazed how much can be pruned without much impact on the performance.