In [3]:
import pytorch_lightning as pl
import torch
from torch.utils.data import DataLoader, random_split, Dataset


In [4]:
model = torch.nn.Linear(10, 1)

class RandomDataset(Dataset):
    def __init__(self, size, num_samples):
        self.len = num_samples
        self.data = torch.randn(num_samples, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len

In [21]:
import torch
from pytorch_lightning import LightningModule
from torch.utils.data import Dataset

class BoringModel(LightningModule):

    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)

    def loss(self, batch, prediction):
        # An arbitrary loss to have a loss that updates the model weights during `Trainer.fit` calls
        return torch.nn.functional.mse_loss(prediction, torch.ones_like(prediction))

    def training_step(self, batch, batch_idx):
        output = self.layer(batch)
        loss = self.loss(batch, output)
        return loss
    
    def train_dataloader(self):
        return DataLoader(RandomDataset(32, 32))

    def configure_optimizers(self):
        optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
        lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
        return [optimizer], [lr_scheduler]


In [22]:
model = BoringModel()

trainer = pl.Trainer(max_epochs=10)
trainer.fit(model)

GPU available: False, used: False
TPU available: None, using: 0 TPU cores

  | Name  | Type   | Params
---------------------------------
0 | layer | Linear | 66    
---------------------------------
66        Trainable params
0         Non-trainable params
66        Total params
0.000     Total estimated model params size (MB)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…




1

- https://github.com/PyTorchLightning/pytorch-lightning/blob/d4d959b342586223fbced3266e0550462a85300c/pl_examples/domain_templates/computer_vision_fine_tuning.py

- https://pytorch-lightning.readthedocs.io/en/latest/_modules/pytorch_lightning/callbacks/finetuning.html#BaseFinetuning


- https://github.com/PyTorchLightning/pytorch-lightning/issues/2006

- can't use `finetune_function` because we need acccess to trainer to set new scheduler

In [57]:
from torch.optim.lr_scheduler import OneCycleLR

def phase1(model, current_optimizers, finetuner):
    model.unfreeze()
    #sched = OneCycleLR()

def phase2(model, current_optimizers, finetuner):
    model.unfreeze()
    #sched = OneCycleLR()

    
class FinetuneScheduler(pl.callbacks.BaseFinetuning):

    def __init__(self, phases: list, train_bn: bool = False):
        self.phases = phases
        self.train_bn = train_bn
        self._current_epoch = -1
        self._configure_optimizers_original = None
            
    @property
    def max_epochs(self):
        durations = [phase.get("n_epochs") for phase in self.phases]
        if any([duration is None for duration in durations]):
            raise Exception(
                "Each phase should define the number of epochs it lasts with `n_epochs`")

        return sum(durations)
    
    def _configure_optimizers(self):
        if self._current_epoch == -1:
            return self._configure_optimizers_original()
        pass

    def freeze_before_training(self, pl_module: pl.LightningModule):
        self.freeze(modules=pl_module, train_bn=self.train_bn)
        self._configure_optimizers_original = pl_module.configure_optimizers
        raise Exception(self._configure_optimizers_original)
        #setattr(pl_module, "configure_optimizers", self._configure_optimizers)
        
    def on_train_epoch_start(self, trainer, pl_module):
        """Called when the epoch begins."""
        
        optimizers = trainer.train_loop.prepare_optimizers()
        print(trainer.optimizers)
        epoch_phases_counter = 0
        for phase in self.phases:
            epoch_phases_counter += phase["n_epochs"]
            if epoch_phases_counter == phase["n_epochs"]:
                phase_fn = phase["func"]
                
                new_optimizers = phase_fn(pl_module, optimizers, self)
                if new_optimizers is None:
                    # assume we don't want to change any of the optimizers or schedulers
                    pass
                


cb = FinetuneScheduler([
    {'func': phase1, 'n_epochs': 5},
    {'func': phase2, 'n_epochs': 5},
])

trainer = pl.Trainer(callbacks=[cb], max_epochs=cb.max_epochs)
trainer.fit(model)

GPU available: False, used: False
INFO:lightning:GPU available: False, used: False
TPU available: None, using: 0 TPU cores
INFO:lightning:TPU available: None, using: 0 TPU cores


Exception: <bound method FinetuneScheduler._configure_optimizers of <__main__.FinetuneScheduler object at 0x7f3e4e7ae290>>