Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added at_train_end logic to base pipeline #1932

Merged
merged 17 commits into from
May 23, 2023

Conversation

maturk
Copy link
Collaborator

@maturk maturk commented May 16, 2023

Here I have added at_train_end() logic to the base pipeline. To accommodate various different user needs, I added similar methods to the base_datamanager and base_model which are callable with **kwargs.

Here is a small example of how I am personally using it within my custom pipeline for color correction:

    def at_train_end(self) -> None:
        self.eval()
        camera_ray_bundle, batch = self.datamanager.at_train_end()
        preds, refs = self.model.at_train_end(camera_ray_bundle=camera_ray_bundle, batch=batch)
        cc_images = self.color_correct(img=preds.cpu(), ref=refs.cpu())
        save_image((cc_images *255).astype(np.uint8), self.save_path, log=True)

@SauravMaheshkar SauravMaheshkar added python Pull requests that update Python code quality of life enhancement New feature or request labels May 16, 2023
Copy link
Contributor

@tancik tancik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should either make sure that the datamanager and model train_end are called, or remove them and leave it to the user to add such a function and call it in the pipeline.

Comment on lines 289 to 291
def at_train_end(self, **kwargs: Any) -> Optional[Any]: # pylint: disable=unused-argument disable=no-self-use
"""Called at end of training for optional datamanager outputs."""

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is never called, I worry that users will override this function and expect it to be called.

Comment on lines 222 to 224

def at_train_end(self, **kwargs: Any) -> Optional[Any]: # pylint: disable=unused-argument disable=no-self-use
"""Called at end of training for optional model outputs."""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also never called, same potential confusion as above.

@jkulhanek
Copy link
Contributor

Wouldn’t it be better to implement this using a new callback type? That way no code would have to be added to the model and the datamanager. Also, shouldn’t it be “on_train_end” instead of at..?

@maturk
Copy link
Collaborator Author

maturk commented May 17, 2023

Thanks for the feedback again. I have added a new callback type OnTrainEndCallback. I removed unnecessary on_train_end calls from base model and base pipeline. This callback has access to the pipeline object. I also made naming consistent "on train end".

Here is an example of how I am using the new callback type within my custom pipeline similar to my previous implementation:

    def get_on_train_end_callbacks(self) -> List[OnTrainEndCallback]:
        on_train_end_callbacks = []
        def color_correct_images():
            self.eval()
            camera_ray_bundle, batch = self.datamanager.at_train_end() # No longer in base datamanager, just a user defined function
            preds, refs = self.model.at_train_end(camera_ray_bundle=camera_ray_bundle, batch=batch) # No longer in base model, just a user defined function
            preds = self.color_correct(img=preds.cpu(), ref=refs.cpu())
            save_image((preds *255).astype(np.uint8), self.cc_save_paths, log=True)
        
        on_train_end_callbacks.append(
            OnTrainEndCallback(func=color_correct_images)
        )
        return on_train_end_callbacks

@jkulhanek
Copy link
Contributor

I am very sorry I didn't explain it clearly. I meant the already existing callback infrastructure. ...registering the callback here:

class TrainingCallbackLocation(Enum):
, etc...

@jkulhanek
Copy link
Contributor

I think that would integrate in a more concise way with the rest of the code.

@maturk
Copy link
Collaborator Author

maturk commented May 17, 2023

Whew, no problem :) Happy to learn, maybe third time is the the charm! So I integrated it into the existing callback locations. Only thing I am not sure about is regarding the "step" var when calling run_callback_at_location(step,location) at the end of training. There is really no way to know if step==max-train-iterations inside the TrainingCallback class since this is not exposed. I have set step=None for now to signify that this is being called at train end.

@@ -44,6 +44,7 @@ class TrainingCallbackLocation(Enum):

BEFORE_TRAIN_ITERATION = auto()
AFTER_TRAIN_ITERATION = auto()
ON_TRAIN_END = auto()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we name it AFTER_TRAIN_END?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or perhaps "AFTER _TRAIN"?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have have now renamed it to AFTER_TRAIN

@@ -297,6 +297,10 @@ def train(self) -> None:
table.add_row("Checkpoint Directory", str(self.checkpoint_dir))
CONSOLE.print(Panel(table, title="[bold][green]:tada: Training Finished :tada:[/bold]", expand=False))

# on train end callbacks
for callback in self.callbacks:
callback.run_callback_at_location(step=None, location=TrainingCallbackLocation.ON_TRAIN_END)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you pass the actual step here?


def run_callback_at_location(self, step: int, location: TrainingCallbackLocation) -> None:
def run_callback_at_location(self, step: Union[int, None], location: TrainingCallbackLocation) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please keep the signature as int

self.where_to_run = where_to_run
self.update_every_num_iters = update_every_num_iters
self.iters = iters
self.func = func
self.args = args if args is not None else []
self.kwargs = kwargs if kwargs is not None else {}

def run_callback(self, step: int) -> None:
def run_callback(self, step: Union[int, None]) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please keep int here.

@@ -71,15 +72,15 @@ def __init__(
):
assert (
"step" in signature(func).parameters.keys()
), f"'step: int' must be an argument in the callback function 'func': {func.__name__}"
), f"'step: Union[int, None]' must be an argument in the callback function 'func': {func.__name__}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, please keep int

Copy link
Contributor

@tancik tancik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@tancik tancik merged commit 33d95f3 into nerfstudio-project:main May 23, 2023
4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request python Pull requests that update Python code quality of life
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants