Skip to content

Commit

Permalink
Merge pull request #53 from FynnBe/easier_trainer_inheritance
Browse files Browse the repository at this point in the history
refactor DefaultTrainer.from_checkpoint
  • Loading branch information
constantinpape committed Apr 8, 2022
2 parents 82d44dd + ea5b9c4 commit a7c436d
Showing 1 changed file with 119 additions and 48 deletions.
167 changes: 119 additions & 48 deletions torch_em/trainer/default_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import time
import warnings
from importlib import import_module
from typing import Any, Callable, Dict, Optional
from typing import Any, Callable, Dict, Optional, Union

import numpy as np
import torch
Expand Down Expand Up @@ -39,6 +39,9 @@ def __init__(
if name is None and not issubclass(logger, WandbLogger):
raise TypeError("Name cannot be None if not using the WandbLogger")

if not all(hasattr(loader, "shuffle") for loader in [train_loader, val_loader]):
raise ValueError(f"{self.__class__} requires any dataloader to have 'shuffle' attribute.")

self._generate_name = name is None
self.name = name
self.id_ = id_ or name
Expand Down Expand Up @@ -78,64 +81,132 @@ def iteration(self):
def epoch(self):
return self._epoch

@classmethod
def from_checkpoint(cls, checkpoint_folder, name="best", device=None):
save_path = os.path.join(checkpoint_folder, f"{name}.pt")
if not os.path.exists(save_path):
raise ValueError(f"Cannot find checkpoint {save_path}")
save_dict = torch.load(save_path, map_location=device)

init_data = save_dict["init"]
model_p, model_m = init_data["model_class"].rsplit(".", 1)
model_class = getattr(import_module(model_p), model_m)
model = model_class(**init_data["model_kwargs"])

optimizer_p, optimizer_m = init_data["optimizer_class"].rsplit(".", 1)
optimizer_class = getattr(import_module(optimizer_p), optimizer_m)
optimizer = optimizer_class(model.parameters(), **init_data["optimizer_kwargs"])

def _init(name, optional=False, only_class=False):
this_cls = init_data.get(f"{name}_class", None)
if this_cls is None and optional:
return None
elif this_cls is None and not optional:
raise RuntimeError(f"Could not find init data for {name} in {save_path}")
class Deserializer:
"""Determines how to deserialize the trainer kwargs from serialized 'init_data'
Examples:
To extend the initialization process you can inherite from this Deserializer in an inherited Trainer class.
Note that `DefaultTrainer.Deserializer.load_generic()` covers most cases already.
This example adds `the_answer` kwarg, which requires 'calculations' upon initialization:
>>> class MyTrainer(DefaultTrainer):
>>> def __init__(self, *args, the_answer: int, **kwargs):
>>> super().__init__(*args, **kwargs)
>>> self.the_answer = the_answer # this allows the default Serializer to save the new kwarg,
>>> # see DefaultTrainer.Serializer
>>>
>>> class Deserializer(DefaultTrainer.Deserializer):
>>> def load_the_answer(self):
>>> generic_answer = self.init_data["the_answer"] # default Deserializer would return this
>>> # (device dependent) special deserialization
>>> if self.device.type == "cpu":
>>> return generic_answer + 1
>>> else:
>>> return generic_answer * 2
"""

def __init__(self, init_data: dict, save_path: str, device: Union[str, torch.device]):
self.init_data = init_data
self.save_path = save_path
self.device = torch.device(self.init_data["device"]) if device is None else torch.device(device)

def __call__(
self,
kwarg_name: str,
*dynamic_args,
optional=False,
only_class=False,
dynamic_kwargs: Optional[Dict[str, Any]] = None,
):
if kwarg_name == "device":
return self.device
elif kwarg_name.endswith("_loader"):
return self.load_data_loader(kwarg_name)
else:
load = getattr(self, f"load_{kwarg_name}", self.load_generic)

return load(
kwarg_name, *dynamic_args, optional=optional, only_class=only_class, dynamic_kwargs=dynamic_kwargs
)

def load_data_loader(self, loader_name):
ds = self.init_data[loader_name.replace("_loader", "_dataset")]
loader_kwargs = self.init_data[f"{loader_name}_kwargs"]
loader = torch.utils.data.DataLoader(ds, **loader_kwargs)
# monkey patch shuffle loader_name to the loader
loader.shuffle = loader_kwargs.get("shuffle", False)
return loader

def load_generic(
self,
kwarg_name: str,
*dynamic_args,
optional: bool,
only_class: bool,
dynamic_kwargs: Optional[Dict[str, Any]],
):
if kwarg_name in self.init_data:
return self.init_data[kwarg_name]

this_cls = self.init_data.get(f"{kwarg_name}_class", None)
if this_cls is None:
if optional:
return None
else:
raise RuntimeError(f"Could not find init data for {kwarg_name} in {self.save_path}")

assert isinstance(this_cls, str), this_cls
assert "." in this_cls, this_cls
cls_p, cls_m = this_cls.rsplit(".", 1)
this_cls = getattr(import_module(cls_p), cls_m)
if only_class:
return this_cls
kwargs = init_data[f"{name}_kwargs"]
if name == "lr_scheduler":
return this_cls(optimizer, **kwargs)
else:
return this_cls(**kwargs)
return this_cls(
*dynamic_args, **self.init_data.get(f"{kwarg_name}_kwargs", {}), **(dynamic_kwargs or {})
)

def _init_loader(name):
ds = init_data[f"{name}_dataset"]
loader_kwargs = init_data[f"{name}_loader_kwargs"]
loader = torch.utils.data.DataLoader(ds, **loader_kwargs)
# monkey patch shuffle attribute to the loader
loader.shuffle = loader_kwargs.get("shuffle", False)
return loader
@staticmethod
def _get_save_dict(save_path, device):
if not os.path.exists(save_path):
raise ValueError(f"Cannot find checkpoint {save_path}")

device = torch.device(init_data["device"]) if device is None else torch.device(device)
trainer = cls(
name=os.path.split(checkpoint_folder)[1],
train_loader=_init_loader("train"),
val_loader=_init_loader("val"),
return torch.load(save_path, map_location=device)

@staticmethod
def _get_trainer_kwargs(load: Deserializer):
model = load("model")
optimizer = load("optimizer", model.parameters())

kwargs = dict(
name=os.path.split(os.path.dirname(load.save_path))[1],
model=model,
loss=_init("loss"),
optimizer=optimizer,
metric=_init("metric"),
device=device,
lr_scheduler=_init("lr_scheduler", optional=True),
log_image_interval=init_data["log_image_interval"],
mixed_precision=init_data["mixed_precision"],
early_stopping=init_data["early_stopping"],
logger=_init("logger", only_class=True, optional=True),
logger_kwargs=init_data.get("logger_kwargs"),
lr_scheduler=load("lr_scheduler", optimizer, optional=True),
logger=load("logger", only_class=True, optional=True),
logger_kwargs=load("logger_kwargs", optional=True),
)
for kw_name in [
"train_loader",
"val_loader",
"loss",
"metric",
"device",
"log_image_interval",
"mixed_precision",
"early_stopping",
]:
kwargs[kw_name] = load(kw_name)

return kwargs

@classmethod
def from_checkpoint(cls, checkpoint_folder, name="best", device=None):
save_path = os.path.join(checkpoint_folder, f"{name}.pt")
save_dict = cls._get_save_dict(save_path, device)
deserializer = cls.Deserializer(save_dict["init"], save_path, device)
trainer_kwargs = cls._get_trainer_kwargs(deserializer)
trainer = cls(**trainer_kwargs)
trainer._initialize(0, save_dict)
return trainer

Expand Down

0 comments on commit a7c436d

Please sign in to comment.