Skip to content

Commit

Permalink
✨ Hook for pytorch lightning
Browse files Browse the repository at this point in the history
  • Loading branch information
klieret committed Jul 4, 2023
1 parent b018d0a commit 00da16f
Show file tree
Hide file tree
Showing 5 changed files with 79 additions and 0 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Expand Up @@ -3,6 +3,12 @@
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## 1.1.0 (04.07.2023)

### Added

- Added support for pytorch lightning

## 1.0.4 (18.01.2023)

### Fixed
Expand Down
21 changes: 21 additions & 0 deletions README.md
Expand Up @@ -104,6 +104,27 @@ for batch_idx, (data, target) in enumerate(train_loader):
trigger_sync() # <-- New!
```

#### With pytorch lightning

```python
from wandb_osh.lightning_hooks import TriggerWandbSyncLightningCallback # <-- New!
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning import Trainer

logger = WandbLogger(
project="project",
group="group",
offline=True,
)

model = MyLightningModule()
trainer = Trainer(
logger=logger,
callbacks=[TriggerWandbSyncLightningCallback()] # <-- New!
)
trainer.fit(model, train_dataloader, val_dataloader)
```

#### With ray tune

> **Note**
Expand Down
33 changes: 33 additions & 0 deletions src/wandb_osh/lightning_hooks.py
@@ -0,0 +1,33 @@
from __future__ import annotations

from os import PathLike

from pytorch_lightning.callbacks import Callback

from wandb_osh.hooks import TriggerWandbSyncHook, _comm_default_dir


class TriggerWandbSyncLightningCallback(Callback):
def __init__(self, communication_dir: PathLike = _comm_default_dir):
"""Hook to be used when interfacing wandb with pytorch lightning.
Args:
communication_dir: Directory used for communication with wandb-osh.
Usage
.. code-block:: python
from wandb_osh.lightning_hooks import TriggerWandbSyncLightningCallback
trainer = Trainer(callbacks=[TriggerWandbSyncLightningCallback()])
"""
super().__init__()
self._hook = TriggerWandbSyncHook(communication_dir=communication_dir)

def on_validation_epoch_end(
self,
*args,
):
self._hook()
7 changes: 7 additions & 0 deletions tests/conftest.py
@@ -0,0 +1,7 @@
from __future__ import annotations

import logging

from wandb_osh.util.log import logger

logger.setLevel(logging.DEBUG)
12 changes: 12 additions & 0 deletions tests/test_lightning_hooks.py
@@ -0,0 +1,12 @@
from __future__ import annotations

import wandb

from wandb_osh.lightning_hooks import TriggerWandbSyncLightningCallback


def test_manual_trigger(tmp_path):
wandb.init(project="test", mode="offline", dir=tmp_path)
lh = TriggerWandbSyncLightningCallback(tmp_path)
lh.on_validation_epoch_end(None, None) # type: ignore
assert len([f for f in tmp_path.iterdir() if f.suffix == ".command"]) == 1

0 comments on commit 00da16f

Please sign in to comment.