In [None]:
# default_exp core

# Sparse Core

> Basic functions for sparsifying dense modules & models.

In [None]:
#hide
from nbdev.showdoc import *

In [None]:
#export
import numpy as np
import torch
import torch.nn as nn

In [None]:
#export
from fastcore.all import * # L
from fastai.basics import * # flatten_model, etc...

## Sparsify Module

> For sparsifying a single module.

When a parameter and buffer in a module follow the naming convention: `{p_name}`, `{p_name}_mask`, respectively, the buffer is assumed to be the mask for the parameter. For example, masked Linear and ConvNd layers will typically have a parameter named `weight` and a buffer named `weight_mask`.

In [None]:
#export
@torch.no_grad()
def sparse_mask(sizes, sparsity):
    n_total = np.prod(sizes)
    n_ones = round((1-sparsity) * n_total)
    shuffled_ones = torch.randperm(n_total)[:n_ones]
    mask = torch.zeros(n_total, dtype=torch.bool)
    mask[shuffled_ones] = True
    return mask.reshape(*sizes)

In [None]:
mask = sparse_mask((10,5), 0.8)
test_eq(10, int(mask.sum()))

In [None]:
#export
def masked_params(module):
    '''Returns list of (param, mask) tuples, assuming masks are buffers with name scheme: {param}_mask.'''
    buffer_d = {name:b for name, b in module.named_buffers()}
    param_mask_pairs = [(p, buffer_d[f'{name}_mask']) 
                        for name,p in module.named_parameters() 
                        if f'{name}_mask' in buffer_d]
    return param_mask_pairs

In [None]:
m = nn.Linear(5,10)
m.register_buffer('weight_mask', mask)
param_mask_pairs = masked_params(m)
test_eq(param_mask_pairs[0][0], m.weight)
test_eq(param_mask_pairs[0][1], m.weight_mask)

In [None]:
#export
@torch.no_grad()
def apply_masks(module, *args, inplace=True):
    for p, mask in masked_params(module): 
        if inplace: p.data.mul_(mask)
        else:       p.data = p.data.mul(mask)

In [None]:
apply_masks(m)
test_eq(10, m.weight.abs().gt(0).sum())

In [None]:
#export
_sparseable_module_types = nn.Linear, nn.Conv2d

def is_sparseable_module(m, additional_types=[]):
    types = set(_sparseable_module_types)
    if additional_types: types |= set(additional_types)
    return isinstance(m, tuple(types))

In [None]:
#export
def sparseable_modules(model, additional_types=[]):
    return [m for m in flatten_model(model) if is_sparseable_module(m, additional_types)]

In [None]:
def test_model():
    return nn.Sequential(
        nn.Conv2d(3,32,3), nn.ReLU(), 
        nn.Conv2d(32,128,3), nn.ReLU(), 
        nn.Conv2d(128,512,3), nn.ReLU(), Flatten(),
        nn.Linear(512, 10))

model = test_model()
s_mods = sparseable_modules(model)
test_eq(4, len(s_mods))

## Sparse Distributions

> For determining the layer-wise sparsity of a list of modules.

### Uniform Distribution

> All layers have a the same percentage of connection removed.

In [None]:
#export
def uniform_sparsity(params, model_sparsity):
    return [model_sparsity] * len(params)

### First-Layer-Dense Uniform Distribution

> Uniform sparsity except for the first layer, which is dense.

In [None]:
#export
def first_layer_dense_uniform(params, model_sparsity):
    sparsities = [1.] + [model_sparsity] * (len(params) - 1)
    return sparsities

### Erdos-Renyi (Kernel) Distribution

> For a fixed overall sparsity, the Erdos-Renyi sparsity distribution allocates more connections to smaller layers and fewer to large layers when compared to a uniform sparsity distribution.

