In [None]:
# default_exp core

# Sparse Core

> Basic functions for sparsifying dense modules & models.

In [None]:
#hide
from nbdev.showdoc import *

In [None]:
#export
import numpy as np
import torch
import torch.nn as nn

In [None]:
#export
from fastcore.all import * # L, etc...
from fastai.basics import * # flatten_model, etc...
from fastai.callback.all import * # combine_scheds

## Sparsify Module

> For sparsifying a single module.

When a parameter and buffer in a module follow the naming convention: `{p_name}`, `{p_name}_mask`, respectively, the buffer is assumed to be the mask for the parameter. For example, masked Linear and ConvNd layers will typically have a parameter named `weight`, a buffer named `weight_mask`. Additionally, parameters optionally also contain a sparsity buffer (e.g. for ConvNd, named `weight_sparsity`), which is used by the DynamicSparseTrainingCallback.

In [None]:
#export
@torch.no_grad()
def sparse_mask(sizes, sparsity):
    n_total = np.prod(sizes)
    n_ones = round((1-sparsity) * n_total)
    shuffled_ones = torch.randperm(n_total)[:n_ones]
    mask = torch.zeros(n_total, dtype=torch.bool)
    mask[shuffled_ones] = True
    return mask.reshape(*sizes)

def sparse_mask_like(param, sparsity): return sparse_mask(param.shape, sparsity).to(param.device)

In [None]:
mask = sparse_mask((10,5), 0.8)
test_eq(10, int(mask.sum()))

In [None]:
#export
def maybe_float(num):
    try: return float(num)
    except: return num
    
def sparse_params(module):
    '''Returns list of all (param, mask, sparsity) tuples in a module.'''
    buffer_d = {name:b for name, b in module.named_buffers()}
    param_mask_sparsities = [(p, buffer_d[f'{name}_mask'], maybe_float(buffer_d.get(f'{name}_sparsity')))
                             for name, p in module.named_parameters() 
                             if f'{name}_mask' in buffer_d]
    return list(set(param_mask_sparsities))

In [None]:
s, m = 0.8, nn.Linear(5,10)
m.register_buffer('weight_mask', sparse_mask_like(m.weight, s))
m.register_buffer('weight_sparsity', tensor(s))
m.register_buffer('bias_mask', sparse_mask_like(m.bias, s))
param_mask_sparsity = sparse_params(m)
test_eq(2, len(param_mask_sparsity))

In [None]:
#export
@torch.no_grad()
def apply_masks(module, *args, inplace=True):
    for param, mask, sparsity in sparse_params(module):
        if inplace: param.data.mul_(mask)
        else:       param.data = param.data.mul(mask)

In [None]:
apply_masks(m)
test_eq(10, m.weight.abs().gt(0).sum())

In [None]:
#export
_sparseable_module_types = (nn.Linear, 
                            nn.Conv1d, nn.Conv2d, nn.Conv3d, 
                            nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d,
                            nn.MultiheadAttention,
                            nn.RNN, nn.RNNCell, nn.GRU, nn.GRUCell, nn.LSTM, nn.LSTMCell)

def is_sparseable_module(m, additional_types=[]):
    types = set(_sparseable_module_types) | set(additional_types)
    return isinstance(m, tuple(types))

In [None]:
#export

# TODO: flatten_model gets rid of nn.MultiheadAttention which has it's own parameter 'in_proj_weight'
#       which means sparsity_model doesn't sparsify this parameter
def sparseable_modules(model, additional_types=[]):
    filt = partial(is_sparseable_module, additional_types=additional_types)
    return L(flatten_model(model)).filter(filt)

In [None]:
def test_model():
    return nn.Sequential(
        nn.Conv2d(3,32,3), nn.ReLU(), 
        nn.Conv2d(32,128,3), nn.ReLU(), 
        nn.Conv2d(128,512,3), nn.ReLU(), AdaptiveAvgPool(), Flatten(),
        nn.Linear(512, 10))

model = test_model()
s_mods = sparseable_modules(model)
test_eq(4, len(s_mods))

