Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions python/monarch/common/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,15 @@
from monarch.common.function import ResolvableFromCloudpickle, ResolvableFunction
from monarch.common.invocation import DeviceException, RemoteException
from monarch.common.reference import Referenceable
from monarch.common.stream import StreamRef
from monarch.common.tree import flattener
from pyre_extensions import none_throws

from .shape import NDSlice
from .tensor_factory import TensorFactory

if TYPE_CHECKING:
from monarch.common.stream import StreamRef

from .device_mesh import DeviceMesh, RemoteProcessGroup
from .pipe import Pipe
from .recording import Recording
Expand Down Expand Up @@ -98,7 +99,7 @@ def to_rust_message(self) -> tensor_worker.WorkerMessage:


class CreateStream(NamedTuple):
result: StreamRef
result: "StreamRef"
default: bool

def to_rust_message(self) -> tensor_worker.WorkerMessage:
Expand Down Expand Up @@ -132,7 +133,7 @@ class CallFunction(NamedTuple):
function: ResolvableFunction
args: Tuple[object, ...]
kwargs: Dict[str, object]
stream: StreamRef
stream: "StreamRef"
device_mesh: DeviceMesh
remote_process_groups: List[RemoteProcessGroup]

Expand Down Expand Up @@ -199,7 +200,7 @@ def to_rust_message(self) -> tensor_worker.WorkerMessage:
class RecordingResult(NamedTuple):
input: Tensor | tensor_worker.Ref
output_index: int
stream: StreamRef
stream: "StreamRef"

def to_rust_message(self) -> tensor_worker.WorkerMessage:
return tensor_worker.RecordingResult(
Expand Down