diff --git a/python/monarch/common/messages.py b/python/monarch/common/messages.py index f117e9bd3..5cfe66043 100644 --- a/python/monarch/common/messages.py +++ b/python/monarch/common/messages.py @@ -25,7 +25,6 @@ from monarch.common.function import ResolvableFromCloudpickle, ResolvableFunction from monarch.common.invocation import DeviceException, RemoteException from monarch.common.reference import Referenceable -from monarch.common.stream import StreamRef from monarch.common.tree import flattener from pyre_extensions import none_throws @@ -33,6 +32,8 @@ from .tensor_factory import TensorFactory if TYPE_CHECKING: + from monarch.common.stream import StreamRef + from .device_mesh import DeviceMesh, RemoteProcessGroup from .pipe import Pipe from .recording import Recording @@ -98,7 +99,7 @@ def to_rust_message(self) -> tensor_worker.WorkerMessage: class CreateStream(NamedTuple): - result: StreamRef + result: "StreamRef" default: bool def to_rust_message(self) -> tensor_worker.WorkerMessage: @@ -132,7 +133,7 @@ class CallFunction(NamedTuple): function: ResolvableFunction args: Tuple[object, ...] kwargs: Dict[str, object] - stream: StreamRef + stream: "StreamRef" device_mesh: DeviceMesh remote_process_groups: List[RemoteProcessGroup] @@ -199,7 +200,7 @@ def to_rust_message(self) -> tensor_worker.WorkerMessage: class RecordingResult(NamedTuple): input: Tensor | tensor_worker.Ref output_index: int - stream: StreamRef + stream: "StreamRef" def to_rust_message(self) -> tensor_worker.WorkerMessage: return tensor_worker.RecordingResult(