In [1]:
%load_ext autoreload
%autoreload 2

# Export cells
!python notebook2script.py module_util.ipynb

Converted module_util.py to torchtrainer/module_util.py


In [2]:
# Utility Pytorch module functions
#export module_util.py
'''Utility functions and classes for working with Pytorch modules'''

from torch import nn
from functools import partial

class ActivationSampler(nn.Module):
    '''Generates a hook for sampling a layer activation. Can be used as
    
    sampler = ActivationSampler(layer_in_model)
    output = model(input)
    layer_activation = sampler()
    
    '''
    
    def __init__(self, model):
        super(ActivationSampler, self).__init__()
        self.model_name = model.__class__.__name__
        self.activation = None
        model.register_forward_hook(self.get_hook())
        
    def forward(self, x=None):
        return self.activation
    
    def get_hook(self):
        def hook(model, input, output):
            self.activation = output
        return hook
    
    def extra_repr(self):
        return f'{self.model_name}'

class Lambda(nn.Module):
    '''Transforms function into a module'''
    
    def __init__(self, func):
        super().__init__()
        self.func = func

    def forward(self, x): return self.func(x)
    
class Hooks:
    '''Hooks for storing information about layers.
    
    The attribute `storage` will contain the layers information. It is a dict
    having layer names as keys and respective values generated by `func`.
    
    Parameters
    ----------
    module : torch.nn
        The module containing the layers. Only used for getting layer names
    layers : list
        List of torch.nn modules for storing the activations
    func : function
        Function to be registered as a hook. Must have signature func(storage, module, input, output) for
        forward hooks and ?? for backward hooks. `storage` is a dictionary used for storing layer information.    
    '''
    
    def __init__(self, module, layers, func, is_forward=True): 
        
        self.hooks = []
        
        storage = {}         # For storing information, each layer will be a key here
        layers_dict = {}     # Dict of layer names and actual layers
        # Obtain layer names for hashing. Is there a better way?
        for layer_name, layer in module.named_modules():
            if True in [True for l in layers if layer is l]:
                layers_dict[layer_name] = layer
                storage[layer_name] = {}
        
        self.layers_dict = layers_dict
        self.storage = storage
        
        if is_forward:
            self._register_forward_hooks(func)
        else:
            self._register_backward_hooks(func)
    
    def __del__(self): self.remove_hooks()
    
    def _register_forward_hooks(self, func):
        '''Register one hook for each layer.'''
            
        for layer_name, layer in self.layers_dict.items():
            hook_func = self._generate_hook(func, self.storage[layer_name])
            self.hooks.append(layer.register_forward_hook(hook_func))
        
    def _register_backward_hooks(self, func):
        '''Register one hook for each layer.'''
            
        for layer_name, layer in self.layers_dict.items():
            hook_func = self._generate_hook(func, self.storage[layer_name])
            self.hooks.append(layer.register_backward_hook(hook_func))
        
    def _generate_hook(self, func, storage):
        '''Generate function to be used in module.register_forward_hook and module.register_backward_hook, fixing
        as a first argument to the function an empty dictionary.'''
        
        return partial(func, storage)
        
    def to_cpu(self):
        pass
    
    def remove_hooks(self):
        '''Remove hooks from the network.'''
        
        for hook in self.hooks:
            hook.remove()
            
def _calculate_stats(storage, model, input, output, store_act=True, store_weights=False):
    
    if store_act:
        if 'activation' not in storage:
            storage['activation'] = {}
        if 'mean' not in storage['activation']:
            storage['activation']['mean'] = []
        if 'std' not in storage['activation']:
            storage['activation']['std'] = []
        if 'hist' not in storage['activation']:
            storage['activation']['hist'] = []
        
        activation = output.detach()
        storage['activation']['mean'].append(activation.mean().item())
        storage['activation']['std'].append(activation.std().item())
        storage['activation']['hist'].append(activation.cpu().histc(100,-10,10)) #histc isn't implemented on the GPU
                                            
    if store_weights:
        if 'weights' not in storage:
            storage['weights'] = {}
        if 'mean' not in storage['weights']:
            storage['weights']['mean'] = []
        if 'std' not in storage['weights']:
            storage['weights']['std'] = []
        
        try:
            weight = model.weight
        except Exception:
            raise AttributeError('Model does not have `weight` attribute')
        else:
            weight = weight.detach()
            storage['weights']['mean'].append(weight.mean().item())
            storage['weights']['std'].append(weight.std().item())
            storage['weights']['hist'].append(weight.cpu().histc(100,-10,10)) #histc isn't implemented on the GPU
            