## Sparse Distributions

> For determining the layer-wise sparsity of a list of modules.

### Uniform Distribution

> All layers have a the same percentage of connection removed.

In [None]:
#export
def uniform_sparsity(params, model_sparsity):
    return [model_sparsity] * len(params)

### First-Layer-Dense Uniform Distribution

> Uniform sparsity except for the first layer, which is dense.

In [None]:
#export
def first_layer_dense_uniform(params, model_sparsity):
    sparsities = [0.] + [model_sparsity] * (len(params) - 1)
    return sparsities

### Erdos-Renyi (Kernel) Distribution

> For a fixed overall sparsity, the Erdos-Renyi sparsity distribution allocates more connections to smaller layers and fewer to large layers when compared to a uniform sparsity distribution.

In [None]:
#export
# modified from https://github.com/google-research/rigl/blob/master/rigl/sparse_utils.py.
def erdos_renyi_sparsity(params, model_sparsity, include_kernel=True, erk_power_scale=1.0):
    """
    Returns a list of sparsities in the same order as params. Sparsities satisfy 
    the Erdos-Renyi(Kernel) distribution, where the model has a total parameter count 
    as one with uniform sparsities, that is, satisfying the following equation:
    # eps * (p_1 * N_1 + p_2 * N_2) = (1 - model_sparsity) * (N_1 + N_2), for some float `eps`.
    
    Args:
    params: list of all sparseable parameters
    model_sparsity: target overall sparsity between 0 and 1
    include_kernel: if True, kernel dimensions are included in the scaling (e.g. for ConvNd layers)
    erk_power_scale: scale < 1 softens the erdos_renyi distribution (i.e. closer to uniform)
    
    Returns a list of sparsities where values correspond to individual param sparsities.
    """
    # Enforce custom sparsities, then find correct scaling factor, `eps` for remaining params
    dense_layers = set()
    is_eps_valid = False
    while not is_eps_valid:
        # Start with all layers and try to find right eps. If any sparsity exceeds 1, 
        # make that layer dense and repeat with the non-dense layers.
        #
        # E.g. where N_3, and N_4 are found to be dense:
        # eps * (p_1 * N_1 + p_2 * N_2) + (N_3 + N_4) =
        #    (1 - model_sparsity) * (N_1 + N_2 + N_3 + N_4)
        # eps * (p_1 * N_1 + p_2 * N_2) =
        #    (1 - model_sparsity) * (N_1 + N_2) - model_sparsity * (N_3 + N_4) <--- == rhs
        # eps = rhs / (\sum_i p_i * N_i) <--- == divisor
        # eps = rhs / divisor

        divisor = 0
        rhs = 0
        raw_sparsity = {}
        for p in params:
            n_zeros = int(np.floor(model_sparsity * p.numel()))
            if p in dense_layers:
                rhs -= n_zeros
            else:
                n_ones = p.numel() - n_zeros
                rhs += n_ones
                if include_kernel:
                    raw_sparsity[p] = (np.sum(p.shape) / np.prod(p.shape))**erk_power_scale
                else:
                    raw_sparsity[p] = (np.sum(p.shape[:2]) / np.prod(p.shape[:2]))
                divisor += raw_sparsity[p] * p.numel()
                
        eps = rhs / divisor
        
        # If eps * raw_sparsity[p] > 1, we add the param to the set of dense_layers
        max_sparsity = np.max(list(raw_sparsity.values()))
        if eps * max_sparsity > 1:
            for p, p_raw_sparsity in raw_sparsity.items():
                if p_raw_sparsity == max_sparsity:
                    dense_layers.add(p)
        else:
            is_eps_valid = True

    # With the valid eps, we can set sparsities of the remaining layers
    sparsities = [0. if p in dense_layers else (1. - eps * raw_sparsity[p]) for p in params]
    return sparsities

In [None]:
model = test_model()
s_params = L(sparseable_modules(model)).map(lambda m: m.weight)

