In [1]:
import warnings
warnings.filterwarnings('ignore')


In [2]:
from fastai.vision import *
from fastai.core import *
from fastai.callbacks import *

import torch
import torch.nn as nn
import torch.nn.functional as F

class Sparsifier():
    def __init__(self, granularity, method, criteria):
        self.granularity = granularity
        self.method = method
        self.criteria = criteria
        
    def prune(self, model, sparsity):
        
        for k, m in enumerate(model.modules()):
            if isinstance(m, nn.Conv2d):
                if self.criteria == 'l1':
                    weight = self._l1_norm(m.weight)
                elif self.criteria == 'taylor':
                    weight = self._taylor_crit(m.weight)
                else: raise NameError('Invalid Criteria')
                
                mask = self._compute_mask(model, weight, sparsity)
                print(mask.shape)
                print(m.weight.shape)
                mask = make_broadcastable(mask, m.weight)
                m.register_buffer("_mask", mask) # Put the mask into a buffer
                self._apply(m)
            
        return model
    
    def _apply(self, module):
        '''
        Apply the mask and freeze the gradient so the corresponding weights are not updated anymore
        '''
        mask = getattr(module, "_mask")
        module.weight.data.mul_(mask)
        if module.weight.grad is not None: # In case some layers are freezed
            module.weight.grad.mul_(mask)

        if self.granularity == 'filter': # If we remove complete filters, we want to remove the bias as well
            if module.bias is not None:
                module.bias.data.mul_(mask.squeeze())
                if module.bias.grad is not None: # In case some layers are freezed
                    module.bias.grad.mul_(mask.squeeze())
    
    def _l1_norm(self, weight):

        if self.granularity == 'weight':
            w = weight.view(-1).abs().clone()
            
        elif self.granularity == 'vector':
            w = torch.norm(weight, p=1, dim=(3)).view(-1)/(weight.shape[3])

        elif self.granularity == 'kernel':
            w = torch.norm(weight, p=1, dim=(2,3)).view(-1)/(weight.shape[2]*weight.shape[3])
        
        elif self.granularity == 'filter':
            w = torch.norm(weight, p=1, dim=(1,2,3))/(weight.shape[1]*weight.shape[2]*weight.shape[3])

        else: raise NameError('Invalid Granularity') 
        
        return w
        
    def _taylor_crit(self, weight):
        if weight.grad is not None:
            if self.granularity == 'weight':
                w = (weight*weight.grad).data.pow(2).view(-1)

            elif self.granularity == 'vector':
                w = (weight*weight.grad).data.pow(2).sum(dim=(3)).view(-1).clone()

            elif self.granularity == 'kernel':
                w = (weight*weight.grad).data.pow(2).sum(dim=(2,3)).view(-1).clone()     
                
            elif self.granularity == 'filter':       
                w = (weight*weight.grad).data.pow(2).sum(dim=(1,2,3))

            else: raise NameError('Invalid Granularity') 

            return w

    
    def _compute_mask(self, model, weight, sparsity):
        '''
        Compute the binary masks
        '''
        if self.method == 'global':
            global_weight = []
            
            for k, m in enumerate(model.modules()):
                if isinstance(m, nn.Conv2d):
                    if self.criteria == 'l1':
                        w = self._l1_norm(m.weight)
                    elif self.criteria == 'taylor':
                        w = self._taylor_crit(m.weight)
                        
                    global_weight.append(w)

            global_weight = torch.cat(global_weight)
            threshold = torch.quantile(global_weight, sparsity/100) # Compute the threshold globally
            
        elif self.method == 'local': 
            threshold = torch.quantile(weight, sparsity/100)
            
        else: raise NameError('Invalid Method')
            
        # Make sure we don't remove every weight of a given layer
        if threshold > weight.max(): threshold = weight.max()

        mask = weight.ge(threshold).to(dtype=weight.dtype)

        return mask
        

class SparsifyCallback(LearnerCallback):
        
    def __init__(self, learn:Learner, sparsity, granularity, method, criteria, sched_func):
        super().__init__(learn)
        self.sparsity, self.granularity, self.method, self.criteria, self.sched_func = sparsity, granularity, method, criteria, sched_func
        self.sparsifier = Sparsifier(self.granularity, self.method, self.criteria)
        self.batches = math.floor(len(learn.data.train_ds)/learn.data.train_dl.batch_size)
    
    def on_train_begin(self, n_epochs:int, **kwargs):
        print(f'Pruning of {self.granularity} until a sparsity of {self.sparsity}%')
        self.total_iters = n_epochs * self.batches
        
    def on_epoch_end(self, epoch, **kwargs):
        print(f'Sparsity at epoch {epoch}: {self.current_sparsity:.2f}%')
        
    def on_batch_begin(self,iteration, **kwargs):
        self.set_sparsity(iteration)
        
    def on_step_end(self, iteration, **kwargs):
        self.sparsifier.prune(self.learn.model, self.current_sparsity)
        
    def set_sparsity(self, iteration):
        self.current_sparsity = self.sched_func(start=0.0001, end=self.sparsity, pct=(iteration+1)/self.total_iters)
    
    def on_train_end(self, **kwargs):
        print(f'Final Sparsity: {self.current_sparsity:.2f}')