def calculate_stats(store_act=True, store_weights=False):

    return partial(_calculate_stats, store_act=store_act, store_weights=store_weights)

In [4]:
# Functions for splitting a model into different groups, which can be frozen or receive distinct learning rates
#export module_util.py

bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)

def split_modules(model, modules_to_split):
    '''Split `model` layers into different groups. Useful for freezing part of the model
    or using different learning rates.'''
    
    module_groups = [[]]
    for module in model.modules():
        if module in modules_to_split:
            module_groups.append([])
        module_groups[-1].append(module)
    return module_groups

def define_opt_params(module_groups, lr=None, wd=None, debug=False):
    '''Define distinct learning rate and weight decay for parameters belonging
    to groupd modules in `module_groups`. '''
    
    num_groups = len(module_groups)
    if isinstance(lr, int): lr = [lr]*num_groups
    if isinstance(wd, int): wd = [wd]*num_groups
    
    opt_params = []
    for idx, group in enumerate(module_groups):
        group_params = {'params':[]}
        if lr is not None: group_params['lr'] = lr[idx]
        if wd is not None: group_params['wd'] = wd[idx]
        for module in group:
            pars = module.parameters(recurse=False)
            if debug: print(module.__class__)
            pars = list(filter(lambda p: p.requires_grad, pars))
            if len(pars)>0:
                group_params['params'] += pars
                if debug:
                    for p in pars:
                        print(p.shape)
        opt_params.append(group_params)
    return opt_params

def groups_requires_grad(module_groups, req_grad=True, keep_bn=False):
    '''Set requires_grad to `req_grad` for all parameters in `module_groups`.
    If `keep_bn` is True, batchnorm layers are not changed.'''
    
    for idx, group in enumerate(module_groups):
        for module in group:
            for p in module.parameters(recurse=False):
                if not keep_bn or not isinstance(module, bn_types): p.requires_grad=req_grad

def freeze_to(module_groups, group_idx=-1, keep_bn=False):
    '''Freeze model groups up to the group with index `group_idx`. If `group_idx` is None, 
    freezes the entire model. If `keep_bn` is True, batchnorm layers are not changed.'''
    
    num_groups = len(module_groups)
    slice_freeze = slice(0, group_idx)
    if group_idx is not None:
        slice_unfreeze = slice(group_idx, None)
    
    groups_requires_grad(module_groups[slice_freeze], False, keep_bn)

    if group_idx is not None:
        groups_requires_grad(module_groups[slice_unfreeze], True)
            
def unfreeze(module_groups):
    '''Unfreezes the entire model.'''
    
    groups_requires_grad(module_groups, True)

In [32]:
import torch
from pytorchlb.unet import UNet    # Fix this!

unet = UNet(3, 2)
hooks_act = Hooks(unet, [unet.l1_, unet.l2_, unet.l3_, unet.l4_], calculate_stats(True))
hooks_wei = Hooks(unet, [unet.l1_.dconv[0], unet.l2_.dconv[0], unet.l3_.dconv[0], unet.l4_.dconv[0]], 
                  calculate_stats(False, True))

xb = torch.randn([2, 3, 128, 128])
pred = unet(xb)

In [34]:
hooks_wei.storage

{'l1_.dconv.0': {'weights': {'mean': [0.0029068964067846537],
   'std': [0.2779395282268524]}},
 'l2_.dconv.0': {'weights': {'mean': [0.0003165746165905148],
   'std': [0.0591021291911602]}},
 'l3_.dconv.0': {'weights': {'mean': [-0.00012993672862648964],
   'std': [0.041721682995557785]}},
 'l4_.dconv.0': {'weights': {'mean': [3.7154190067667514e-05],
   'std': [0.02945738658308983]}}}

In [13]:
c=list(unet.l1_.dconv.children())[0]

ReLU(inplace=True)

In [60]:
list(unet.named_modules())[1][1] == unet.l1_

True

In [61]:
for m in unet:
    print(m)

TypeError: 'UNet' object is not iterable