Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Returning callback results when calling pipelines' train method #71

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/pythae/models/vae/vae_model.py
Expand Up @@ -162,7 +162,7 @@ def get_nll(self, data, n_samples=1, batch_size=100):
log_q_z_given_x = -0.5 * (
log_var + (z - mu) ** 2 / torch.exp(log_var)
).sum(dim=-1)
log_p_z = -0.5 * (z ** 2).sum(dim=-1)
log_p_z = -0.5 * (z**2).sum(dim=-1)

recon_x = self.decoder(z)["reconstruction"]

Expand Down
2 changes: 2 additions & 0 deletions src/pythae/pipelines/training.py
Expand Up @@ -228,6 +228,8 @@ def __call__(
training_config=self.training_config,
callbacks=callbacks,
)
else:
raise ValueError("The provided training config is not supported.")

self.trainer = trainer

Expand Down
29 changes: 23 additions & 6 deletions src/pythae/trainers/training_callbacks.py
Expand Up @@ -111,6 +111,9 @@ def on_log(self, training_config: BaseTrainerConfig, logs, **kwargs):
Event called after logging the last logs.
"""

def __repr__(self) -> str:
return self.__class__.__name__


class CallbackHandler:
"""
Expand Down Expand Up @@ -225,6 +228,24 @@ def on_log(self, training_config: BaseTrainerConfig, logs, **kwargs):
)


class TrainHistoryCallback(MetricConsolePrinterCallback):
def __init__(self):
self.history = {"train_loss": [], "eval_loss": []}
super().__init__()

def on_train_begin(self, training_config: BaseTrainerConfig, **kwargs):
self.history = {"train_loss": [], "eval_loss": []}

def on_log(self, training_config: BaseTrainerConfig, logs, **kwargs):
logger = kwargs.pop("logger", self.logger)

if logger is not None:
epoch_train_loss = logs.get("train_epoch_loss", None)
epoch_eval_loss = logs.get("eval_epoch_loss", None)
self.history["train_loss"].append(epoch_train_loss)
self.history["eval_loss"].append(epoch_eval_loss)


class ProgressBarCallback(TrainingCallback):
"""
A :class:`TrainingCallback` printing the training progress bar.
Expand Down Expand Up @@ -581,12 +602,8 @@ def setup(
)
experiment.log_other("Created from", "pythae")

experiment.log_parameters(
training_config, prefix="training_config/"
)
experiment.log_parameters(
model_config, prefix="model_config/"
)
experiment.log_parameters(training_config, prefix="training_config/")
experiment.log_parameters(model_config, prefix="model_config/")

def on_train_begin(self, training_config: BaseTrainerConfig, **kwargs):
model_config = kwargs.pop("model_config", None)
Expand Down