# Prune Transformers

> How to prune a transformer with fasterai

In [None]:
#all_slow

> Note: This example code is taken from the fastai [docs](https://docs.fast.ai/tutorial.transformers.html)

In [None]:
from transformers import GPT2LMHeadModel, GPT2TokenizerFast
import fastcore

In [None]:
pretrained_weights = 'gpt2'
tokenizer = GPT2TokenizerFast.from_pretrained(pretrained_weights)
model = GPT2LMHeadModel.from_pretrained(pretrained_weights)

In [None]:
from fastai.text.all import *

In [None]:
path = untar_data(URLs.WIKITEXT_TINY)
path.ls()

(#2) [Path('/home/HubensN/.fastai/data/wikitext-2/train.csv'),Path('/home/HubensN/.fastai/data/wikitext-2/test.csv')]

In [None]:
df_train = pd.read_csv(path/'train.csv', header=None)
df_valid = pd.read_csv(path/'test.csv', header=None)
df_train.head()

Unnamed: 0,0
0,"\n = 2013 – 14 York City F.C. season = \n \n The 2013 – 14 season was the <unk> season of competitive association football and 77th season in the Football League played by York City Football Club , a professional football club based in York , North Yorkshire , England . Their 17th @-@ place finish in 2012 – 13 meant it was their second consecutive season in League Two . The season ran from 1 July 2013 to 30 June 2014 . \n Nigel Worthington , starting his first full season as York manager , made eight permanent summer signings . By the turn of the year York were only above the relegation z..."
1,"\n = Big Boy ( song ) = \n \n "" Big Boy "" <unk> "" I 'm A Big Boy Now "" was the first single ever recorded by the Jackson 5 , which was released by Steeltown Records in January 1968 . The group played instruments on many of their Steeltown compositions , including "" Big Boy "" . The song was neither a critical nor commercial success , but the Jackson family were delighted with the outcome nonetheless . \n The Jackson 5 would release a second single with Steeltown Records before moving to Motown Records . The group 's recordings at Steeltown Records were thought to be lost , but they were re..."
2,"\n = The Remix ( Lady Gaga album ) = \n \n The Remix is a remix album by American recording artist Lady Gaga . Released in Japan on March 3 , 2010 , it contains remixes of the songs from her first studio album , The Fame ( 2008 ) , and her third extended play , The Fame Monster ( 2009 ) . A revised version of the track list was prepared for release in additional markets , beginning with Mexico on May 3 , 2010 . A number of recording artists have produced the songs , including Pet Shop Boys , Passion Pit and The Sound of Arrows . The remixed versions feature both uptempo and <unk> composit..."
3,"\n = New Year 's Eve ( Up All Night ) = \n \n "" New Year 's Eve "" is the twelfth episode of the first season of the American comedy television series Up All Night . The episode originally aired on NBC in the United States on January 12 , 2012 . It was written by Erica <unk> and was directed by Beth McCarthy @-@ Miller . The episode also featured a guest appearance from Jason Lee as Chris and Reagan 's neighbor and Ava 's boyfriend , Kevin . \n During Reagan ( Christina Applegate ) and Chris 's ( Will <unk> ) first New Year 's Eve game night , Reagan 's competitiveness comes out causing Ch..."
4,"\n = Geopyxis carbonaria = \n \n Geopyxis carbonaria is a species of fungus in the genus Geopyxis , family <unk> . First described to science in 1805 , and given its current name in 1889 , the species is commonly known as the charcoal loving elf @-@ cup , dwarf <unk> cup , <unk> <unk> cup , or pixie cup . The small , <unk> @-@ shaped fruitbodies of the fungus are reddish @-@ brown with a whitish fringe and measure up to 2 cm ( 0 @.@ 8 in ) across . They have a short , tapered stalk . Fruitbodies are commonly found on soil where brush has recently been burned , sometimes in great numbers ...."


In [None]:
all_texts = np.concatenate([df_train[0].values, df_valid[0].values])

In [None]:
class TransformersTokenizer(Transform):
    def __init__(self, tokenizer): self.tokenizer = tokenizer
    def encodes(self, x): 
        toks = self.tokenizer.tokenize(x)
        return tensor(self.tokenizer.convert_tokens_to_ids(toks))
    def decodes(self, x): return TitledStr(self.tokenizer.decode(x.cpu().numpy()))

In [None]:
splits = [range_of(df_train), list(range(len(df_train), len(all_texts)))]
tls = TfmdLists(all_texts, TransformersTokenizer(tokenizer), splits=splits, dl_type=LMDataLoader)

Token indices sequence length is longer than the specified maximum sequence length for this model (4576 > 1024). Running this sequence through the model will result in indexing errors


In [None]:
bs,sl = 4,256
dls = tls.dataloaders(bs=bs, seq_len=sl)

In [None]:
def tokenize(text):
    toks = tokenizer.tokenize(text)
    return tensor(tokenizer.convert_tokens_to_ids(toks))

tokenized = [tokenize(t) for t in progress_bar(all_texts)]

In [None]:
class TransformersTokenizer(Transform):
    def __init__(self, tokenizer): self.tokenizer = tokenizer
    def encodes(self, x): 
        return x if isinstance(x, Tensor) else tokenize(x)
        
    def decodes(self, x): return TitledStr(self.tokenizer.decode(x.cpu().numpy()))

In [None]:
tls = TfmdLists(tokenized, TransformersTokenizer(tokenizer), splits=splits, dl_type=LMDataLoader)
dls = tls.dataloaders(bs=bs, seq_len=sl)

In [None]:
class DropOutput(Callback):
    def after_pred(self): self.learn.pred = self.pred[0]

Let's create our fastai `Learner`.

In [None]:
learn = Learner(dls, model, loss_func=CrossEntropyLossFlat(), cbs=[DropOutput], metrics=Perplexity())

And let's try to extend a given prompt with the pretrained model.

In [None]:
prompt = "\n = Unicorn = \n \n A unicorn is a magical creature with a rainbow tail and a horn"

In [None]:
prompt_ids = tokenizer.encode(prompt)
inp = tensor(prompt_ids)[None]

In [None]:
preds = learn.model.generate(inp, max_length=40, num_beams=5, temperature=1.5)

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


In [None]:
tokenizer.decode(preds[0].cpu().numpy())

'\n = Unicorn = \n \n A unicorn is a magical creature with a rainbow tail and a horn on its head.\n\nA unicorn is a magical creature with a rainbow tail and a horn'

In [None]:
from fasterai.sparse.all import *

As we only want to prune the Conv1d layers in the model, we need to change the `prune` and `mask_grad` methods of the `Sparsifier` class accordingly.

In [None]:
@patch_to(Sparsifier)
def prune(self, sparsity):
    for k, m in enumerate(self.model.modules()):
        if m.__class__.__name__ == 'Conv1D':
            weight = self.criteria(m, self.granularity)
            mask = self._compute_mask(self.model, weight, sparsity)
            m.register_buffer("_mask", mask) # Put the mask into a buffer
            self._apply(m)
            
@patch_to(Sparsifier)
def mask_grad(self):
    for k, m in enumerate(self.model.modules()):
        if m.__class__.__name__ == 'Conv1D':
            mask = getattr(m, "_mask")
            if m.weight.grad is not None: # In case some layers are freezed
                m.weight.grad.mul_(mask)

Also, when working with text, fastai defines the number of processed batches differently, so we have to adjust our `SparsifyCallback` accordingly (luckily, fastai makes it available as the `n_batches` attribute.

In [None]:
@patch_to(SparsifyCallback)
def before_fit(self):
    print(f'Pruning of {self.granularity} until a sparsity of {self.end_sparsity}%')
    self.sparsifier = Sparsifier(self.learn.model, self.granularity, self.method, self.criteria)

    self.total_iters = self.n_epoch * self.dls.n_batches
    self.start_iter = self.start_epoch * self.dls.n_batches

In [None]:
learn.validate()

(#2) [3.695716381072998,40.2744140625]

In [None]:
learn.fit_one_cycle(1, 1e-4)

epoch,train_loss,valid_loss,perplexity,time
0,3.105212,2.847042,17.236725,07:44


In [None]:
prompt_ids = tokenizer.encode(prompt)
inp = tensor(prompt_ids)[None]

preds = learn.model.generate(inp.cuda(), max_length=40, num_beams=5, temperature=1.5)

tokenizer.decode(preds[0].cpu().numpy())

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


'\n = Unicorn = \n \n A unicorn is a magical creature with a rainbow tail and a horn @-@ shaped head. It is the most common type of unicorn in the United States.'

In [None]:
sp_cb = SparsifyCallback(end_sparsity=30, granularity='weight', method='local', criteria=large_final, sched_func=sched_agp)

learn.fit_one_cycle(1, 1e-4, cbs=sp_cb)

Pruning of weight until a sparsity of 30%


epoch,train_loss,valid_loss,perplexity,time
0,3.276928,2.875845,17.740412,17:07


Saving Weights at epoch 0
Sparsity at the end of epoch 0: 30.00%
Final Sparsity: 30.00


In [None]:
prompt_ids = tokenizer.encode(prompt)
inp = tensor(prompt_ids)[None]

preds = learn.model.generate(inp.cuda(), max_length=40, num_beams=5, temperature=1.5)

tokenizer.decode(preds[0].cpu().numpy())

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


'\n = Unicorn = \n \n A unicorn is a magical creature with a rainbow tail and a horn on its head. The unicorn is a member of the <unk> <unk> family,'

In [None]:
for k,m in enumerate(learn.model.modules()):
    if m.__class__.__name__ == 'Conv1D':
        print(f"Sparsity in {m.__class__.__name__} {k}: {100. * float(torch.sum(m.weight == 0))/ float(m.weight.nelement()):.2f}%")

Sparsity in Conv1D 9: 30.00%
Sparsity in Conv1D 10: 30.00%
Sparsity in Conv1D 15: 30.00%
Sparsity in Conv1D 16: 30.00%
Sparsity in Conv1D 21: 30.00%
Sparsity in Conv1D 22: 30.00%
Sparsity in Conv1D 27: 30.00%
Sparsity in Conv1D 28: 30.00%
Sparsity in Conv1D 33: 30.00%
Sparsity in Conv1D 34: 30.00%
Sparsity in Conv1D 39: 30.00%
Sparsity in Conv1D 40: 30.00%
Sparsity in Conv1D 45: 30.00%
Sparsity in Conv1D 46: 30.00%
Sparsity in Conv1D 51: 30.00%
Sparsity in Conv1D 52: 30.00%
Sparsity in Conv1D 57: 30.00%
Sparsity in Conv1D 58: 30.00%
Sparsity in Conv1D 63: 30.00%
Sparsity in Conv1D 64: 30.00%
Sparsity in Conv1D 69: 30.00%
Sparsity in Conv1D 70: 30.00%
Sparsity in Conv1D 75: 30.00%
Sparsity in Conv1D 76: 30.00%
Sparsity in Conv1D 81: 30.00%
Sparsity in Conv1D 82: 30.00%
Sparsity in Conv1D 87: 30.00%
Sparsity in Conv1D 88: 30.00%
Sparsity in Conv1D 93: 30.00%
Sparsity in Conv1D 94: 30.00%
Sparsity in Conv1D 99: 30.00%
Sparsity in Conv1D 100: 30.00%
Sparsity in Conv1D 105: 30.00%
Sparsity 