In [None]:
# export
from typing import List
from fastai2.basics import *
from fastai2.callback.hook import *
from fastai2.imports import *

In [None]:
# default_exp pytorch.model

# Pytorch Model
> Utils about torch.nn.Module

## model.summary(sample_inputs)

In [None]:
# export
def layer_info(model: nn.Module, *sample_inputs):
    "sample_inputs: sample_inputs of your model, only support batch first inputs"
    def _track(m, i, o):
        return (m.__class__.__name__,)+total_params(m)+(apply(lambda x:x.shape, o),)
    layers = [m for m in flatten_model(model)]
    with Hooks(layers, _track) as h:
        _ = model.eval()(*apply(lambda o:o[:1], sample_inputs))
        return sample_inputs,h.stored

In [None]:
m = nn.Sequential(nn.Linear(1,50), nn.ReLU(), nn.BatchNorm1d(50), nn.Linear(50, 1))
sample_input = torch.randn((16, 1))
test_eq(layer_info(m, sample_input)[1], [
    ('Linear', 100, True, [1, 50]),
    ('ReLU', 0, False, [1, 50]),
    ('BatchNorm1d', 100, True, [1, 50]),
    ('Linear', 51, True, [1, 1])
])

In [None]:
# Test for multiple inputs model
class _2InpModel(Module):
    def __init__(self):
        super().__init__()
        self.seq = nn.Sequential(nn.Linear(2,50), nn.ReLU(), nn.BatchNorm1d(50), nn.Linear(50, 1))
    def forward(self, *inps):
        outputs = torch.cat(inps, dim=-1)
        return self.seq(outputs)


m = _2InpModel()
sample_inputs = (torch.randn(16, 1), torch.randn(16, 1))
test_eq(layer_info(m, *sample_inputs)[1], [
    ('Linear', 150, True, [1, 50]),
    ('ReLU', 0, False, [1, 50]),
    ('BatchNorm1d', 100, True, [1, 50]),
    ('Linear', 51, True, [1, 1])
])

In [None]:
# export
def _print_shapes(o, bs):
    if isinstance(o, torch.Size): return ' x '.join([str(bs)] + [str(t) for t in o[1:]])
    else: return str([_print_shapes(x, bs) for x in o])

In [None]:
# export
@patch
def summary(self: nn.Module, *sample_inputs):
    ''' Print a summary of the model
        sample_inputs: sample inputs of your model, only support batch first inputs
    '''
    sample_inputs,infos = layer_info(self, *sample_inputs)
    n,bs = 64,find_bs(sample_inputs)
    inp_sz = _print_shapes(apply(lambda x:x.shape, sample_inputs), bs)
    res = f"{self.__class__.__name__} (Input shape: {inp_sz})\n"
    res += "=" * n + "\n"
    res += f"{'Layer (type)':<20} {'Output Shape':<20} {'Param #':<10} {'Trainable':<10}\n"
    res += "=" * n + "\n"
    ps,trn_ps = 0,0
    infos = [o for o in infos if o is not None] #see comment in previous cell
    for typ,np,trn,sz in infos:
        if sz is None: continue
        ps += np
        if trn: trn_ps += np
        res += f"{typ:<20} {_print_shapes(sz, bs)[:19]:<20} {np:<10,} {str(trn):<10}\n"
        res += "_" * n + "\n"
    res += f"\nTotal params: {ps:,}\n"
    res += f"Total trainable params: {trn_ps:,}\n"
    res += f"Total non-trainable params: {ps - trn_ps:,}\n\n"
#     res += f"Optimizer used: {self.opt_func}\nLoss function: {self.loss_func}\n\n"
#     if self.opt is not None:
#         res += f"Model " + ("unfrozen\n\n" if self.opt.frozen_idx==0 else f"frozen up to parameter group number {self.opt.frozen_idx}\n\n")
#     res += "Callbacks:\n" + '\n'.join(f"  - {cb}" for cb in sort_by_run(self.cbs))
    return PrettyString(res)
    

In [None]:
m = nn.Sequential(nn.Linear(1,50), nn.ReLU(), nn.BatchNorm1d(50), nn.Linear(50, 1))
for p in m[0].parameters(): p.requires_grad_(False)
sample_input = torch.randn((16, 1))
m.summary(sample_input)