In [None]:
#export
# modified from https://github.com/google-research/rigl/blob/master/rigl/sparse_utils.py.
def erdos_renyi_sparsity(params, model_sparsity, include_kernel=True, erk_power_scale=1.0):
    """
    Returns a list of sparsities in the same order as params. Sparsities satisfy 
    the Erdos-Renyi(Kernel) distribution, where the model has a total parameter count 
    as one with uniform sparsities, that is, satisfying the following equation:
    # eps * (p_1 * N_1 + p_2 * N_2) = (1 - model_sparsity) * (N_1 + N_2), for some float `eps`.
    
    Args:
    params: list of all sparseable parameters
    model_sparsity: target overall sparsity between 0 and 1
    include_kernel: if True, kernel dimensions are included in the scaling (e.g. for ConvNd layers)
    erk_power_scale: scale < 1 softens the erdos_renyi distribution (i.e. closer to uniform)
    
    Returns a list of sparsities where values correspond to individual param sparsities.
    """
    # Enforce custom sparsities, then find correct scaling factor, `eps` for remaining params
    dense_layers = set()
    is_eps_valid = False
    while not is_eps_valid:
        # Start with all layers and try to find right eps. If any sparsity exceeds 1, 
        # make that layer dense and repeat with the non-dense layers.
        #
        # E.g. where N_3, and N_4 are found to be dense:
        # eps * (p_1 * N_1 + p_2 * N_2) + (N_3 + N_4) =
        #    (1 - model_sparsity) * (N_1 + N_2 + N_3 + N_4)
        # eps * (p_1 * N_1 + p_2 * N_2) =
        #    (1 - model_sparsity) * (N_1 + N_2) - model_sparsity * (N_3 + N_4) <--- == rhs
        # eps = rhs / (\sum_i p_i * N_i) <--- == divisor
        # eps = rhs / divisor

        divisor = 0
        rhs = 0
        raw_sparsity = {}
        for p in params:
            n_zeros = int(np.floor(model_sparsity * p.numel()))
            if p in dense_layers:
                rhs -= n_zeros
            else:
                n_ones = p.numel() - n_zeros
                rhs += n_ones
                if include_kernel:
                    raw_sparsity[p] = (np.sum(p.shape) / np.prod(p.shape))**erk_power_scale
                else:
                    raw_sparsity[p] = (np.sum(p.shape[:2]) / np.prod(p.shape[:2]))
                divisor += raw_sparsity[p] * p.numel()
                
        eps = rhs / divisor
        
        # If eps * raw_sparsity[p] > 1, we add the param to the set of dense_layers
        max_sparsity = np.max(list(raw_sparsity.values()))
        if eps * max_sparsity > 1:
            for p, p_raw_sparsity in raw_sparsity.items():
                if p_raw_sparsity == max_sparsity:
                    dense_layers.add(p)
        else:
            is_eps_valid = True

    # With the valid eps, we can set sparsities of the remaining layers
    sparsities = [0. if p in dense_layers else (1. - eps * raw_sparsity[p]) for p in params]
    return sparsities

In [None]:
model = test_model()
s_params = L(sparseable_modules(model)).map(lambda m: m.weight)

sparsities = erdos_renyi_sparsity(s_params, 0.9)
n_nonzeros = sum([(1-s) * p.numel() for p, s in zip(s_params, sparsities)])
test_close(n_nonzeros, 0.1 * sum([p.numel() for p in s_params]), eps=len(s_params))
# test_eq([0., 0., 0., 0.], sparsities) # TODO: calc sparsities by hand and compare

## Sparsify Model

> For sparsifying an entire model.

In [None]:
#export
@torch.no_grad()
def sparsify_model(model, model_sparsity, sparse_init_f=uniform_sparsity, enforce_mask=True):
    '''
    Adds a sparse mask for each sparseable-module weight in model and applies mask to weights.
    
    If `enforce_mask` is True, a forward_pre_hook will be registered to each module
    to apply the weight mask before every forward pass of the module.
    
    `sparsify_method`: per RigL paper, `uniform_sparsity` has fewer FLOPs, `erdos_renyi_sparsity` 
    results in better model.
    
    Returns hooks if `enforce_mask` == True, otherwise None.
    '''
    sparseable_modules = L(model.modules()).filter(is_sparseable_module)
    sparseable_params = sparseable_modules.map(lambda m: m.weight)
    sparsities = sparse_init_f(sparseable_params, model_sparsity)
    hooks = []
    for m, s in zip(sparseable_modules, sparsities):
        if s > 0:
            mask = sparse_mask(m.weight.shape, s).to(m.weight.device)
            m.register_buffer('weight_mask', mask)
            apply_masks(m)
            if enforce_mask: 
                h = m.register_forward_pre_hook(apply_masks)
                hooks.append(h)
    return hooks or None

