From 363ccc130f523d3d2c6acf8bb11e7ca580bdcfb1 Mon Sep 17 00:00:00 2001 From: Cedric Leonard <51703091+CedricLeon@users.noreply.github.com> Date: Thu, 21 Mar 2024 18:44:07 +0100 Subject: [PATCH] Update to Lightning new name (previously pytorch_lightning) (#97) --- src/wandb_osh/lightning_hooks.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/src/wandb_osh/lightning_hooks.py b/src/wandb_osh/lightning_hooks.py index b769a23..259c71d 100644 --- a/src/wandb_osh/lightning_hooks.py +++ b/src/wandb_osh/lightning_hooks.py @@ -2,18 +2,17 @@ from os import PathLike -from pytorch_lightning import LightningModule, Trainer -from pytorch_lightning.callbacks import Callback +import lightning.pytorch as pl from wandb_osh.hooks import TriggerWandbSyncHook, _comm_default_dir -class TriggerWandbSyncLightningCallback(Callback): +class TriggerWandbSyncLightningCallback(pl.Callback): def __init__( self, communication_dir: PathLike = _comm_default_dir, ): - """Hook to be used when interfacing wandb with pytorch lightning. + """Hook to be used when interfacing wandb with Lightning. Args: communication_dir: Directory used for communication with wandb-osh. @@ -32,8 +31,8 @@ def __init__( def on_validation_epoch_end( self, - trainer: Trainer, - pl_module: LightningModule, + trainer: pl.Trainer, + pl_module: pl.LightningModule, ) -> None: if trainer.sanity_checking: return