Skip to content

Commit

Permalink
Tracing callback hotfix (#1234)
Browse files Browse the repository at this point in the history
* fixed bug in TracingCallback

* removed unnecessary imports
  • Loading branch information
y-ksenia committed Jun 10, 2021
1 parent bc31973 commit 59300a1
Showing 1 changed file with 2 additions and 3 deletions.
5 changes: 2 additions & 3 deletions catalyst/callbacks/tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import torch

from catalyst.core import Callback, CallbackNode, CallbackOrder
from catalyst.utils.torch import any2device
from catalyst.utils.tracing import trace_model

if TYPE_CHECKING:
Expand Down Expand Up @@ -145,9 +144,9 @@ def on_stage_end(self, runner: "IRunner") -> None:
Args:
runner: runner for experiment
"""
model = runner.model
model = runner.engine.sync_device(runner.model)
batch = tuple(runner.batch[key] for key in self.input_key)
batch = any2device(batch, "cpu")
batch = runner.engine.sync_device(batch)
traced_model = trace_model(model=model, batch=batch, method_name=self.method_name)
torch.jit.save(traced_model, self.filename)

Expand Down

0 comments on commit 59300a1

Please sign in to comment.