Skip to content

Commit

Permalink
Track learning rate with callback. n_epochs redundant with the `epo…
Browse files Browse the repository at this point in the history
…chs` tracked by lightning.
  • Loading branch information
DomInvivo committed Sep 10, 2023
1 parent 4a9893f commit ef5db7f
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
7 changes: 6 additions & 1 deletion graphium/config/_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

# Lightning
from lightning import Trainer
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint, LearningRateMonitor
from lightning.pytorch.loggers import Logger, WandbLogger
from loguru import logger

Expand Down Expand Up @@ -415,6 +415,11 @@ def load_trainer(
if "model_checkpoint" in cfg_trainer.keys():
callbacks.append(ModelCheckpoint(**cfg_trainer["model_checkpoint"]))

if "learning_rate_monitor" in cfg_trainer.keys():
callbacks.append(LearningRateMonitor(**cfg_trainer["learning_rate_monitor"]))
else:
callbacks.append(LearningRateMonitor())

# Define the logger parameters
wandb_cfg = config["constants"].get("wandb")
if wandb_cfg is not None:
Expand Down
5 changes: 0 additions & 5 deletions graphium/trainer/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,11 +618,6 @@ def on_validation_epoch_end(self) -> None:
concatenated_metrics_logs = self.task_epoch_summary.concatenate_metrics_logs(metrics_logs)
concatenated_metrics_logs["val/mean_time"] = torch.tensor(self.mean_val_time_tracker.mean_value)
concatenated_metrics_logs["val/mean_tput"] = self.mean_val_tput_tracker.mean_value

if hasattr(self.optimizers(), "param_groups"):
lr = self.optimizers().param_groups[0]["lr"]
concatenated_metrics_logs["lr"] = torch.tensor(lr)
concatenated_metrics_logs["n_epochs"] = torch.tensor(self.current_epoch, dtype=torch.float32)
self.log_dict(concatenated_metrics_logs)

# Save yaml file with the per-task metrics summaries
Expand Down

0 comments on commit ef5db7f

Please sign in to comment.