-
Notifications
You must be signed in to change notification settings - Fork 1.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added at_train_end logic to base pipeline #1932
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should either make sure that the datamanager and model train_end
are called, or remove them and leave it to the user to add such a function and call it in the pipeline.
def at_train_end(self, **kwargs: Any) -> Optional[Any]: # pylint: disable=unused-argument disable=no-self-use | ||
"""Called at end of training for optional datamanager outputs.""" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is never called, I worry that users will override this function and expect it to be called.
nerfstudio/models/base_model.py
Outdated
|
||
def at_train_end(self, **kwargs: Any) -> Optional[Any]: # pylint: disable=unused-argument disable=no-self-use | ||
"""Called at end of training for optional model outputs.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also never called, same potential confusion as above.
Wouldn’t it be better to implement this using a new callback type? That way no code would have to be added to the model and the datamanager. Also, shouldn’t it be “on_train_end” instead of at..? |
Thanks for the feedback again. I have added a new callback type OnTrainEndCallback. I removed unnecessary on_train_end calls from base model and base pipeline. This callback has access to the pipeline object. I also made naming consistent "on train end". Here is an example of how I am using the new callback type within my custom pipeline similar to my previous implementation:
|
I am very sorry I didn't explain it clearly. I meant the already existing callback infrastructure. ...registering the callback here: nerfstudio/nerfstudio/engine/callbacks.py Line 42 in 8b9574b
|
I think that would integrate in a more concise way with the rest of the code. |
Whew, no problem :) Happy to learn, maybe third time is the the charm! So I integrated it into the existing callback locations. Only thing I am not sure about is regarding the "step" var when calling run_callback_at_location(step,location) at the end of training. There is really no way to know if step==max-train-iterations inside the TrainingCallback class since this is not exposed. I have set step=None for now to signify that this is being called at train end. |
nerfstudio/engine/callbacks.py
Outdated
@@ -44,6 +44,7 @@ class TrainingCallbackLocation(Enum): | |||
|
|||
BEFORE_TRAIN_ITERATION = auto() | |||
AFTER_TRAIN_ITERATION = auto() | |||
ON_TRAIN_END = auto() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we name it AFTER_TRAIN_END?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Or perhaps "AFTER _TRAIN"?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have have now renamed it to AFTER_TRAIN
nerfstudio/engine/trainer.py
Outdated
@@ -297,6 +297,10 @@ def train(self) -> None: | |||
table.add_row("Checkpoint Directory", str(self.checkpoint_dir)) | |||
CONSOLE.print(Panel(table, title="[bold][green]:tada: Training Finished :tada:[/bold]", expand=False)) | |||
|
|||
# on train end callbacks | |||
for callback in self.callbacks: | |||
callback.run_callback_at_location(step=None, location=TrainingCallbackLocation.ON_TRAIN_END) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you pass the actual step here?
nerfstudio/engine/callbacks.py
Outdated
|
||
def run_callback_at_location(self, step: int, location: TrainingCallbackLocation) -> None: | ||
def run_callback_at_location(self, step: Union[int, None], location: TrainingCallbackLocation) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please keep the signature as int
nerfstudio/engine/callbacks.py
Outdated
self.where_to_run = where_to_run | ||
self.update_every_num_iters = update_every_num_iters | ||
self.iters = iters | ||
self.func = func | ||
self.args = args if args is not None else [] | ||
self.kwargs = kwargs if kwargs is not None else {} | ||
|
||
def run_callback(self, step: int) -> None: | ||
def run_callback(self, step: Union[int, None]) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please keep int here.
nerfstudio/engine/callbacks.py
Outdated
@@ -71,15 +72,15 @@ def __init__( | |||
): | |||
assert ( | |||
"step" in signature(func).parameters.keys() | |||
), f"'step: int' must be an argument in the callback function 'func': {func.__name__}" | |||
), f"'step: Union[int, None]' must be an argument in the callback function 'func': {func.__name__}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here, please keep int
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Here I have added at_train_end() logic to the base pipeline. To accommodate various different user needs, I added similar methods to the base_datamanager and base_model which are callable with **kwargs.
Here is a small example of how I am personally using it within my custom pipeline for color correction: