# Prune Transformers

> How to prune a transformer with fasterai

In [None]:
#all_slow

> Note: This example code is taken from the fastai [docs](https://docs.fast.ai/tutorial.transformers.html)

In [None]:
#hide
from transformers import GPT2LMHeadModel, GPT2TokenizerFast
from transformers.modeling_utils import Conv1D
from fastai.text.all import *
import fastcore
from fasterai.sparse.all import *

In [None]:
pretrained_weights = 'gpt2'
tokenizer = GPT2TokenizerFast.from_pretrained(pretrained_weights)
model = GPT2LMHeadModel.from_pretrained(pretrained_weights)

In [None]:
path = untar_data(URLs.WIKITEXT_TINY)

In [None]:
#hide
df_train = pd.read_csv(path/'train.csv', header=None)
df_valid = pd.read_csv(path/'test.csv', header=None)

In [None]:
#hide
all_texts = np.concatenate([df_train[0].values, df_valid[0].values])

In [None]:
#hide
class TransformersTokenizer(Transform):
    def __init__(self, tokenizer): self.tokenizer = tokenizer
    def encodes(self, x): 
        toks = self.tokenizer.tokenize(x)
        return tensor(self.tokenizer.convert_tokens_to_ids(toks))
    def decodes(self, x): return TitledStr(self.tokenizer.decode(x.cpu().numpy()))

In [None]:
#hide
splits = [range_of(df_train), list(range(len(df_train), len(all_texts)))]
tls = TfmdLists(all_texts, TransformersTokenizer(tokenizer), splits=splits, dl_type=LMDataLoader)

Token indices sequence length is longer than the specified maximum sequence length for this model (4576 > 1024). Running this sequence through the model will result in indexing errors


In [None]:
#hide
bs,sl = 4,256
dls = tls.dataloaders(bs=bs, seq_len=sl)

In [None]:
#hide
def tokenize(text):
    toks = tokenizer.tokenize(text)
    return tensor(tokenizer.convert_tokens_to_ids(toks))

tokenized = [tokenize(t) for t in progress_bar(all_texts)]

In [None]:
#hide
class TransformersTokenizer(Transform):
    def __init__(self, tokenizer): self.tokenizer = tokenizer
    def encodes(self, x): 
        return x if isinstance(x, Tensor) else tokenize(x)
        
    def decodes(self, x): return TitledStr(self.tokenizer.decode(x.cpu().numpy()))

In [None]:
#hide
tls = TfmdLists(tokenized, TransformersTokenizer(tokenizer), splits=splits, dl_type=LMDataLoader)
dls = tls.dataloaders(bs=bs, seq_len=sl)

In [None]:
#hide
class DropOutput(Callback):
    def after_pred(self): self.learn.pred = self.pred[0]

Let's create our fastai `Learner`.

In [None]:
learn = Learner(dls, model, loss_func=CrossEntropyLossFlat(), cbs=[DropOutput], metrics=Perplexity())

And let's try to extend a given prompt with the pretrained model.

In [None]:
prompt = "\n = Unicorn = \n \n A unicorn is a magical creature with a rainbow tail and a horn"

In [None]:
#hide
prompt_ids = tokenizer.encode(prompt)
inp = tensor(prompt_ids)[None]

In [None]:
preds = learn.model.generate(inp, max_length=40, num_beams=5, temperature=1.5)

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


In [None]:
tokenizer.decode(preds[0].cpu().numpy())

'\n = Unicorn = \n \n A unicorn is a magical creature with a rainbow tail and a horn on its head.\n\nA unicorn is a magical creature with a rainbow tail and a horn'

In [None]:
learn.validate()

(#2) [3.695716381072998,40.2744140625]

In [None]:
learn.fit_one_cycle(1, 1e-4)

epoch,train_loss,valid_loss,perplexity,time
0,3.139103,2.843017,17.167484,07:58


In [None]:
prompt_ids = tokenizer.encode(prompt)
inp = tensor(prompt_ids)[None]

preds = learn.model.generate(inp.cuda(), max_length=40, num_beams=5, temperature=1.5)

tokenizer.decode(preds[0].cpu().numpy())

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


'\n = Unicorn = \n \n A unicorn is a magical creature with a rainbow tail and a horn on its head. It is a member of the <unk> <unk> <unk>'

## Make it sparse !

Let's see now if we retrain our model, this time introducing sparsity

In [None]:
learn = Learner(dls, model, loss_func=CrossEntropyLossFlat(), cbs=[DropOutput], metrics=Perplexity())

Also, when working with text, fastai defines the number of processed batches differently, so we have to adjust our `SparsifyCallback` accordingly (luckily, fastai makes it available as the `n_batches` attribute.

In [None]:
@patch_to(SparsifyCallback)     
def before_fit(self):
    print(f'Pruning of {self.granularity} until a sparsity of {self.end_sparsity}%')
    self.end_epoch = self.n_epoch if self.end_epoch is None else self.end_epoch
    assert self.end_epoch <= self.n_epoch, 'Your end_epoch must be smaller than total number of epoch'

    model = self.learn.model if self.model is None else self.model # Pass a model if you don't want the whole model to be pruned
    self.sparsifier = Sparsifier(model, self.granularity, self.method, self.criteria, self.layer_type)
    self.total_iters = self.end_epoch * self.dls.n_batches
    self.start_iter = self.start_epoch * self.dls.n_batches

Let's define our `SparsifyCallback`. Let's say we want to make our model 30% sparse, by removing the highest-norm weight in each attention head.

In [None]:
sp_cb = SparsifyCallback(end_sparsity=30, granularity='weight', method='local', criteria=large_final, sched_func=sched_onecycle, layer_type=Conv1D)

We now only have to pass our callback to fastai

In [None]:
learn.fit_one_cycle(1, 1e-4, cbs=sp_cb)

Pruning of weight until a sparsity of [30]%
Saving Weights at epoch 0


epoch,train_loss,valid_loss,perplexity,time
0,3.004998,2.860594,17.471899,12:16


Sparsity at the end of epoch 0: [30.0]%
Final Sparsity: [30.0]%
Sparsity in Conv1D 9: 30.00%
Sparsity in Conv1D 10: 30.00%
Sparsity in Conv1D 15: 30.00%
Sparsity in Conv1D 16: 30.00%
Sparsity in Conv1D 21: 30.00%
Sparsity in Conv1D 22: 30.00%
Sparsity in Conv1D 27: 30.00%
Sparsity in Conv1D 28: 30.00%
Sparsity in Conv1D 33: 30.00%
Sparsity in Conv1D 34: 30.00%
Sparsity in Conv1D 39: 30.00%
Sparsity in Conv1D 40: 30.00%
Sparsity in Conv1D 45: 30.00%
Sparsity in Conv1D 46: 30.00%
Sparsity in Conv1D 51: 30.00%
Sparsity in Conv1D 52: 30.00%
Sparsity in Conv1D 57: 30.00%
Sparsity in Conv1D 58: 30.00%
Sparsity in Conv1D 63: 30.00%
Sparsity in Conv1D 64: 30.00%
Sparsity in Conv1D 69: 30.00%
Sparsity in Conv1D 70: 30.00%
Sparsity in Conv1D 75: 30.00%
Sparsity in Conv1D 76: 30.00%
Sparsity in Conv1D 81: 30.00%
Sparsity in Conv1D 82: 30.00%
Sparsity in Conv1D 87: 30.00%
Sparsity in Conv1D 88: 30.00%
Sparsity in Conv1D 93: 30.00%
Sparsity in Conv1D 94: 30.00%
Sparsity in Conv1D 99: 30.00%
Sparsit

And we can check the predicion to the same prompt as before

In [None]:
prompt_ids = tokenizer.encode(prompt)
inp = tensor(prompt_ids)[None]

preds = learn.model.generate(inp.cuda(), max_length=40, num_beams=5, temperature=1.5)

tokenizer.decode(preds[0].cpu().numpy())

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


'\n = Unicorn = \n \n A unicorn is a magical creature with a rainbow tail and a horn @-@ like head. The unicorn is a member of the <unk> <unk>'

In [None]:
#hide
def print_sparsity(model):
    for k,m in enumerate(learn.model.modules()):
        if m.__class__.__name__ == 'Conv1D':
            print(f"Sparsity in {m.__class__.__name__} {k}: {100. * float(torch.sum(m.weight == 0))/ float(m.weight.nelement()):.2f}%")

In [None]:
print_sparsity(learn.model)

Sparsity in Conv1D 9: 30.00%
Sparsity in Conv1D 10: 30.00%
Sparsity in Conv1D 15: 30.00%
Sparsity in Conv1D 16: 30.00%
Sparsity in Conv1D 21: 30.00%
Sparsity in Conv1D 22: 30.00%
Sparsity in Conv1D 27: 30.00%
Sparsity in Conv1D 28: 30.00%
Sparsity in Conv1D 33: 30.00%
Sparsity in Conv1D 34: 30.00%
Sparsity in Conv1D 39: 30.00%
Sparsity in Conv1D 40: 30.00%
Sparsity in Conv1D 45: 30.00%
Sparsity in Conv1D 46: 30.00%
Sparsity in Conv1D 51: 30.00%
Sparsity in Conv1D 52: 30.00%
Sparsity in Conv1D 57: 30.00%
Sparsity in Conv1D 58: 30.00%
Sparsity in Conv1D 63: 30.00%
Sparsity in Conv1D 64: 30.00%
Sparsity in Conv1D 69: 30.00%
Sparsity in Conv1D 70: 30.00%
Sparsity in Conv1D 75: 30.00%
Sparsity in Conv1D 76: 30.00%
Sparsity in Conv1D 81: 30.00%
Sparsity in Conv1D 82: 30.00%
Sparsity in Conv1D 87: 30.00%
Sparsity in Conv1D 88: 30.00%
Sparsity in Conv1D 93: 30.00%
Sparsity in Conv1D 94: 30.00%
Sparsity in Conv1D 99: 30.00%
Sparsity in Conv1D 100: 30.00%
Sparsity in Conv1D 105: 30.00%
Sparsity 

That's it ! You now have a sparse Transformer as performant as the whole model. However, this model is currently not more efficient speed and storage wise. To have such a speed-up, I suggest you to look at the [granularity](https://nathanhubens.github.io/fasterai/granularity.html) section.