Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor DefaultTrainer.from_checkpoint #53

Merged
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 97 additions & 39 deletions torch_em/trainer/default_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import time
import warnings
from importlib import import_module
from typing import Any, Callable, Dict, Optional
from typing import Any, Callable, Dict, Optional, Union

import numpy as np
import torch
Expand Down Expand Up @@ -38,6 +38,9 @@ def __init__(
if name is None and not issubclass(logger, WandbLogger):
raise TypeError("Name cannot be None if not using the WandbLogger")

if not all(hasattr(loader, "shuffle") for loader in [train_loader, val_loader]):
raise ValueError(f"{self.__class__} requires any dataloader to have 'shuffle' attribute.")

self._generate_name = name is None
self.name = name
self.train_loader = train_loader
Expand Down Expand Up @@ -76,64 +79,119 @@ def iteration(self):
def epoch(self):
return self._epoch

@classmethod
def from_checkpoint(cls, checkpoint_folder, name="best", device=None):
save_path = os.path.join(checkpoint_folder, f"{name}.pt")
if not os.path.exists(save_path):
raise ValueError(f"Cannot find checkpoint {save_path}")
save_dict = torch.load(save_path, map_location=device)

init_data = save_dict["init"]
model_p, model_m = init_data["model_class"].rsplit(".", 1)
model_class = getattr(import_module(model_p), model_m)
model = model_class(**init_data["model_kwargs"])
class Deserializer:
"""
Determines how to deserialize the trainer kwargs from serialized 'init_data'

Examples:
To extend the initialization process you can inherite from this Deserializer in an inherited Trainer class.
Note that `DefaultTrainer.Deserializer.load_generic()` covers most cases already.

This example adds `the_answer` kwarg, which requires 'calculations' upon initialization
>>> class MyTrainer(DefaultTrainer):
>>> def __init__(self, *args, the_answer: int, **kwargs):
>>> super().__init__(*args, **kwargs)
>>> self.the_answer = the_answer
>>>
>>> class Deserializer(DefaultTrainer.Deserializer):
>>> def __call__(self, name, *args, **kwargs):
>>> if name == "the_answer":
>>> return self.load_the_answer()
>>> else:
>>> return super().__call__(name, *args, **kwargs)
>>>
>>> def load_the_answer(self):
>>> if self.device.type == "cpu":
>>> return ... # complex 'calculation' to arrive at the answer
>>> else:
>>> return 42
"""

def __init__(self, init_data: dict, save_path: str, device: Union[str, torch.device]):
self.init_data = init_data
self.save_path = save_path
self.device = torch.device(self.init_data["device"]) if device is None else torch.device(device)

def __call__(
self,
name: str,
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What exactly is name here? I thought it was the name of the trainer, i.e. the checkpoint name. But looks like it's something else.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I understand now that this is the name of the kwarg. But this should be explained a bit better, it's somehow not immediately obvious. I think it would be best if you can just add the "normal usage" in the doc string, i.e. how the deserializer is normally called for a kwarg.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good idea. I suppose we can additionally rename name to kwarg_name or attribute?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, let's go with kwarg_name

*dynamic_args,
optional=False,
only_class=False,
dynamic_kwargs: Optional[Dict[str, Any]] = None,
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a particular reason you're not using **dynamic_kwargs here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, **dynamic_kwargs would not allow to optional (or only_class) as a kwarg

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe I am missing something, but I don't see why this is the case. As far as I can see this is perfectly valid python:

def my_func(
    kwarg_name, *args,
    optional=False, only_class=False,
    **kwargs
):
    print(optional)

my_func("x", "y", optional=True, blub="xyz")

(prints True)

):
if name.endswith("_loader"):
constantinpape marked this conversation as resolved.
Show resolved Hide resolved
return self.load_loader(name)
if name == "device":
return self.device
else:
return self.load_generic(
name, *dynamic_args, optional=optional, only_class=only_class, dynamic_kwargs=dynamic_kwargs
)

optimizer_p, optimizer_m = init_data["optimizer_class"].rsplit(".", 1)
optimizer_class = getattr(import_module(optimizer_p), optimizer_m)
optimizer = optimizer_class(model.parameters(), **init_data["optimizer_kwargs"])
def load_generic(
self, name: str, *dynamic_args, optional: bool, only_class: bool, dynamic_kwargs: Optional[Dict[str, Any]]
):
if name in self.init_data:
return self.init_data[name]

def _init(name, optional=False, only_class=False):
this_cls = init_data.get(f"{name}_class", None)
this_cls = self.init_data.get(f"{name}_class", None)
if this_cls is None and optional:
return None
elif this_cls is None and not optional:
raise RuntimeError(f"Could not find init data for {name} in {save_path}")
raise RuntimeError(f"Could not find init data for {name} in {self.save_path}")

cls_p, cls_m = this_cls.rsplit(".", 1)
this_cls = getattr(import_module(cls_p), cls_m)
if only_class:
return this_cls
kwargs = init_data[f"{name}_kwargs"]
if name == "lr_scheduler":
return this_cls(optimizer, **kwargs)
else:
return this_cls(**kwargs)
return this_cls(*dynamic_args, **self.init_data.get(f"{name}_kwargs", {}), **(dynamic_kwargs or {}))

def _init_loader(name):
ds = init_data[f"{name}_dataset"]
loader_kwargs = init_data[f"{name}_loader_kwargs"]
def load_loader(self, name):
ds = self.init_data[f"{name.replace('_loader', '')}_dataset"]
loader_kwargs = self.init_data[f"{name}_kwargs"]
loader = torch.utils.data.DataLoader(ds, **loader_kwargs)
# monkey patch shuffle attribute to the loader
loader.shuffle = loader_kwargs.get("shuffle", False)
return loader

device = torch.device(init_data["device"]) if device is None else torch.device(device)
trainer = cls(
name=os.path.split(checkpoint_folder)[1],
train_loader=_init_loader("train"),
val_loader=_init_loader("val"),
@staticmethod
def _get_save_dict(save_path, device):
if not os.path.exists(save_path):
raise ValueError(f"Cannot find checkpoint {save_path}")

return torch.load(save_path, map_location=device)

@staticmethod
def _get_trainer_kwargs(load: Deserializer):
model = load("model")
optimizer = load("optimizer", model.parameters())

return dict(
name=os.path.split(os.path.dirname(load.save_path))[1],
train_loader=load("train_loader"),
val_loader=load("val_loader"),
model=model,
loss=_init("loss"),
loss=load("loss"),
optimizer=optimizer,
metric=_init("metric"),
device=device,
lr_scheduler=_init("lr_scheduler", optional=True),
log_image_interval=init_data["log_image_interval"],
mixed_precision=init_data["mixed_precision"],
early_stopping=init_data["early_stopping"],
logger=_init("logger", only_class=True, optional=True),
logger_kwargs=init_data.get("logger_kwargs"),
metric=load("metric"),
device=load("device"),
lr_scheduler=load("lr_scheduler", optimizer, optional=True),
log_image_interval=load("log_image_interval"),
mixed_precision=load("mixed_precision"),
early_stopping=load("early_stopping"),
logger=load("logger", only_class=True, optional=True),
logger_kwargs=load("logger_kwargs", optional=True),
)

@classmethod
def from_checkpoint(cls, checkpoint_folder, name="best", device=None):
save_path = os.path.join(checkpoint_folder, f"{name}.pt")
save_dict = cls._get_save_dict(save_path, device)
load = cls.Deserializer(save_dict["init"], save_path, device)
trainer_kwargs = cls._get_trainer_kwargs(load)
trainer = cls(**trainer_kwargs)
trainer._initialize(0, save_dict)
return trainer

Expand Down