Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor DefaultTrainer.from_checkpoint #53

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
167 changes: 119 additions & 48 deletions torch_em/trainer/default_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import time
import warnings
from importlib import import_module
from typing import Any, Callable, Dict, Optional
from typing import Any, Callable, Dict, Optional, Union

import numpy as np
import torch
Expand Down Expand Up @@ -38,6 +38,9 @@ def __init__(
if name is None and not issubclass(logger, WandbLogger):
raise TypeError("Name cannot be None if not using the WandbLogger")

if not all(hasattr(loader, "shuffle") for loader in [train_loader, val_loader]):
raise ValueError(f"{self.__class__} requires any dataloader to have 'shuffle' attribute.")

self._generate_name = name is None
self.name = name
self.train_loader = train_loader
Expand Down Expand Up @@ -76,64 +79,132 @@ def iteration(self):
def epoch(self):
return self._epoch

@classmethod
def from_checkpoint(cls, checkpoint_folder, name="best", device=None):
save_path = os.path.join(checkpoint_folder, f"{name}.pt")
if not os.path.exists(save_path):
raise ValueError(f"Cannot find checkpoint {save_path}")
save_dict = torch.load(save_path, map_location=device)

init_data = save_dict["init"]
model_p, model_m = init_data["model_class"].rsplit(".", 1)
model_class = getattr(import_module(model_p), model_m)
model = model_class(**init_data["model_kwargs"])

optimizer_p, optimizer_m = init_data["optimizer_class"].rsplit(".", 1)
optimizer_class = getattr(import_module(optimizer_p), optimizer_m)
optimizer = optimizer_class(model.parameters(), **init_data["optimizer_kwargs"])

def _init(name, optional=False, only_class=False):
this_cls = init_data.get(f"{name}_class", None)
if this_cls is None and optional:
return None
elif this_cls is None and not optional:
raise RuntimeError(f"Could not find init data for {name} in {save_path}")
class Deserializer:
"""Determines how to deserialize the trainer kwargs from serialized 'init_data'

Examples:
To extend the initialization process you can inherite from this Deserializer in an inherited Trainer class.
Note that `DefaultTrainer.Deserializer.load_generic()` covers most cases already.

This example adds `the_answer` kwarg, which requires 'calculations' upon initialization:
>>> class MyTrainer(DefaultTrainer):
>>> def __init__(self, *args, the_answer: int, **kwargs):
>>> super().__init__(*args, **kwargs)
>>> self.the_answer = the_answer # this allows the default Serializer to save the new kwarg,
>>> # see DefaultTrainer.Serializer
>>>
>>> class Deserializer(DefaultTrainer.Deserializer):
>>> def load_the_answer(self):
>>> generic_answer = self.init_data["the_answer"] # default Deserializer would return this
>>> # (device dependent) special deserialization
>>> if self.device.type == "cpu":
>>> return generic_answer + 1
>>> else:
>>> return generic_answer * 2
"""

def __init__(self, init_data: dict, save_path: str, device: Union[str, torch.device]):
self.init_data = init_data
self.save_path = save_path
self.device = torch.device(self.init_data["device"]) if device is None else torch.device(device)

def __call__(
self,
kwarg_name: str,
*dynamic_args,
optional=False,
only_class=False,
dynamic_kwargs: Optional[Dict[str, Any]] = None,
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a particular reason you're not using **dynamic_kwargs here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, **dynamic_kwargs would not allow to optional (or only_class) as a kwarg

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe I am missing something, but I don't see why this is the case. As far as I can see this is perfectly valid python:

def my_func(
    kwarg_name, *args,
    optional=False, only_class=False,
    **kwargs
):
    print(optional)

my_func("x", "y", optional=True, blub="xyz")

(prints True)

):
if kwarg_name == "device":
return self.device
elif kwarg_name.endswith("_loader"):
return self.load_data_loader(kwarg_name)
else:
load = getattr(self, f"load_{kwarg_name}", self.load_generic)

return load(
kwarg_name, *dynamic_args, optional=optional, only_class=only_class, dynamic_kwargs=dynamic_kwargs
)

def load_data_loader(self, loader_name):
ds = self.init_data[loader_name.replace("_loader", "_dataset")]
loader_kwargs = self.init_data[f"{loader_name}_kwargs"]
loader = torch.utils.data.DataLoader(ds, **loader_kwargs)
# monkey patch shuffle loader_name to the loader
loader.shuffle = loader_kwargs.get("shuffle", False)
return loader

def load_generic(
self,
kwarg_name: str,
*dynamic_args,
optional: bool,
only_class: bool,
dynamic_kwargs: Optional[Dict[str, Any]],
):
if kwarg_name in self.init_data:
return self.init_data[kwarg_name]

this_cls = self.init_data.get(f"{kwarg_name}_class", None)
if this_cls is None:
if optional:
return None
else:
raise RuntimeError(f"Could not find init data for {kwarg_name} in {self.save_path}")

assert isinstance(this_cls, str), this_cls
assert "." in this_cls, this_cls
cls_p, cls_m = this_cls.rsplit(".", 1)
this_cls = getattr(import_module(cls_p), cls_m)
if only_class:
return this_cls
kwargs = init_data[f"{name}_kwargs"]
if name == "lr_scheduler":
return this_cls(optimizer, **kwargs)
else:
return this_cls(**kwargs)
return this_cls(
*dynamic_args, **self.init_data.get(f"{kwarg_name}_kwargs", {}), **(dynamic_kwargs or {})
)

def _init_loader(name):
ds = init_data[f"{name}_dataset"]
loader_kwargs = init_data[f"{name}_loader_kwargs"]
loader = torch.utils.data.DataLoader(ds, **loader_kwargs)
# monkey patch shuffle attribute to the loader
loader.shuffle = loader_kwargs.get("shuffle", False)
return loader
@staticmethod
def _get_save_dict(save_path, device):
if not os.path.exists(save_path):
raise ValueError(f"Cannot find checkpoint {save_path}")

device = torch.device(init_data["device"]) if device is None else torch.device(device)
trainer = cls(
name=os.path.split(checkpoint_folder)[1],
train_loader=_init_loader("train"),
val_loader=_init_loader("val"),
return torch.load(save_path, map_location=device)

@staticmethod
def _get_trainer_kwargs(load: Deserializer):
model = load("model")
optimizer = load("optimizer", model.parameters())

kwargs = dict(
name=os.path.split(os.path.dirname(load.save_path))[1],
model=model,
loss=_init("loss"),
optimizer=optimizer,
metric=_init("metric"),
device=device,
lr_scheduler=_init("lr_scheduler", optional=True),
log_image_interval=init_data["log_image_interval"],
mixed_precision=init_data["mixed_precision"],
early_stopping=init_data["early_stopping"],
logger=_init("logger", only_class=True, optional=True),
logger_kwargs=init_data.get("logger_kwargs"),
lr_scheduler=load("lr_scheduler", optimizer, optional=True),
logger=load("logger", only_class=True, optional=True),
logger_kwargs=load("logger_kwargs", optional=True),
)
for kw_name in [
"train_loader",
"val_loader",
"loss",
"metric",
"device",
"log_image_interval",
"mixed_precision",
"early_stopping",
]:
kwargs[kw_name] = load(kw_name)

return kwargs

@classmethod
def from_checkpoint(cls, checkpoint_folder, name="best", device=None):
save_path = os.path.join(checkpoint_folder, f"{name}.pt")
save_dict = cls._get_save_dict(save_path, device)
deserializer = cls.Deserializer(save_dict["init"], save_path, device)
trainer_kwargs = cls._get_trainer_kwargs(deserializer)
trainer = cls(**trainer_kwargs)
trainer._initialize(0, save_dict)
return trainer

Expand Down