diff --git a/torch_em/trainer/default_trainer.py b/torch_em/trainer/default_trainer.py index 707eb663..4e82ce99 100644 --- a/torch_em/trainer/default_trainer.py +++ b/torch_em/trainer/default_trainer.py @@ -3,7 +3,7 @@ import time import warnings from importlib import import_module -from typing import Any, Callable, Dict, Optional +from typing import Any, Callable, Dict, Optional, Union import numpy as np import torch @@ -38,6 +38,9 @@ def __init__( if name is None and not issubclass(logger, WandbLogger): raise TypeError("Name cannot be None if not using the WandbLogger") + if not all(hasattr(loader, "shuffle") for loader in [train_loader, val_loader]): + raise ValueError(f"{self.__class__} requires any dataloader to have 'shuffle' attribute.") + self._generate_name = name is None self.name = name self.train_loader = train_loader @@ -76,64 +79,132 @@ def iteration(self): def epoch(self): return self._epoch - @classmethod - def from_checkpoint(cls, checkpoint_folder, name="best", device=None): - save_path = os.path.join(checkpoint_folder, f"{name}.pt") - if not os.path.exists(save_path): - raise ValueError(f"Cannot find checkpoint {save_path}") - save_dict = torch.load(save_path, map_location=device) - - init_data = save_dict["init"] - model_p, model_m = init_data["model_class"].rsplit(".", 1) - model_class = getattr(import_module(model_p), model_m) - model = model_class(**init_data["model_kwargs"]) - - optimizer_p, optimizer_m = init_data["optimizer_class"].rsplit(".", 1) - optimizer_class = getattr(import_module(optimizer_p), optimizer_m) - optimizer = optimizer_class(model.parameters(), **init_data["optimizer_kwargs"]) - - def _init(name, optional=False, only_class=False): - this_cls = init_data.get(f"{name}_class", None) - if this_cls is None and optional: - return None - elif this_cls is None and not optional: - raise RuntimeError(f"Could not find init data for {name} in {save_path}") + class Deserializer: + """Determines how to deserialize the trainer kwargs from serialized 'init_data' + + Examples: + To extend the initialization process you can inherite from this Deserializer in an inherited Trainer class. + Note that `DefaultTrainer.Deserializer.load_generic()` covers most cases already. + + This example adds `the_answer` kwarg, which requires 'calculations' upon initialization: + >>> class MyTrainer(DefaultTrainer): + >>> def __init__(self, *args, the_answer: int, **kwargs): + >>> super().__init__(*args, **kwargs) + >>> self.the_answer = the_answer # this allows the default Serializer to save the new kwarg, + >>> # see DefaultTrainer.Serializer + >>> + >>> class Deserializer(DefaultTrainer.Deserializer): + >>> def load_the_answer(self): + >>> generic_answer = self.init_data["the_answer"] # default Deserializer would return this + >>> # (device dependent) special deserialization + >>> if self.device.type == "cpu": + >>> return generic_answer + 1 + >>> else: + >>> return generic_answer * 2 + """ + + def __init__(self, init_data: dict, save_path: str, device: Union[str, torch.device]): + self.init_data = init_data + self.save_path = save_path + self.device = torch.device(self.init_data["device"]) if device is None else torch.device(device) + + def __call__( + self, + kwarg_name: str, + *dynamic_args, + optional=False, + only_class=False, + dynamic_kwargs: Optional[Dict[str, Any]] = None, + ): + if kwarg_name == "device": + return self.device + elif kwarg_name.endswith("_loader"): + return self.load_data_loader(kwarg_name) + else: + load = getattr(self, f"load_{kwarg_name}", self.load_generic) + + return load( + kwarg_name, *dynamic_args, optional=optional, only_class=only_class, dynamic_kwargs=dynamic_kwargs + ) + + def load_data_loader(self, loader_name): + ds = self.init_data[loader_name.replace("_loader", "_dataset")] + loader_kwargs = self.init_data[f"{loader_name}_kwargs"] + loader = torch.utils.data.DataLoader(ds, **loader_kwargs) + # monkey patch shuffle loader_name to the loader + loader.shuffle = loader_kwargs.get("shuffle", False) + return loader + + def load_generic( + self, + kwarg_name: str, + *dynamic_args, + optional: bool, + only_class: bool, + dynamic_kwargs: Optional[Dict[str, Any]], + ): + if kwarg_name in self.init_data: + return self.init_data[kwarg_name] + + this_cls = self.init_data.get(f"{kwarg_name}_class", None) + if this_cls is None: + if optional: + return None + else: + raise RuntimeError(f"Could not find init data for {kwarg_name} in {self.save_path}") + + assert isinstance(this_cls, str), this_cls + assert "." in this_cls, this_cls cls_p, cls_m = this_cls.rsplit(".", 1) this_cls = getattr(import_module(cls_p), cls_m) if only_class: return this_cls - kwargs = init_data[f"{name}_kwargs"] - if name == "lr_scheduler": - return this_cls(optimizer, **kwargs) else: - return this_cls(**kwargs) + return this_cls( + *dynamic_args, **self.init_data.get(f"{kwarg_name}_kwargs", {}), **(dynamic_kwargs or {}) + ) - def _init_loader(name): - ds = init_data[f"{name}_dataset"] - loader_kwargs = init_data[f"{name}_loader_kwargs"] - loader = torch.utils.data.DataLoader(ds, **loader_kwargs) - # monkey patch shuffle attribute to the loader - loader.shuffle = loader_kwargs.get("shuffle", False) - return loader + @staticmethod + def _get_save_dict(save_path, device): + if not os.path.exists(save_path): + raise ValueError(f"Cannot find checkpoint {save_path}") - device = torch.device(init_data["device"]) if device is None else torch.device(device) - trainer = cls( - name=os.path.split(checkpoint_folder)[1], - train_loader=_init_loader("train"), - val_loader=_init_loader("val"), + return torch.load(save_path, map_location=device) + + @staticmethod + def _get_trainer_kwargs(load: Deserializer): + model = load("model") + optimizer = load("optimizer", model.parameters()) + + kwargs = dict( + name=os.path.split(os.path.dirname(load.save_path))[1], model=model, - loss=_init("loss"), optimizer=optimizer, - metric=_init("metric"), - device=device, - lr_scheduler=_init("lr_scheduler", optional=True), - log_image_interval=init_data["log_image_interval"], - mixed_precision=init_data["mixed_precision"], - early_stopping=init_data["early_stopping"], - logger=_init("logger", only_class=True, optional=True), - logger_kwargs=init_data.get("logger_kwargs"), + lr_scheduler=load("lr_scheduler", optimizer, optional=True), + logger=load("logger", only_class=True, optional=True), + logger_kwargs=load("logger_kwargs", optional=True), ) + for kw_name in [ + "train_loader", + "val_loader", + "loss", + "metric", + "device", + "log_image_interval", + "mixed_precision", + "early_stopping", + ]: + kwargs[kw_name] = load(kw_name) + + return kwargs + @classmethod + def from_checkpoint(cls, checkpoint_folder, name="best", device=None): + save_path = os.path.join(checkpoint_folder, f"{name}.pt") + save_dict = cls._get_save_dict(save_path, device) + deserializer = cls.Deserializer(save_dict["init"], save_path, device) + trainer_kwargs = cls._get_trainer_kwargs(deserializer) + trainer = cls(**trainer_kwargs) trainer._initialize(0, save_dict) return trainer