In [3]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from fastai.core import *
from fastai.callbacks import *

def make_broadcastable(input, target):
    target_shape = target.shape
    output_shape = [*target.shape]
    
    for i in range(len(target_shape)):
        input_size = np.prod(input.shape)
        target_size = np.prod(np.array(target_shape[:i+1]))
        if input_size >= target_size:
            output_shape[i]=target_shape[i]
        else:
            output_shape[i]=1
        
    new_input = input.reshape(*output_shape)        
    return new_input

def annealing_gradual(start:Number, end:Number, pct:float)->Number:
    "Gradually anneal from `start` to `end` as pct goes from 0.0 to 1.0."
    return end + start - end * (1 - pct)**3

In [4]:
path = untar_data(URLs.IMAGENETTE_160)

In [5]:
data = (ImageList.from_folder(path)
                .split_by_folder(train='train', valid='val')
                .label_from_folder()
                .transform(get_transforms(), size=64)
                .databunch(bs=64)
                .normalize(imagenet_stats))

In [9]:
cd fasterai

/space/storage/homes/nathan/Code/fasterai


In [10]:
from fasterai.sparsifier_test import *

In [11]:
from fasterai.distillation import *

In [56]:
learn = Learner(data, models.resnet18(), metrics=[accuracy])
learn.fit_one_cycle(3, 1e-3)
learn.validate()

In [57]:
learn_1 = Learner(data, models.resnet18(), metrics=[accuracy])

In [51]:
prune = SparsifyCallback(learn_1, sparsity=10, granularity='weight', method='local', criteria='l1', sched_func=annealing_linear, lth_reset=True)
KD = KnowledgeDistillation(learn_1, teacher=learn)

In [52]:
learn_1.fit_one_cycle(3, 1e-3, callbacks=[prune, KD])
learn_1.validate()

Pruning of weight until a sparsity of 10%


Saving Weights at epoch 0
Sparsity at the end of epoch 0: 3.33%
Sparsity at the end of epoch 1: 6.67%
Sparsity at the end of epoch 2: 10.00%
Final Sparsity: 10.00


In [53]:
learn_1 = Learner(data, models.resnet18(), metrics=[accuracy])

In [54]:
prune = SparsifyCallback(learn_1, sparsity=10, granularity='weight', method='local', criteria='l1', sched_func=annealing_linear)

In [55]:
learn_1.fit_one_cycle(3, 1e-3, callbacks=[prune])
learn_1.validate()

Pruning of weight until a sparsity of 10%


Saving Weights at epoch 0
Sparsity at the end of epoch 0: 3.33%
Sparsity at the end of epoch 1: 6.67%
Sparsity at the end of epoch 2: 10.00%
Final Sparsity: 10.00


In [20]:
learn_teacher = Learner(data, models.resnet18(pretrained=True), metrics=[accuracy])

In [21]:
learn_teacher.fit_one_cycle(3, 1e-3)

epoch,train_loss,valid_loss,accuracy,time
0,1.375059,0.860187,0.749809,00:07
1,0.650054,0.497418,0.841783,00:08
2,0.386505,0.392868,0.875414,00:08


In [45]:
KD = KnowledgeDistillation(learn_st, teacher=learn)

In [46]:
learn_st.fit_one_cycle(3, 1e-4, callbacks=[KD])


epoch,train_loss,valid_loss,accuracy,time
0,0.899284,0.897972,0.712357,00:09
1,0.87488,0.866431,0.717707,00:09
2,0.829745,0.840928,0.726369,00:10


In [47]:
learn_st.validate()

In [22]:
learn.fit_one_cycle(3, 1e-3, callbacks=[SparsifyCallback(learn, sparsity=50, granularity='weight', method='global', criteria='l1', sched_func=annealing_cos)])

Pruning of weight until a sparsity of 50%


