In [1]:
#export
from operator import itemgetter

from loop.callbacks import Callback, Order
from loop.modules import set_trainable, freeze_all
from loop.training import write_output

In [2]:
#export
class GradualTraining(Callback):
    """Gradually un-freezes model layers.
    
    Helps to slowly 'warm-up' topmost layers before fine-tuning the first ones.
    
    Parameters:
        schedule: List of pairs (epoch, module) that describes when to enable
            for training specific layers of a model.
        start_frozen: If True, then all network layers are set frozen before 
            the training loop is started.
        verbose: If True, the callback prints layer name after each un-freezing.

    """
    order = Order.Schedule(10)
    
    def __init__(self, steps: list, start_frozen: bool=True, verbose: bool=False):
        self.steps = sorted(steps, key=itemgetter(0))
        self.start_frozen = start_frozen
        self.verbose = verbose

    def training_started(self, **kwargs):
        if self.start_frozen:
            if self.verbose:
                write_output('Freezing all model layers\n')
            freeze_all(self.group.model)
        
    def epoch_started(self, epoch, **kwargs):
        for epoch_no, keys in self.steps:
            if isinstance(keys, str):
                keys = [keys]
            if epoch == epoch_no:
                set_trainable(self.group.model, keys)
                if self.verbose:
                    write_output(f'Un-freezing layer(s): {keys}\n')
                break

In [3]:
from collections import OrderedDict

import numpy as np
from torch import nn
from torch.nn.functional import cross_entropy

from loop.modules import fc_network, fc, Flatten
from loop.training import Loop
from loop.testing import get_mnist

trn_ds, val_ds = get_mnist(flat=False)

net = nn.Sequential(OrderedDict([
    ('features', nn.Sequential(OrderedDict([
        ('block1', nn.Sequential(OrderedDict([
            ('conv1', nn.Conv2d(1, 10, 3)),
            ('relu1', nn.ReLU()),
            ('conv2', nn.Conv2d(10, 32, 3)),
            ('relu2', nn.ReLU())
        ]))),
        ('block2', nn.Sequential(OrderedDict([
            ('conv1', nn.Conv2d(32, 32, 3)),
            ('relu1', nn.ReLU()),
            ('conv2', nn.Conv2d(32, 32, 3))
        ]))),
        ('top', nn.Sequential(OrderedDict([
            ('pool', nn.AdaptiveAvgPool2d(1)),
            ('flat', Flatten()),
            ('fc1', nn.Linear(32, 16)),
            ('relu1', nn.ReLU()),
            ('fc2', nn.Linear(16, 10))
        ])))
    ])))
]))

steps = [
    (1, 'features.top'),
    (3, 'features.block2.conv2'),
    (5, 'features.block2'),
    (7, 'features.block1')
]
loop = Loop(net, cbs=[GradualTraining(steps, verbose=True)], loss_fn=cross_entropy)
loop.fit_datasets(trn_ds, val_ds, epochs=10, batch_size=100)

Freezing all model layers
Un-freezing layer(s): ['features.top']
Epoch:    1 | train_loss=2.2104, valid_loss=2.1909
Epoch:    2 | train_loss=2.0967, valid_loss=2.0835
Un-freezing layer(s): ['features.block2.conv2']
Epoch:    3 | train_loss=2.0564, valid_loss=2.0325
Epoch:    4 | train_loss=2.0331, valid_loss=2.0114
Un-freezing layer(s): ['features.block2']
Epoch:    5 | train_loss=1.6722, valid_loss=1.5993
Epoch:    6 | train_loss=1.2514, valid_loss=1.1457
Un-freezing layer(s): ['features.block1']
Epoch:    7 | train_loss=0.8768, valid_loss=0.7239
Epoch:    8 | train_loss=0.5097, valid_loss=0.3482
Epoch:    9 | train_loss=0.3720, valid_loss=0.2508
Epoch:   10 | train_loss=0.2891, valid_loss=0.1749
