From 6a8f2a7c10f377a79c09a7fb620f6f0563319eb9 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Wed, 6 Apr 2022 20:55:18 +0200 Subject: [PATCH 1/4] raise early on missing 'shuffle' attribute --- torch_em/trainer/default_trainer.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/torch_em/trainer/default_trainer.py b/torch_em/trainer/default_trainer.py index 707eb663..31d3578c 100644 --- a/torch_em/trainer/default_trainer.py +++ b/torch_em/trainer/default_trainer.py @@ -38,6 +38,9 @@ def __init__( if name is None and not issubclass(logger, WandbLogger): raise TypeError("Name cannot be None if not using the WandbLogger") + if not all(hasattr(loader, "shuffle") for loader in [train_loader, val_loader]): + raise ValueError(f"{self.__class__} requires any dataloader to have 'shuffle' attribute.") + self._generate_name = name is None self.name = name self.train_loader = train_loader From f2d03ceca6d256f9ddd7001e06109ba0cd2dfaee Mon Sep 17 00:00:00 2001 From: fynnbe Date: Wed, 6 Apr 2022 21:55:35 +0200 Subject: [PATCH 2/4] refactor DefaultTrainer.from_checkpoint --- torch_em/trainer/default_trainer.py | 133 ++++++++++++++++++++-------- 1 file changed, 94 insertions(+), 39 deletions(-) diff --git a/torch_em/trainer/default_trainer.py b/torch_em/trainer/default_trainer.py index 31d3578c..30daeef5 100644 --- a/torch_em/trainer/default_trainer.py +++ b/torch_em/trainer/default_trainer.py @@ -3,7 +3,7 @@ import time import warnings from importlib import import_module -from typing import Any, Callable, Dict, Optional +from typing import Any, Callable, Dict, Optional, Union import numpy as np import torch @@ -79,64 +79,119 @@ def iteration(self): def epoch(self): return self._epoch - @classmethod - def from_checkpoint(cls, checkpoint_folder, name="best", device=None): - save_path = os.path.join(checkpoint_folder, f"{name}.pt") - if not os.path.exists(save_path): - raise ValueError(f"Cannot find checkpoint {save_path}") - save_dict = torch.load(save_path, map_location=device) - - init_data = save_dict["init"] - model_p, model_m = init_data["model_class"].rsplit(".", 1) - model_class = getattr(import_module(model_p), model_m) - model = model_class(**init_data["model_kwargs"]) + class Deserializer: + """ + Determines how to deserialize the trainer kwargs from serialized 'init_data' + + Examples: + To extend the initialization process you can inherite from this Deserializer in an inherited Trainer class. + Note that `DefaultTrainer.Deserializer.load_generic()` covers most cases already. + + This example adds `the_answer` kwarg, which requires 'calculations' upon initialization + >>> class MyTrainer(DefaultTrainer): + >>> def __init__(self, *args, the_answer: int, **kwargs): + >>> super().__init__(*args, **kwargs) + >>> self.the_answer = the_answer + >>> + >>> class Deserializer(DefaultTrainer.Deserializer): + >>> def __call__(self, name, *args, **kwargs): + >>> if name == "the_answer": + >>> return self.load_the_answer() + >>> else: + >>> return super().__call__(name, *args, **kwargs) + >>> + >>> def load_the_answer(self): + >>> if self.device.type == "cpu": + >>> return ... # complex 'calculation' to arrive at the answer + >>> else: + >>> return 42 + """ + + def __init__(self, init_data: dict, save_path: str, device: Union[str, torch.device]): + self.init_data = init_data + self.save_path = save_path + self.device = torch.device(self.init_data["device"]) if device is None else torch.device(device) + + def __call__( + self, + name: str, + *dynamic_args, + optional=False, + only_class=False, + dynamic_kwargs: Optional[Dict[str, Any]] = None, + ): + if name.endswith("_loader"): + return self.load_loader(name) + if name == "device": + return self.device + else: + return self.load_generic( + name, *dynamic_args, optional=optional, only_class=only_class, dynamic_kwargs=dynamic_kwargs + ) - optimizer_p, optimizer_m = init_data["optimizer_class"].rsplit(".", 1) - optimizer_class = getattr(import_module(optimizer_p), optimizer_m) - optimizer = optimizer_class(model.parameters(), **init_data["optimizer_kwargs"]) + def load_generic( + self, name: str, *dynamic_args, optional: bool, only_class: bool, dynamic_kwargs: Optional[Dict[str, Any]] + ): + if name in self.init_data: + return self.init_data[name] - def _init(name, optional=False, only_class=False): - this_cls = init_data.get(f"{name}_class", None) + this_cls = self.init_data.get(f"{name}_class", None) if this_cls is None and optional: return None elif this_cls is None and not optional: - raise RuntimeError(f"Could not find init data for {name} in {save_path}") + raise RuntimeError(f"Could not find init data for {name} in {self.save_path}") + cls_p, cls_m = this_cls.rsplit(".", 1) this_cls = getattr(import_module(cls_p), cls_m) if only_class: return this_cls - kwargs = init_data[f"{name}_kwargs"] - if name == "lr_scheduler": - return this_cls(optimizer, **kwargs) else: - return this_cls(**kwargs) + return this_cls(*dynamic_args, **self.init_data.get(f"{name}_kwargs", {}), **(dynamic_kwargs or {})) - def _init_loader(name): - ds = init_data[f"{name}_dataset"] - loader_kwargs = init_data[f"{name}_loader_kwargs"] + def load_loader(self, name): + ds = self.init_data[f"{name}_dataset"] + loader_kwargs = self.init_data[f"{name}_loader_kwargs"] loader = torch.utils.data.DataLoader(ds, **loader_kwargs) # monkey patch shuffle attribute to the loader loader.shuffle = loader_kwargs.get("shuffle", False) return loader - device = torch.device(init_data["device"]) if device is None else torch.device(device) - trainer = cls( - name=os.path.split(checkpoint_folder)[1], - train_loader=_init_loader("train"), - val_loader=_init_loader("val"), + @staticmethod + def _get_save_dict(save_path, device): + if not os.path.exists(save_path): + raise ValueError(f"Cannot find checkpoint {save_path}") + + return torch.load(save_path, map_location=device) + + @staticmethod + def _get_trainer_kwargs(load: Deserializer): + model = load("model") + optimizer = load("optimizer", model.parameters()) + + return dict( + name=os.path.split(os.path.dirname(load.save_path))[1], + train_loader=load("train_loader"), + val_loader=load("val_loader"), model=model, - loss=_init("loss"), + loss=load("loss"), optimizer=optimizer, - metric=_init("metric"), - device=device, - lr_scheduler=_init("lr_scheduler", optional=True), - log_image_interval=init_data["log_image_interval"], - mixed_precision=init_data["mixed_precision"], - early_stopping=init_data["early_stopping"], - logger=_init("logger", only_class=True, optional=True), - logger_kwargs=init_data.get("logger_kwargs"), + metric=load("metric"), + device=load("device"), + lr_scheduler=load("lr_scheduler", optimizer, optional=True), + log_image_interval=load("log_image_interval"), + mixed_precision=load("mixed_precision"), + early_stopping=load("early_stopping"), + logger=load("logger", only_class=True, optional=True), + logger_kwargs=load("logger_kwargs", optional=True), ) + @classmethod + def from_checkpoint(cls, checkpoint_folder, name="best", device=None): + save_path = os.path.join(checkpoint_folder, f"{name}.pt") + save_dict = cls._get_save_dict(save_path, device) + load = cls.Deserializer(save_dict["init"], save_path, device) + trainer_kwargs = cls._get_trainer_kwargs(load) + trainer = cls(**trainer_kwargs) trainer._initialize(0, save_dict) return trainer From 994b1b7fdce0a49bd0c104ab4ad0ccf9c653a749 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Thu, 7 Apr 2022 16:13:01 +0200 Subject: [PATCH 3/4] fix load_loader --- torch_em/trainer/default_trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch_em/trainer/default_trainer.py b/torch_em/trainer/default_trainer.py index 30daeef5..ba6b372d 100644 --- a/torch_em/trainer/default_trainer.py +++ b/torch_em/trainer/default_trainer.py @@ -149,8 +149,8 @@ def load_generic( return this_cls(*dynamic_args, **self.init_data.get(f"{name}_kwargs", {}), **(dynamic_kwargs or {})) def load_loader(self, name): - ds = self.init_data[f"{name}_dataset"] - loader_kwargs = self.init_data[f"{name}_loader_kwargs"] + ds = self.init_data[f"{name.replace('_loader', '')}_dataset"] + loader_kwargs = self.init_data[f"{name}_kwargs"] loader = torch.utils.data.DataLoader(ds, **loader_kwargs) # monkey patch shuffle attribute to the loader loader.shuffle = loader_kwargs.get("shuffle", False) From ea5b9c46cc8dd230d9f721a4b8183f0ee5ee87b8 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Fri, 8 Apr 2022 20:59:44 +0200 Subject: [PATCH 4/4] improve Deserializer --- torch_em/trainer/default_trainer.py | 109 ++++++++++++++++------------ 1 file changed, 61 insertions(+), 48 deletions(-) diff --git a/torch_em/trainer/default_trainer.py b/torch_em/trainer/default_trainer.py index ba6b372d..4e82ce99 100644 --- a/torch_em/trainer/default_trainer.py +++ b/torch_em/trainer/default_trainer.py @@ -80,31 +80,27 @@ def epoch(self): return self._epoch class Deserializer: - """ - Determines how to deserialize the trainer kwargs from serialized 'init_data' + """Determines how to deserialize the trainer kwargs from serialized 'init_data' Examples: To extend the initialization process you can inherite from this Deserializer in an inherited Trainer class. Note that `DefaultTrainer.Deserializer.load_generic()` covers most cases already. - This example adds `the_answer` kwarg, which requires 'calculations' upon initialization + This example adds `the_answer` kwarg, which requires 'calculations' upon initialization: >>> class MyTrainer(DefaultTrainer): >>> def __init__(self, *args, the_answer: int, **kwargs): >>> super().__init__(*args, **kwargs) - >>> self.the_answer = the_answer + >>> self.the_answer = the_answer # this allows the default Serializer to save the new kwarg, + >>> # see DefaultTrainer.Serializer >>> >>> class Deserializer(DefaultTrainer.Deserializer): - >>> def __call__(self, name, *args, **kwargs): - >>> if name == "the_answer": - >>> return self.load_the_answer() - >>> else: - >>> return super().__call__(name, *args, **kwargs) - >>> >>> def load_the_answer(self): + >>> generic_answer = self.init_data["the_answer"] # default Deserializer would return this + >>> # (device dependent) special deserialization >>> if self.device.type == "cpu": - >>> return ... # complex 'calculation' to arrive at the answer + >>> return generic_answer + 1 >>> else: - >>> return 42 + >>> return generic_answer * 2 """ def __init__(self, init_data: dict, save_path: str, device: Union[str, torch.device]): @@ -114,47 +110,59 @@ def __init__(self, init_data: dict, save_path: str, device: Union[str, torch.dev def __call__( self, - name: str, + kwarg_name: str, *dynamic_args, optional=False, only_class=False, dynamic_kwargs: Optional[Dict[str, Any]] = None, ): - if name.endswith("_loader"): - return self.load_loader(name) - if name == "device": + if kwarg_name == "device": return self.device + elif kwarg_name.endswith("_loader"): + return self.load_data_loader(kwarg_name) else: - return self.load_generic( - name, *dynamic_args, optional=optional, only_class=only_class, dynamic_kwargs=dynamic_kwargs + load = getattr(self, f"load_{kwarg_name}", self.load_generic) + + return load( + kwarg_name, *dynamic_args, optional=optional, only_class=only_class, dynamic_kwargs=dynamic_kwargs ) + def load_data_loader(self, loader_name): + ds = self.init_data[loader_name.replace("_loader", "_dataset")] + loader_kwargs = self.init_data[f"{loader_name}_kwargs"] + loader = torch.utils.data.DataLoader(ds, **loader_kwargs) + # monkey patch shuffle loader_name to the loader + loader.shuffle = loader_kwargs.get("shuffle", False) + return loader + def load_generic( - self, name: str, *dynamic_args, optional: bool, only_class: bool, dynamic_kwargs: Optional[Dict[str, Any]] + self, + kwarg_name: str, + *dynamic_args, + optional: bool, + only_class: bool, + dynamic_kwargs: Optional[Dict[str, Any]], ): - if name in self.init_data: - return self.init_data[name] - - this_cls = self.init_data.get(f"{name}_class", None) - if this_cls is None and optional: - return None - elif this_cls is None and not optional: - raise RuntimeError(f"Could not find init data for {name} in {self.save_path}") - + if kwarg_name in self.init_data: + return self.init_data[kwarg_name] + + this_cls = self.init_data.get(f"{kwarg_name}_class", None) + if this_cls is None: + if optional: + return None + else: + raise RuntimeError(f"Could not find init data for {kwarg_name} in {self.save_path}") + + assert isinstance(this_cls, str), this_cls + assert "." in this_cls, this_cls cls_p, cls_m = this_cls.rsplit(".", 1) this_cls = getattr(import_module(cls_p), cls_m) if only_class: return this_cls else: - return this_cls(*dynamic_args, **self.init_data.get(f"{name}_kwargs", {}), **(dynamic_kwargs or {})) - - def load_loader(self, name): - ds = self.init_data[f"{name.replace('_loader', '')}_dataset"] - loader_kwargs = self.init_data[f"{name}_kwargs"] - loader = torch.utils.data.DataLoader(ds, **loader_kwargs) - # monkey patch shuffle attribute to the loader - loader.shuffle = loader_kwargs.get("shuffle", False) - return loader + return this_cls( + *dynamic_args, **self.init_data.get(f"{kwarg_name}_kwargs", {}), **(dynamic_kwargs or {}) + ) @staticmethod def _get_save_dict(save_path, device): @@ -168,29 +176,34 @@ def _get_trainer_kwargs(load: Deserializer): model = load("model") optimizer = load("optimizer", model.parameters()) - return dict( + kwargs = dict( name=os.path.split(os.path.dirname(load.save_path))[1], - train_loader=load("train_loader"), - val_loader=load("val_loader"), model=model, - loss=load("loss"), optimizer=optimizer, - metric=load("metric"), - device=load("device"), lr_scheduler=load("lr_scheduler", optimizer, optional=True), - log_image_interval=load("log_image_interval"), - mixed_precision=load("mixed_precision"), - early_stopping=load("early_stopping"), logger=load("logger", only_class=True, optional=True), logger_kwargs=load("logger_kwargs", optional=True), ) + for kw_name in [ + "train_loader", + "val_loader", + "loss", + "metric", + "device", + "log_image_interval", + "mixed_precision", + "early_stopping", + ]: + kwargs[kw_name] = load(kw_name) + + return kwargs @classmethod def from_checkpoint(cls, checkpoint_folder, name="best", device=None): save_path = os.path.join(checkpoint_folder, f"{name}.pt") save_dict = cls._get_save_dict(save_path, device) - load = cls.Deserializer(save_dict["init"], save_path, device) - trainer_kwargs = cls._get_trainer_kwargs(load) + deserializer = cls.Deserializer(save_dict["init"], save_path, device) + trainer_kwargs = cls._get_trainer_kwargs(deserializer) trainer = cls(**trainer_kwargs) trainer._initialize(0, save_dict) return trainer