sparsities = erdos_renyi_sparsity(s_params, 0.9)
n_nonzeros = sum([(1-s) * p.numel() for p, s in zip(s_params, sparsities)])
test_close(n_nonzeros, 0.1 * sum([p.numel() for p in s_params]), eps=len(s_params))
# test_eq([0., 0., 0., 0.], sparsities) # TODO: calc sparsities by hand and compare

## Sparsify Model

> For sparsifying an entire model.

In [None]:
#export
@torch.no_grad()
def sparsify_model(model, model_sparsity, sparse_f=uniform_sparsity, enforce_mask=True):
    '''
    Adds a sparse mask for each sparseable-module weight in model and applies mask to weights.
    
    If `enforce_mask` is True, a forward_pre_hook will be registered to each module
    to apply the weight mask before every forward pass of the module.
    
    `sparsify_method`: per RigL paper, `uniform_sparsity` has fewer FLOPs, `erdos_renyi_sparsity` 
    results in better model.
    
    Returns a fastai Hooks object. You can remove the hooks after training by calling hooks.remove().
    '''
    if isinstance(model, Learner): model = model.model
    modules = sparseable_modules(model)
    module_name_param = L([(m, p_name, p) for m in modules for p_name, p in m.named_parameters()
                         if 'weight' in p_name])
    params = module_name_param.itemgot(2)
    sparsities = sparse_f(params, model_sparsity)
    
    hooks = Hooks([], noop)
    for (m, p_name, p), s in zip(module_name_param, sparsities):
        if s > 0:
            mask = sparse_mask_like(m.weight, s)
            m.register_buffer('weight_mask', mask)
            m.register_buffer('weight_sparsity', tensor(s))
            apply_masks(m)
            if enforce_mask: 
                h = m.register_forward_pre_hook(apply_masks)
                hooks.hooks.append(h)
    
    return hooks

In [None]:
model = test_model()
s_mods = sparseable_modules(model)
n_params = sum(m.weight.numel() for m in s_mods)
sparsify_model(model, 0.9, sparse_f=uniform_sparsity)
n_nonzeros = sum(m.weight.abs().gt(0).sum() for m in s_mods)
# increase `eps` to account for rounding to nearest whole weight
test_close(n_nonzeros, 0.1 * n_params, eps=len(s_mods))
p = s_mods[0].weight
test_close(p.abs().gt(0).sum(), 0.1 * p.numel(), eps=1)

model = nn.Sequential(nn.Linear(1,50), nn.ReLU(), nn.Linear(50,1))
hooks = sparsify_model(model, 0.9)
model(torch.rand(10,1))
test_eq(10, sum([model[i].weight.abs().gt(0).sum() for i in (0,2)]))
hooks.remove()
for i in (0,2): model[i].weight.data = torch.ones_like(model[i].weight)
model(torch.rand(10,1))
test_eq(100, sum([model[i].weight.abs().gt(0).sum() for i in (0,2)]))

## Sparse Training

### Drop/Grow Heuristics

In [None]:
#export
def random_score(p, **kwargs): return torch.rand_like(p)

In [None]:
#export
def weight_magnitude(p, **kwargs): return p.data.abs()

In [None]:
#export
def gradient_magnitude(p, **kwargs): return p.grad.abs()

In [None]:
#export
def gradient_momentum(p, opt, **kwargs):
    '''Calculates the momentum of the gradient for a parameter `p` from the `opt` state.'''
    state = opt.state[p]
    grad_avg = state['grad_avg'] if 'grad_avg' in state else None
    sqr_avg = state['sqr_avg'] if 'sqr_avg' in state else None
    if grad_avg is None:
        raise Exception(f"Error: 'grad_avg' key not found in optimizer state. Tip: set the `mom` hyperparamter in the learner.")
    if sqr_avg is None:
        grad_mom = grad_avg
    else:
        try: eps = opt.state_dict()['hypers'][0]['eps']
        except: eps = 1e-6
        grad_mom =  grad_avg / (torch.sqrt(sqr_avg + eps))
    return grad_mom

In [None]:
#tests

### Dynamic Sparse Training Callback

