-
Notifications
You must be signed in to change notification settings - Fork 18
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
refactor DefaultTrainer.from_checkpoint #53
Changes from 3 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,7 +3,7 @@ | |
import time | ||
import warnings | ||
from importlib import import_module | ||
from typing import Any, Callable, Dict, Optional | ||
from typing import Any, Callable, Dict, Optional, Union | ||
|
||
import numpy as np | ||
import torch | ||
|
@@ -38,6 +38,9 @@ def __init__( | |
if name is None and not issubclass(logger, WandbLogger): | ||
raise TypeError("Name cannot be None if not using the WandbLogger") | ||
|
||
if not all(hasattr(loader, "shuffle") for loader in [train_loader, val_loader]): | ||
raise ValueError(f"{self.__class__} requires any dataloader to have 'shuffle' attribute.") | ||
|
||
self._generate_name = name is None | ||
self.name = name | ||
self.train_loader = train_loader | ||
|
@@ -76,64 +79,119 @@ def iteration(self): | |
def epoch(self): | ||
return self._epoch | ||
|
||
@classmethod | ||
def from_checkpoint(cls, checkpoint_folder, name="best", device=None): | ||
save_path = os.path.join(checkpoint_folder, f"{name}.pt") | ||
if not os.path.exists(save_path): | ||
raise ValueError(f"Cannot find checkpoint {save_path}") | ||
save_dict = torch.load(save_path, map_location=device) | ||
|
||
init_data = save_dict["init"] | ||
model_p, model_m = init_data["model_class"].rsplit(".", 1) | ||
model_class = getattr(import_module(model_p), model_m) | ||
model = model_class(**init_data["model_kwargs"]) | ||
class Deserializer: | ||
""" | ||
Determines how to deserialize the trainer kwargs from serialized 'init_data' | ||
|
||
Examples: | ||
To extend the initialization process you can inherite from this Deserializer in an inherited Trainer class. | ||
Note that `DefaultTrainer.Deserializer.load_generic()` covers most cases already. | ||
|
||
This example adds `the_answer` kwarg, which requires 'calculations' upon initialization | ||
>>> class MyTrainer(DefaultTrainer): | ||
>>> def __init__(self, *args, the_answer: int, **kwargs): | ||
>>> super().__init__(*args, **kwargs) | ||
>>> self.the_answer = the_answer | ||
>>> | ||
>>> class Deserializer(DefaultTrainer.Deserializer): | ||
>>> def __call__(self, name, *args, **kwargs): | ||
>>> if name == "the_answer": | ||
>>> return self.load_the_answer() | ||
>>> else: | ||
>>> return super().__call__(name, *args, **kwargs) | ||
>>> | ||
>>> def load_the_answer(self): | ||
>>> if self.device.type == "cpu": | ||
>>> return ... # complex 'calculation' to arrive at the answer | ||
>>> else: | ||
>>> return 42 | ||
""" | ||
|
||
def __init__(self, init_data: dict, save_path: str, device: Union[str, torch.device]): | ||
self.init_data = init_data | ||
self.save_path = save_path | ||
self.device = torch.device(self.init_data["device"]) if device is None else torch.device(device) | ||
|
||
def __call__( | ||
self, | ||
name: str, | ||
*dynamic_args, | ||
optional=False, | ||
only_class=False, | ||
dynamic_kwargs: Optional[Dict[str, Any]] = None, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there a particular reason you're not using There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe I am missing something, but I don't see why this is the case. As far as I can see this is perfectly valid python: def my_func(
kwarg_name, *args,
optional=False, only_class=False,
**kwargs
):
print(optional)
my_func("x", "y", optional=True, blub="xyz") (prints True) |
||
): | ||
if name.endswith("_loader"): | ||
constantinpape marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return self.load_loader(name) | ||
if name == "device": | ||
return self.device | ||
else: | ||
return self.load_generic( | ||
name, *dynamic_args, optional=optional, only_class=only_class, dynamic_kwargs=dynamic_kwargs | ||
) | ||
|
||
optimizer_p, optimizer_m = init_data["optimizer_class"].rsplit(".", 1) | ||
optimizer_class = getattr(import_module(optimizer_p), optimizer_m) | ||
optimizer = optimizer_class(model.parameters(), **init_data["optimizer_kwargs"]) | ||
def load_generic( | ||
self, name: str, *dynamic_args, optional: bool, only_class: bool, dynamic_kwargs: Optional[Dict[str, Any]] | ||
): | ||
if name in self.init_data: | ||
return self.init_data[name] | ||
|
||
def _init(name, optional=False, only_class=False): | ||
this_cls = init_data.get(f"{name}_class", None) | ||
this_cls = self.init_data.get(f"{name}_class", None) | ||
if this_cls is None and optional: | ||
return None | ||
elif this_cls is None and not optional: | ||
raise RuntimeError(f"Could not find init data for {name} in {save_path}") | ||
raise RuntimeError(f"Could not find init data for {name} in {self.save_path}") | ||
|
||
cls_p, cls_m = this_cls.rsplit(".", 1) | ||
this_cls = getattr(import_module(cls_p), cls_m) | ||
if only_class: | ||
return this_cls | ||
kwargs = init_data[f"{name}_kwargs"] | ||
if name == "lr_scheduler": | ||
return this_cls(optimizer, **kwargs) | ||
else: | ||
return this_cls(**kwargs) | ||
return this_cls(*dynamic_args, **self.init_data.get(f"{name}_kwargs", {}), **(dynamic_kwargs or {})) | ||
|
||
def _init_loader(name): | ||
ds = init_data[f"{name}_dataset"] | ||
loader_kwargs = init_data[f"{name}_loader_kwargs"] | ||
def load_loader(self, name): | ||
ds = self.init_data[f"{name.replace('_loader', '')}_dataset"] | ||
loader_kwargs = self.init_data[f"{name}_kwargs"] | ||
loader = torch.utils.data.DataLoader(ds, **loader_kwargs) | ||
# monkey patch shuffle attribute to the loader | ||
loader.shuffle = loader_kwargs.get("shuffle", False) | ||
return loader | ||
|
||
device = torch.device(init_data["device"]) if device is None else torch.device(device) | ||
trainer = cls( | ||
name=os.path.split(checkpoint_folder)[1], | ||
train_loader=_init_loader("train"), | ||
val_loader=_init_loader("val"), | ||
@staticmethod | ||
def _get_save_dict(save_path, device): | ||
if not os.path.exists(save_path): | ||
raise ValueError(f"Cannot find checkpoint {save_path}") | ||
|
||
return torch.load(save_path, map_location=device) | ||
|
||
@staticmethod | ||
def _get_trainer_kwargs(load: Deserializer): | ||
model = load("model") | ||
optimizer = load("optimizer", model.parameters()) | ||
|
||
return dict( | ||
name=os.path.split(os.path.dirname(load.save_path))[1], | ||
train_loader=load("train_loader"), | ||
val_loader=load("val_loader"), | ||
model=model, | ||
loss=_init("loss"), | ||
loss=load("loss"), | ||
optimizer=optimizer, | ||
metric=_init("metric"), | ||
device=device, | ||
lr_scheduler=_init("lr_scheduler", optional=True), | ||
log_image_interval=init_data["log_image_interval"], | ||
mixed_precision=init_data["mixed_precision"], | ||
early_stopping=init_data["early_stopping"], | ||
logger=_init("logger", only_class=True, optional=True), | ||
logger_kwargs=init_data.get("logger_kwargs"), | ||
metric=load("metric"), | ||
device=load("device"), | ||
lr_scheduler=load("lr_scheduler", optimizer, optional=True), | ||
log_image_interval=load("log_image_interval"), | ||
mixed_precision=load("mixed_precision"), | ||
early_stopping=load("early_stopping"), | ||
logger=load("logger", only_class=True, optional=True), | ||
logger_kwargs=load("logger_kwargs", optional=True), | ||
) | ||
|
||
@classmethod | ||
def from_checkpoint(cls, checkpoint_folder, name="best", device=None): | ||
save_path = os.path.join(checkpoint_folder, f"{name}.pt") | ||
save_dict = cls._get_save_dict(save_path, device) | ||
load = cls.Deserializer(save_dict["init"], save_path, device) | ||
trainer_kwargs = cls._get_trainer_kwargs(load) | ||
trainer = cls(**trainer_kwargs) | ||
trainer._initialize(0, save_dict) | ||
return trainer | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What exactly is
name
here? I thought it was the name of the trainer, i.e. the checkpoint name. But looks like it's something else.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, I understand now that this is the
name
of thekwarg
. But this should be explained a bit better, it's somehow not immediately obvious. I think it would be best if you can just add the "normal usage" in the doc string, i.e. how the deserializer is normally called for a kwarg.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good idea. I suppose we can additionally rename
name
tokwarg_name
orattribute
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, let's go with
kwarg_name