Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions dvclive/catalyst.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from catalyst.core.callback import Callback, CallbackOrder

import dvclive


class DvcLiveCallback(Callback):
def __init__(self, model_file=None):
super().__init__(order=CallbackOrder.external)
self.model_file = model_file

def on_epoch_end(self, runner) -> None:
step = runner.stage_epoch_step

for loader_key, per_loader_metrics in runner.epoch_metrics.items():
for key, value in per_loader_metrics.items():
key = key.replace("/", "_")
dvclive.log(f"{loader_key}/{key}", float(value), step)

if self.model_file:
checkpoint = runner.engine.pack_checkpoint(
model=runner.model,
criterion=runner.criterion,
optimizer=runner.optimizer,
scheduler=runner.scheduler,
)
runner.engine.save_checkpoint(checkpoint, self.model_file)
dvclive.next_step()
4 changes: 3 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,9 @@ def run(self):
xgb = ["xgboost"]
lgbm = ["lightgbm"]
hugginface = ["transformers", "datasets"]
catalyst = ["catalyst"]

all_libs = mmcv + tf + xgb + lgbm + hugginface
all_libs = mmcv + tf + xgb + lgbm + hugginface + catalyst

tests_requires = [
"pylint==2.5.3",
Expand Down Expand Up @@ -75,6 +76,7 @@ def run(self):
"xgb": xgb,
"lgbm": lgbm,
"huggingface": hugginface,
"catalyst": catalyst,
},
keywords="data-science metrics machine-learning developer-tools ai",
python_requires=">=3.6",
Expand Down
102 changes: 102 additions & 0 deletions tests/test_catalyst.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import os

import pytest
from catalyst import dl
from catalyst.contrib.datasets import MNIST
from catalyst.data import ToTensor
from catalyst.utils.torch import get_available_engine
from torch import nn, optim
from torch.utils.data import DataLoader

import dvclive
from dvclive.catalyst import DvcLiveCallback

# pylint: disable=redefined-outer-name, unused-argument


@pytest.fixture
def loaders():
train_data = MNIST(
os.getcwd(), train=True, download=True, transform=ToTensor()
)
valid_data = MNIST(
os.getcwd(), train=False, download=True, transform=ToTensor()
)
return {
"train": DataLoader(train_data, batch_size=32),
"valid": DataLoader(valid_data, batch_size=32),
}


@pytest.fixture
def runner():
return dl.SupervisedRunner(
engine=get_available_engine(),
input_key="features",
output_key="logits",
target_key="targets",
loss_key="loss",
)


def test_catalyst_callback(tmp_dir, runner, loaders):
dvclive.init("dvc_logs")

model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10))
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.02)

runner.train(
model=model,
criterion=criterion,
optimizer=optimizer,
loaders=loaders,
num_epochs=2,
callbacks=[
dl.AccuracyCallback(input_key="logits", target_key="targets"),
DvcLiveCallback(),
],
logdir="./logs",
valid_loader="valid",
valid_metric="loss",
minimize_valid_metric=True,
verbose=True,
load_best_on_end=True,
)

assert os.path.exists("dvc_logs")

train_path = tmp_dir / "dvc_logs/train"
valid_path = tmp_dir / "dvc_logs/valid"

assert train_path.is_dir()
assert valid_path.is_dir()
assert (train_path / "accuracy.tsv").exists()


def test_catalyst_model_file(tmp_dir, runner, loaders):
dvclive.init("dvc_logs")

model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10))
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.02)

runner.train(
model=model,
engine=runner.engine,
criterion=criterion,
optimizer=optimizer,
loaders=loaders,
num_epochs=2,
callbacks=[
dl.AccuracyCallback(input_key="logits", target_key="targets"),
DvcLiveCallback("model.pth"),
],
logdir="./logs",
valid_loader="valid",
valid_metric="loss",
minimize_valid_metric=True,
verbose=True,
load_best_on_end=True,
)
assert (tmp_dir / "model.pth").is_file()