torch.Size([9408])
torch.Size([64, 3, 7, 7])
torch.Size([36864])
torch.Size([64, 64, 3, 3])
torch.Size([36864])
torch.Size([64, 64, 3, 3])
torch.Size([36864])
torch.Size([64, 64, 3, 3])
torch.Size([36864])
torch.Size([64, 64, 3, 3])
torch.Size([73728])
torch.Size([128, 64, 3, 3])
torch.Size([147456])
torch.Size([128, 128, 3, 3])
torch.Size([8192])
torch.Size([128, 64, 1, 1])
torch.Size([147456])
torch.Size([128, 128, 3, 3])
torch.Size([147456])
torch.Size([128, 128, 3, 3])
torch.Size([294912])
torch.Size([256, 128, 3, 3])
torch.Size([589824])
torch.Size([256, 256, 3, 3])
torch.Size([32768])
torch.Size([256, 128, 1, 1])
torch.Size([589824])
torch.Size([256, 256, 3, 3])
torch.Size([589824])
torch.Size([256, 256, 3, 3])
torch.Size([1179648])
torch.Size([512, 256, 3, 3])
torch.Size([2359296])
torch.Size([512, 512, 3, 3])
torch.Size([131072])
torch.Size([512, 256, 1, 1])
torch.Size([2359296])
torch.Size([512, 512, 3, 3])
torch.Size([2359296])
torch.Size([512, 512, 3, 3])
torch.Size([9408])


KeyboardInterrupt: 

In [439]:
learn.validate()

[1.6817698, tensor(0.4283)]

In [450]:
learn.validate()

[2.055264, tensor(0.2713)]

## KD Pruning

In [None]:
class SparsifyCallback(LearnerCallback):
        
    def __init__(self, learn:Learner, sparsity, granularity, method, criteria, sched_func, start_epoch=0, lth_reset=False, rewind_epoch=0, reset_end=False):
        super().__init__(learn)
        self.sparsity, self.granularity, self.method, self.criteria, self.sched_func = sparsity, granularity, method, criteria, sched_func
        self.reset_end, self.rewind_epoch, self.start_epoch, self.lth_reset = reset_end, rewind_epoch, start_epoch, lth_reset
        self.sparsifier = Sparsifier(self.learn.model, self.granularity, self.method, self.criteria)
        self.batches = math.floor(len(learn.data.train_ds)/learn.data.train_dl.batch_size)
        self.current_sparsity, self.previous_sparsity = 0,0
        self.T, self.α = 20, 0.7

        assert self.start_epoch>=self.rewind_epoch, 'You must rewind to an epoch before the start of the pruning process'
    
    def on_train_begin(self, n_epochs:int, **kwargs):
        print(f'Pruning of {self.granularity} until a sparsity of {self.sparsity}%')
        self.total_iters = n_epochs * self.batches
        self.start_iter = self.start_epoch * self.batches
        
    def on_epoch_end(self, epoch, **kwargs):
        print(f'Sparsity at the end of epoch {epoch}: {self.current_sparsity:.2f}%')
    
    def on_epoch_begin(self, epoch, **kwargs):
        if epoch == self.rewind_epoch:
            print(f'Saving Weights at epoch {epoch}')
            self.sparsifier._save_weights()
        
    def on_batch_begin(self, iteration, epoch, **kwargs):
        if epoch>=self.start_epoch:
            self.set_sparsity(iteration)
            self.teacher = self.learn.deepcopy()
            self.sparsifier.prune(self.current_sparsity)
             

            if self.lth_reset and self.current_sparsity!=self.previous_sparsity: # If sparsity has changed, the network has been pruned
                    print(f'Resetting Weights to their epoch {self.rewind_epoch} values')
                    self.sparsifier._reset_weights()

        self.previous_sparsity = self.current_sparsity
        
    def set_sparsity(self, iteration):
        self.current_sparsity = self.sched_func(start=0., end=self.sparsity, pct=(iteration-self.start_iter)/(self.total_iters-self.start_iter))
        
    def on_backward_begin(self, last_input, last_output, last_target, **kwargs):
        self.teacher.model.eval()
        teacher_output = self.teacher.model(last_input)
        new_loss = DistillationLoss(last_output, last_target, teacher_output, self.T, self.α)
        
        return {'last_loss': new_loss}
    
    def on_train_end(self, **kwargs):
        print(f'Final Sparsity: {self.current_sparsity:.2f}')
        if self.reset_end:
            self.sparsifier._reset_weights()

In [58]:
def DistillationLoss(y, labels, teacher_scores, T, alpha):
    return nn.KLDivLoss(reduction='batchmean')(F.log_softmax(y/T, dim=-1), F.softmax(teacher_scores/T, dim=-1)) * (T*T * 2.0 * alpha) + F.cross_entropy(y, labels) * (1. - alpha)

In [None]:
    def __init__(self, learn:Learner, teacher:Learner, T:float=20., α:float=0.7):
        super().__init__(learn)
        self.teacher = teacher
        self.T, self.α = T, α
    
    def on_backward_begin(self, last_input, last_output, last_target, **kwargs):
        self.teacher.model.eval()
        teacher_output = self.teacher.model(last_input)
        new_loss = DistillationLoss(last_output, last_target, teacher_output, self.T, self.α)
        
        return {'last_loss': new_loss}