In [None]:
#export
def top_k_mask(t, n_keep):
    '''Returns a mask with `n_keep` ones cooresponding to the largest values in `t`'''
    n_drop = t.numel() - n_keep
    _, sorted_ixs = torch.topk(t.flatten(), k=t.numel())
    mask = torch.cat([torch.ones(n_keep, dtype=torch.bool, device=t.device), 
                      torch.zeros(n_drop, dtype=torch.bool, device=t.device)])
    mask = mask.scatter(0, sorted_ixs, mask)
    return mask.view(*t.shape)

In [None]:
t = torch.linspace(-0.9, 0.9, 20).reshape(4,5)
mask = top_k_mask(t, 5)
test_eq(0, mask[:3].sum())
test_eq(5, mask[3:].sum())

In [None]:
#export
class DynamicSparseTrainingCallback(Callback):
    toward_end,run_after = True,GradientAccumulation
    
    def __init__(self, sparse_modules=None,
                 batches_per_update=None, initial_drop_grow_pct=0.3, stop_pct=0.75, 
                 keep_score_f=weight_magnitude, grow_score_f=gradient_magnitude):
        '''
        Args:
        module_sparsity_map: dictionary mapping modules to sparsity values
        batches_per_update: # of batches per update, None (default) updates at end of each training epoch
        initial_drop_grow_pct: percentage of weights to change during each dynamic weight update
        stop_pct: stop dynamic weight updates after `stop_pct` of training
        keep_score_f: function scoring each weight, top n are kept and the rest are zeroed
        grow_score_f: function scoring each weight, top n excl. kept weights are unmasked and initialized to zero
        '''
        store_attr('initial_drop_grow_pct,stop_pct,keep_score_f,grow_score_f,batches_per_update')
        self.modules = sparse_modules
        
    def before_fit(self):
        self.modules = ifnone(self.modules, sparseable_modules(self.learn.model))
        self.batches_per_update = ifnone(self.batches_per_update, len(self.dls.train))
        self.drop_grow_pct_sched = combine_scheds(
            [self.stop_pct, 1-self.stop_pct],
            [SchedCos(self.initial_drop_grow_pct, 0.), SchedNo(0.,0.)]
        )
    
    def after_backward(self):
        self.step()
        if self.is_update_step:
            for m in self.modules:
                self.rewire_module(m)
            raise CancelBatchException()
        
    def step(self):
        if not self.training:
            self.is_update_step = False
        else:
            step = self.epoch * self.n_iter + self.iter
            n_steps = self.n_epoch * self.n_iter
            pct_train = step / n_steps
            self.drop_grow_pct = self.drop_grow_pct_sched(pct_train)
            self.is_update_step = step > 0 and step % self.batches_per_update == 0 and self.drop_grow_pct > 0
            
    @torch.no_grad()
    def rewire_module(self, m):
        for param, param_name, sparsity in sparse_params(m):
            if sparsity <= 0: continue

            param, mask = m.weight, m.weight_mask

            n_grow = int((1 - sparsity) * param.numel() * self.drop_grow_pct)
            n_keep = int((1 - sparsity) * param.numel() * 1 - self.drop_grow_pct)

            # determine which weights to keep
            keep_score = self.keep_score_f(param, opt=self.learn.opt)
            keep_mask = top_k_mask(keep_score, n_keep)

            # determine which weights to grow
            grow_score = self.grow_score_f(param, opt=self.learn.opt)
            # make all keep weights to negative so we don't choose to grow them
            grow_score = grow_score * keep_mask.logical_not() - keep_mask.float()
            grow_mask = top_k_mask(grow_score, n_grow)

            # update network connectivity
            mask.data = keep_mask | grow_mask
            
            # zero momentum for new connections
            self.reset_momentum(param, grow_mask & keep_mask.logical_not())

    @torch.no_grad()
    def reset_momentum(self, p, mask):
        state = self.opt.state[p]
        if 'grad_avg' in state: state['grad_avg'].mul_(mask)
        if 'sqr_avg' in state: state['sqr_avg'].mul_(mask)

    _docs = dict(before_fit="Schedule the number of connections to drop & grow per update.",
                 before_batch="Add dynamic update hooks.",
                 after_backward="Remove dynamic update hooks and skip gradient update.",
                 step="Update self.is_update_step and self.drop_grow_pct.",
                 rewire_module="Update step for one module.",
                 reset_momentum="Initialize momentum to zero for newly-added connections.")

