diff --git a/tests/test_fastai.py b/tests/test_fastai.py index c4d501cd..9e30f4d9 100644 --- a/tests/test_fastai.py +++ b/tests/test_fastai.py @@ -4,6 +4,7 @@ from fastai.tabular.all import ( Categorify, Normalize, + ProgressCallback, TabularDataLoaders, accuracy, tabular_learner, @@ -39,6 +40,7 @@ def data_loader(): def test_fastai_callback(tmp_dir, data_loader): learn = tabular_learner(data_loader, metrics=accuracy) + learn.remove_cb(ProgressCallback) learn.model_dir = os.path.abspath("./") learn.fit_one_cycle(2, cbs=[DvcLiveCallback("model")]) @@ -54,6 +56,7 @@ def test_fastai_callback(tmp_dir, data_loader): def test_fastai_model_file(tmp_dir, data_loader): learn = tabular_learner(data_loader, metrics=accuracy) + learn.remove_cb(ProgressCallback) learn.model_dir = os.path.abspath("./") learn.fit_one_cycle(2, cbs=[DvcLiveCallback("model")]) assert (tmp_dir / "model.pth").is_file()