In [440]:
for k,m in enumerate(learn.model.modules()):
    if isinstance(m, nn.Conv2d):
        print(f"Sparsity in {m.__class__.__name__} {k}: {100. * float(torch.sum(m.weight == 0))/ float(m.weight.nelement()):.2f}%")

Sparsity in Conv2d 1: 96.88%
Sparsity in Conv2d 7: 0.00%
Sparsity in Conv2d 10: 0.00%
Sparsity in Conv2d 13: 0.00%
Sparsity in Conv2d 16: 0.00%
Sparsity in Conv2d 20: 0.00%
Sparsity in Conv2d 23: 0.00%
Sparsity in Conv2d 26: 0.00%
Sparsity in Conv2d 29: 0.00%
Sparsity in Conv2d 32: 0.00%
Sparsity in Conv2d 36: 32.03%
Sparsity in Conv2d 39: 24.61%
Sparsity in Conv2d 42: 0.00%
Sparsity in Conv2d 45: 30.08%
Sparsity in Conv2d 48: 26.56%
Sparsity in Conv2d 52: 99.80%
Sparsity in Conv2d 55: 99.80%
Sparsity in Conv2d 58: 0.00%
Sparsity in Conv2d 61: 99.80%
Sparsity in Conv2d 64: 99.80%


In [451]:
for k,m in enumerate(learn.model.modules()):
    if isinstance(m, nn.Conv2d):
        print(f"Sparsity in {m.__class__.__name__} {k}: {100. * float(torch.sum(m.weight == 0))/ float(m.weight.nelement()):.2f}%")

Sparsity in Conv2d 1: 98.44%
Sparsity in Conv2d 7: 98.44%
Sparsity in Conv2d 10: 98.44%
Sparsity in Conv2d 13: 98.44%
Sparsity in Conv2d 16: 98.44%
Sparsity in Conv2d 20: 99.22%
Sparsity in Conv2d 23: 96.88%
Sparsity in Conv2d 26: 99.22%
Sparsity in Conv2d 29: 95.31%
Sparsity in Conv2d 32: 96.09%
Sparsity in Conv2d 36: 99.61%
Sparsity in Conv2d 39: 0.00%
Sparsity in Conv2d 42: 99.61%
Sparsity in Conv2d 45: 0.00%
Sparsity in Conv2d 48: 0.00%
Sparsity in Conv2d 52: 84.18%
Sparsity in Conv2d 55: 0.00%
Sparsity in Conv2d 58: 99.80%
Sparsity in Conv2d 61: 0.00%
Sparsity in Conv2d 64: 0.00%


In [456]:
a = torch.randn(10,100,3,100)

In [457]:
torch.norm(a, p=2, dim=(1,2,3))

tensor([172.6954, 171.7417, 173.2661, 173.3760, 173.9668, 173.8984, 173.0813,
        173.5202, 173.2958, 172.3175])

In [459]:
(a).data.pow(2).sum(dim=(1,2,3))

tensor([29823.7070, 29495.2441, 30021.1211, 30059.1621, 30264.4707, 30240.7129,
        29957.1465, 30109.3848, 30031.4355, 29693.3477])

In [417]:
#(torch.norm(a, p=1, dim=(3))/(a.shape[3])).shape
(a.abs().sum(dim=(3)).view(-1).clone()).shape

torch.Size([3000])

In [353]:
torch.norm(a, p=1, dim=(1,2,3))/(a.shape[1]*a.shape[2]*a.shape[3])

tensor([0.8004, 0.8037, 0.7933, 0.8009, 0.7991, 0.7969, 0.7933, 0.8015, 0.7884,
        0.7880])

In [344]:
a.shape

torch.Size([10, 100, 11, 11])

In [None]:
torch.

In [345]:
torch.numel(a[0,:,:,:])

12100

In [268]:
(a.abs().sum(dim=(3)).view(-1)/a.shape[3]).mean()

#/(a.shape[2]*a.shape[3])

tensor(0.8076)

In [None]:
        if self.granularity == 'weight':
            w = weight.view(-1).abs().clone()
            
        elif self.granularity == 'vector':
            w = weight.abs().sum(dim=(3)).view(-1).clone()/(weight.shape[2]*weight.shape[3])
            #w = weight.abs().sum(dim=(3)).view(-1).clone()

        elif self.granularity == 'kernel':
            w = weight.abs().sum(dim=(2,3)).view(-1).clone()/weight.shape[2]  
        
        elif self.granularity == 'filter':
            w = weight.abs().sum(dim=(1,2,3)).clone()/weight.shape[1]