In [None]:
show_doc(DynamicSparseTrainingCallback)

<h2 id="DynamicSparseTrainingCallback" class="doc_header"><code>class</code> <code>DynamicSparseTrainingCallback</code><a href="" class="source_link" style="float:right">[source]</a></h2>

> <code>DynamicSparseTrainingCallback</code>(**`sparse_modules`**=*`None`*, **`batches_per_update`**=*`None`*, **`initial_drop_grow_pct`**=*`0.3`*, **`stop_pct`**=*`0.75`*, **`keep_score_f`**=*`weight_magnitude`*, **`grow_score_f`**=*`gradient_magnitude`*) :: `Callback`

Basic class handling tweaks of the training loop by changing a `Learner` in various events

In [None]:
from fastai.test_utils import *
model = nn.Sequential(nn.Linear(1,32), nn.ReLU(), nn.Linear(32,32), nn.ReLU(), nn.Linear(32,1))
learn = synth_learner(data=synth_dbunch(bs=100), model=model)
sparse_hooks = sparsify_model(learn.model, 0.8, sparse_f=first_layer_dense_uniform)
cbs = DynamicSparseTrainingCallback(batches_per_update=None, stop_pct=0.5, grow_score_f=gradient_momentum)
learn.fit(10, lr=1e-2, cbs=cbs)

epoch,train_loss,valid_loss,time
0,9.414032,5.081448,00:00
1,6.218855,3.628369,00:00
2,4.470164,0.579595,00:00
3,3.098428,0.138066,00:00
4,2.240101,0.059732,00:00
5,1.675639,0.043351,00:00
6,1.280303,0.026333,00:00
7,0.994448,0.025528,00:00
8,0.782729,0.021094,00:00
9,0.622418,0.020422,00:00


Now let's test a slightly more realistic use case: MNIST_TINY on ResNet18.

In [None]:
#slow
from fastai.vision.all import *
dls = ImageDataLoaders.from_folder(untar_data(URLs.MNIST_TINY))
learn = cnn_learner(dls, resnet18, metrics=accuracy, pretrained=False)
sparse_hooks = sparsify_model(learn.model, 0.95, first_layer_dense_uniform)
cbs = DynamicSparseTrainingCallback(batches_per_update=8, stop_pct=0.5, grow_score_f=gradient_momentum)
learn.fit_one_cycle(5, 1e-2, cbs=cbs)

test_close(1, learn.final_record[-1], eps=0.01) # better than 99% accuracy

for m in sparseable_modules(learn.model):
    for p, mask, s in sparse_params(m):
        n_alive = p.abs().gt(0).sum()
        n_total = p.numel()    
        test_close(s, 1 - n_alive / n_total, eps=0.01) # layer sparsity = target sparsity

epoch,train_loss,valid_loss,accuracy,time
0,0.397535,0.705997,0.459375,00:03


AssertionError: ==:
64
59

## Preset Definitions

### Sparse Evolutionary Training (SET)

In [None]:
#export
SET_presets = {'keep_score_f': weight_magnitude, 'grow_score_f': random_score, 
               'initial_drop_grow_pct': 0.3, 'stop_pct': 1.0,}

### Sparse Networks From Scratch (SNFS)

In [None]:
#export
SNFS_presets = {'keep_score_f': weight_magnitude, 'grow_score_f': gradient_momentum, 
               'initial_drop_grow_pct': 0.3, 'stop_pct': 1.0,}

### Rigged Lottery (RigL)

In [None]:
#export
RigL_presets = {'keep_score_f': weight_magnitude, 'grow_score_f': gradient_magnitude, 
               'initial_drop_grow_pct':0.3, 'stop_pct':0.75,}

# Export

In [None]:
#hide
from nbdev.export import notebook2script
notebook2script()