In [None]:
model = test_model()
s_mods = sparseable_modules(model)
n_params = sum(m.weight.numel() for m in s_mods)
sparsify_model(model, 0.9, sparse_init_f=uniform_sparsity)
n_nonzeros = sum(m.weight.abs().gt(0).sum() for m in s_mods)
# increase `eps` to account for rounding to nearest whole weight
test_close(n_nonzeros, 0.1 * n_params, eps=len(s_mods))
p = s_mods[0].weight
test_close(p.abs().gt(0).sum(), 0.1 * p.numel(), eps=1)

## Sparse Training

### Drop/Grow Heuristics

In [None]:
#export
def weight_magnitude(p, *args): return p.data.abs()

In [None]:
#export
def gradient_magnitude(p, *args): return p.grad.abs()

In [None]:
#export
def gradient_momentum(p, opt, *args):
    '''Calculates the momentum of the gradient for a parameter `p` from the `opt` state.'''
    state = opt.state[p]
    grad_avg = state['grad_avg'] if 'grad_avg' in state else None
    sqr_avg = state['sqr_avg'] if 'sqr_avg' in state else None
    if grad_avg is None:
        raise Exception(f"Error: 'grad_avg' key not found in optimizer state. Tip: set the `mom` hyperparamter in the learner.")
    if sqr_avg is None:
        grad_mom = grad_avg
    else:
        try: eps = opt.state_dict()['hypers'][0]['eps']
        except: eps = 1e-6
        print(eps)
        grad_mom =  grad_avg / (torch.sqrt(sqr_avg + eps))
    return grad_mom

In [None]:
#tests

### Dynamic Sparse Training Callback

In [None]:
#export
def top_k_mask(t, n_keep):
    '''Returns a mask with `n_keep` ones cooresponding to the largest values in `t`'''
    n_drop = t.numel() - n_keep
    _, sorted_ixs = torch.topk(t.flatten(), k=t.numel())
    mask = torch.cat([torch.ones(n_keep, dtype=torch.bool, device=t.device), 
                      torch.zeros(n_drop, dtype=torch.bool, device=t.device)])
    mask = mask.scatter(0, sorted_ixs, mask)
    return mask.view(*t.shape)

In [None]:
t = torch.linspace(-0.9, 0.9, 20).reshape(4,5)
mask = top_k_mask(t, 5)
test_eq(0, mask[:3].sum())
test_eq(5, mask[3:].sum())

In [None]:
#export
class DynamicSparseTrainingCallback(Callback):
    toward_end = True # run after GradientAccumulation and any other cb that modifies the gradients
    _exclude_modules = [nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d]
    _sparse_distributions = ('uniform', 'ERK')
    
    def __init__(self, sparsity=0.9, modules=None, sparsity_distribution='uniform', 
                 batches_per_update=None, 
                 initial_drop_grow_pct=0.3, stop_pct=0.75, 
                 keep_method=weight_magnitude,
                 grow_method=gradient_magnitude,
                 exclude_modules=[], first_layer_dense=True):
        store_attr('sparsity,modules,sparsity_distribution,initial_drop_grow_pct,stop_pct,keep_method,grow_method,first_layer_dense')
        self.batches_per_update = ifnone(batches_per_update, len(self.dls.train)) # default: 1 update per epoch
        self.exclude_modules = exclude_modules + self._exclude_modules
        
    def before_fit(self):
        ### determine modules to sparsify
        is_sparse_module = lambda m: has_params(m) and hasattr(m, 'weight') and type(m) not in self.exclude_modules
        self.modules = ifnone(self.modules, [m for m in flatten_model(self.learn.model) if is_sparse_module(m)])

        ### determine initial sparsities per layer
        assert self.sparsity_distribution in self._sparse_distributions, f'Unknown sparsity distribution: {self.sparsity_distribution}. Options: {self._sparse_distributions}'
        if self.sparsity_distribution == 'uniform':
            self.S = [self.sparsity] * len(self.modules)
            if self.first_layer_dense:
                self.S[0] = 0
        elif self.sparsity_distribution == 'ERK':
            raise NotImplementedError()
        
        if self.grow_method not in self._grow_methods: assert is_function(self.grow_method)
        
        ### create masks and assign to each parameter
        for m, s in zip(self.modules, self.S):
            if s > 0: m.register_buffer('weight_mask', (torch.rand_like(m.weight) > s).bool())
            
        ### schedule the decay percent (i.e. # of connections to drop/add per update)
        self.drop_grow_pct_sched = combine_scheds([self.stop_pct, 1-self.stop_pct], 
                                                  [SchedCos(self.initial_drop_grow_pct, 0.), SchedNo(0.,0.)])
        
        ### apply weight masks
        self.add_hooks()
        
    def after_fit(self):
        self.remove_hooks() # ensure hooks are removed (e.g. in case we cancelled the fit loop)
        
    def add_hooks(self):
        self.remove_hooks() # to be certain that we never add them twice
        self.hooks = [m.register_forward_pre_hook(apply_masks) for m in self.modules]
        
    def remove_hooks(self):
        if getattr(self, 'hooks', None):
            for h in self.hooks: h.remove()
    
    def before_batch(self):
        if not self.training: return
        if self.is_update_step():