Sequential (Input shape: ['16 x 1'])
Layer (type)         Output Shape         Param #    Trainable 
Linear               16 x 50              100        False     
________________________________________________________________
ReLU                 16 x 50              0          False     
________________________________________________________________
BatchNorm1d          16 x 50              100        True      
________________________________________________________________
Linear               16 x 1               51         True      
________________________________________________________________

Total params: 251
Total trainable params: 151
Total non-trainable params: 100


In [None]:
# Test for multiple outputs model
class _2OutModel(nn.Module):
    def forward(self, x1):
        seq_len, bs, hid_size = 50, 16, 256
        num_layer = 1
        return torch.randn((seq_len, bs, hid_size)), torch.randn((num_layer, bs, hid_size))
m = _2OutModel()
m.summary(torch.randn((16, 1))) # Output Shape should be (50, 16, 256), (1, 16, 256)

_2OutModel (Input shape: ['16 x 1'])
Layer (type)         Output Shape         Param #    Trainable 
_2OutModel           ['16 x 16 x 256', '  0          False     
________________________________________________________________

Total params: 0
Total trainable params: 0
Total non-trainable params: 0


## freeze

In [None]:
# export
def check_requires_grad(layers: List[nn.Module], grad: bool):
    " check whether reauires_grad of all params in layers is grad "
    grads = []
    param_groups = list(map(params, layers)) # [list of params in layer1, list of params in group2, ....]
    for param_group in param_groups: 
        for param in param_group:
            grads.append(param.requires_grad)
    if grad==True and all(grads)==True: return True
    elif grad==False and all(grads)==False: return True
    else: return False

In [None]:
# export
def set_requires_grad(layers: List[nn.Module], to: bool):
    "set requires_grad of params in layers to to"
    param_groups = list(map(params, layers)) # [list of params in layer1, list of params in group2, ....]
    for param_group in param_groups: 
        for param in param_group:
            param.requires_grad_(to)

In [None]:
layers = [nn.Linear(1, 1), nn.BatchNorm1d(1)]
set_requires_grad(layers, False)
test_eq(check_requires_grad(layers, False), True)
set_requires_grad(layers, True)
test_eq(check_requires_grad(layers, True), True)

In [None]:
# export
def freeze_to(layers: List[nn.Module], n: int):
    ''' set requires_grad_ to False of layers[:n] and set requires_grad_ to True of layers[n:] '''
    freeze_layers = layers[slice(None, n)]
    unfreeze_layers = layers[slice(n, None)]
    set_requires_grad(freeze_layers, False)
    set_requires_grad(unfreeze_layers, True)

In [None]:
layers = [nn.Linear(1, 1), nn.Linear(1, 1), nn.BatchNorm1d(1)]
freeze_to(layers, 1)
test_eq(check_requires_grad(layers[:1], False), True)
test_eq(check_requires_grad(layers[1:], True), True)

In [None]:
class TestModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linears = nn.Sequential(nn.Linear(1,50), nn.ReLU(), nn.BatchNorm1d(50))
        self.classifier = nn.Linear(50, 10)
    def forward(self, x):
        return self.classifier(self.linears(x))
m = TestModel()
layers = [*m.linears, m.classifier]
freeze_to(layers, -1)
sample_input = torch.randn((16, 1))
m.summary(sample_input)

TestModel (Input shape: ['16 x 1'])
Layer (type)         Output Shape         Param #    Trainable 
Linear               16 x 50              100        False     
________________________________________________________________
ReLU                 16 x 50              0          False     
________________________________________________________________
BatchNorm1d          16 x 50              100        False     
________________________________________________________________
Linear               16 x 10              510        True      
________________________________________________________________

Total params: 710
Total trainable params: 510
Total non-trainable params: 200


## Export -

In [None]:
# hide
from nbdev.export import notebook2script
notebook2script()

Converted 01_data.core.ipynb.
Converted 02_pytorch.transformer.ipynb.
Converted 03_pytorch.model.ipynb.
Converted 04_callback.optuna.ipynb.
Converted Untitled.ipynb.
Converted index.ipynb.