#             print(f'UPDATE step! before_batch, epoch: {self.learn.epoch}, step: {self.learn.iter}')
            self.hooks = Hooks(self.modules, self.rewire_module, is_forward=False)
    def after_backward(self):
        if self.is_update_step():
#             print(f'UPDATE step! after_backward, epoch: {self.learn.epoch}, step: {self.learn.iter}')
            self.hooks.remove()
            
            ### skip gradient update after changing network connectivity
            raise CancelBatchException()
        
    def is_update_step(self):
        '''Whether to modify network connectivity. Side effect: updates self.drop_grow_pct'''
        step = self.epoch * self.n_iter + self.iter
        n_steps = self.n_epoch * self.n_iter
        pct_train = step / n_steps
        self.drop_grow_pct = self.drop_grow_pct_sched(pct_train)
        return step > 0 and step % self.batches_per_update == 0 and self.drop_grow_pct > 0
            
    @torch.no_grad()
    def rewire_module(self, m, *args):
        '''Update step for one module'''
        
        for m, s in zip(self.modules, self.S):
            if s <= 0: continue # ignore fully dense layers
            
            param, mask = m.weight, m.weight_mask
            
            ### determine # of connections to keep, # to regrow
            n_keep = self.compute_n_keep(s, param, mask)
            n_grow = self.compute_n_grow(s, param, mask)
            
            ### determine weights to keep
            keep_score = self.keep_method(p)
            keep_mask = top_k_mask(keep_score, n_keep)
            
            ### determine weights to grow
            grow_score = self.grow_method(p, self.learn.opt)
            # set keep weights to negative so we don't choose to grow them
            grow_score = grow_score * keep_mask.logical_not() - keep_mask.float()
            grow_mask = top_k_mask(grow_score, n_grow)
            
            ### update mask
            mask.data = keep_mask | grow_mask
            
            ### zero momentum for new connections
            self.reset_momentum(param, grow_mask & keep_mask.logical_not())

    def compute_n_grow(self, s, p, mask):
        return (1 - s) * p.numel() * self.drop_grow_pct
    def compute_n_keep(self, s, p, mask):
        return p.numel() - self.compute_n_grow()

    @torch.no_grad()
    def reset_momentum(self, p, grow_mask):
        '''Initialize momentum to zero for newly-added connections'''
        state = self.opt.state[p]
        if 'grad_avg' in state: state['grad_avg'].mul_(grow_mask)
        if 'sqr_avg' in state: state['sqr_avg'].mul_(grow_mask)

#     _docs = dict(before_fit="Set counter to 0",
#                  after_backward="Skip weight update if we have not seen enough items")

## Presets

### Sparse Evolutionary Training (SET)

TODO

In [None]:
#export
SET_kwargs = {}

### Sparse Training From Scratch (SNFS)

In [None]:
#export
STFS_kwargs = {'keep_method': weight_magnitude, 'grow_method': gradient_momentum, 
               'batches_per_update': None, 'initial_drop_grow_pct': 0.3, 'stop_pct': 1.0,}

### Rigged Lottery (RigL)

In [None]:
#export
RigL_kwargs = {'keep_method': weight_magnitude, 'grow_method': gradient_magnitude, 
               'batches_per_update':None, 'initial_drop_grow_pct':0.3, 'stop_pct':0.75,}

# Export

In [None]:
#hide
from nbdev.export import notebook2script
notebook2script()

Converted 00_core.ipynb.
Converted index.ipynb.
