Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions torchstore/_async_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import asyncio
from typing import Callable, cast, Generic, TypeVar

T = TypeVar("T")


class OnceCell(Generic[T]):
"""Poor man's version of tokio::sync::OnceCell, except it's not threadsafe (maybe it is because of GIL?).
This is a cell that can be initialized exactly once."""

def __init__(self):
self._lock = asyncio.Lock()
self._value: T | None = None
self._initialized = False

async def get_or_init(self, initializer) -> T:
if self._initialized:
return cast(T, self._value)

async with self._lock:
if not self._initialized:
self._value = await initializer()
self._initialized = True

return cast(T, self._value)

def get(self) -> T:
if not self._initialized:
raise ValueError("Value not initialized yet")
return cast(T, self._value)


class SequentialExecutor:
"""A simple executor that runs tasks sequentially in the current event loop.
This is mainly needed for RDMA operations, which will panic if concurrent requests are made (what the heck?).
"""

def __init__(self):
self._queue = asyncio.Queue()
self._worker_task = None

async def start_worker(self):
self._worker_task = asyncio.create_task(self._worker())

async def _worker(self):
while True:
try:
func, args, kwargs, response = await self._queue.get()

if response.cancelled():
continue # Caller gave up

try:
result = await func(*args, **kwargs)
response.set_result(result)
except Exception as e:
response.set_exception(e)

except Exception as outer_err:
# Log or handle the error
print(f"[SequentialExecutor] Worker crashed: {outer_err}")

async def submit(self, func: Callable, *args, **kwargs) -> asyncio.Future:
fut = asyncio.Future()
await self._queue.put((func, args, kwargs, fut))
return await fut
40 changes: 22 additions & 18 deletions torchstore/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
from typing import Any, Dict, List, Optional, Union

import torch

from monarch.actor import get_or_spawn_controller

import torchstore.state_dict_utils
from torchstore._async_utils import OnceCell, SequentialExecutor
from torchstore.client import LocalClient
from torchstore.controller import Controller
from torchstore.storage_volume import StorageVolume
Expand All @@ -26,7 +26,7 @@
DEFAULT_TORCHSTORE_NAME: str = "TorchStore"

# cache for local clients
_local_clent_map: Dict[str, LocalClient] = {}
_local_client_map: Dict[str, OnceCell[LocalClient]] = {}


async def initialize(
Expand Down Expand Up @@ -94,14 +94,14 @@ async def shutdown(store_name: str = DEFAULT_TORCHSTORE_NAME) -> None:
"""
controller = await _controller(store_name)
await controller.teardown.call()
global _local_clent_map
_local_clent_map = {}
global _local_client_map
_local_client_map = {}


def reset_client(store_name: str = DEFAULT_TORCHSTORE_NAME) -> None:
"""Reset the local client for a given store. Useful for refreshing client state after shutdown."""
global _local_clent_map
_local_clent_map.pop(store_name, None)
global _local_client_map
_local_client_map.pop(store_name, None)


async def _controller(store_name: str = DEFAULT_TORCHSTORE_NAME) -> Controller:
Expand All @@ -124,19 +124,23 @@ async def client(store_name: str = DEFAULT_TORCHSTORE_NAME) -> LocalClient:
>>> store_client = await client()
>>> await store_client.put("my_key", tensor)
"""
if store_name in _local_clent_map:
return _local_clent_map[store_name]

controller = await _controller(store_name)
controller_strategy = await controller.get_controller_strategy.call_one()

local_client = LocalClient(
controller=controller,
strategy=controller_strategy,
)
_local_clent_map[store_name] = local_client
if store_name not in _local_client_map:
_local_client_map[store_name] = OnceCell()

async def initializer():
controller = await _controller(store_name)
controller_strategy = await controller.get_controller_strategy.call_one()

executor = SequentialExecutor()
await executor.start_worker()
local_client = LocalClient(
controller=controller,
strategy=controller_strategy,
rdma_executor=executor,
)
return local_client

return local_client
return await _local_client_map[store_name].get_or_init(initializer)


async def put(
Expand Down
34 changes: 30 additions & 4 deletions torchstore/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
import torch
from torch.distributed.tensor import DTensor

from torchstore._async_utils import SequentialExecutor

from torchstore.controller import ObjectType
from torchstore.logging import LatencyTracker
from torchstore.transport import Pipe, Request, TensorSlice
Expand All @@ -19,6 +21,19 @@
logger = getLogger(__name__)


def _limit_concurrency(method):
"""
Decorator to limit concurrency of async methods using the instance's semaphore.
Assumes the instance has a self._semaphore attribute (asyncio.Semaphore).
"""

async def wrapper(self, *args, **kwargs):
async with self._semaphore:
return await method(self, *args, **kwargs)

return wrapper


class LocalClient:
"""This class represents the local store, which exists on every process. Remote storage
is handled by the client.
Expand All @@ -28,9 +43,14 @@ def __init__(
self,
controller,
strategy,
*,
rdma_executor: SequentialExecutor | None = None,
max_concurrent_requests: int = 4,
):
self._controller = controller
self.strategy = strategy
self.rdma_executor = rdma_executor
self._semaphore = asyncio.Semaphore(max_concurrent_requests)

async def _locate_volumes(self, key: str):
"""Helper method to call locate_volumes and convert any error to KeyError for missing keys."""
Expand All @@ -40,6 +60,7 @@ async def _locate_volumes(self, key: str):
raise KeyError(str(e)) from e

@torch.no_grad
@_limit_concurrency
async def put(self, key: str, value: Union[torch.Tensor, Any]):
latency_tracker = LatencyTracker(f"put:{key}")
request = Request.from_any(value)
Expand All @@ -51,14 +72,15 @@ async def put(self, key: str, value: Union[torch.Tensor, Any]):

pipe = Pipe(storage_volume)

await pipe.put_to_storage_volume(key, request)
await pipe.put_to_storage_volume(key, request, executor=self.rdma_executor)
latency_tracker.track_step("put_to_storage_volume")

await self._controller.notify_put.call(key, request.meta_only(), volume_id)
latency_tracker.track_step("notify_put")
latency_tracker.track_e2e()

@torch.no_grad
@_limit_concurrency
async def get(
self,
key: str,
Expand Down Expand Up @@ -235,7 +257,9 @@ async def _get_object(self, key: str):
storage_volume = self.strategy.get_storage_volume(volume_id)
pipe = Pipe(storage_volume)
request = Request.from_any(None)
return await pipe.get_from_storage_volume(key, request)
return await pipe.get_from_storage_volume(
key, request, executor=self.rdma_executor
)

async def _get_tensor(self, key: str) -> torch.Tensor:
"""Fetches the tensor which is stored in one volume storage"""
Expand All @@ -248,7 +272,9 @@ async def _get_tensor(self, key: str) -> torch.Tensor:
# TODO: consolidate the logic here - None indicates it is an object request,
# which is sematically inappropriate here.
request = Request.from_any(None)
return await pipe.get_from_storage_volume(key, request)
return await pipe.get_from_storage_volume(
key, request, executor=self.rdma_executor
)

async def _get_distributed_whole_tensor(self, key: str) -> torch.Tensor:
"""Fetches slices from all volume storages and stitch together to return the whole tensor"""
Expand All @@ -267,7 +293,7 @@ async def _get_distributed_whole_tensor(self, key: str) -> torch.Tensor:
tensor_slice_request = Request.from_tensor_slice(tensor_slice)

local_tensor = await pipe.get_from_storage_volume(
key, tensor_slice_request
key, tensor_slice_request, executor=self.rdma_executor
)
partial_results.append((local_tensor, tensor_slice))

Expand Down
54 changes: 44 additions & 10 deletions torchstore/storage_volume.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@
import torch
from monarch.actor import Actor, endpoint

from torchstore.transport.buffers import TransportBuffer
from torchstore._async_utils import OnceCell, SequentialExecutor

from torchstore.transport.buffers import TransportBuffer
from torchstore.transport.pipe import Request, TensorSlice
from torchstore.utils import assemble_global_tensor, spawn_actors

Expand All @@ -33,6 +34,15 @@ def __init__(
) -> None:
self.store: StorageImpl = InMemoryStore()
self.volume_id: str = id_func()
self._executor = OnceCell[SequentialExecutor]()

async def get_executor(self) -> SequentialExecutor:
async def initializer() -> SequentialExecutor:
executor = SequentialExecutor()
await executor.start_worker()
return executor

return await self._executor.get_or_init(initializer=initializer)

@classmethod
async def spawn(
Expand All @@ -56,13 +66,17 @@ async def get_id(self) -> str:
async def put(
self, key: str, transport_buffer: TransportBuffer, request: Request
) -> None:
await self.store.put(key, transport_buffer, request)
await self.store.put(
key, transport_buffer, request, executor=await self.get_executor()
)

@endpoint
async def get(
self, key: str, transport_buffer: TransportBuffer, request: Request
) -> TransportBuffer:
return await self.store.get(key, transport_buffer, request)
return await self.store.get(
key, transport_buffer, request, executor=await self.get_executor()
)

@endpoint
async def get_meta(
Expand All @@ -81,13 +95,23 @@ class StorageImpl:
"""Abstract base class for storage implementations."""

async def put(
self, key: str, transport_buffer: TransportBuffer, request: Request
self,
key: str,
transport_buffer: TransportBuffer,
request: Request,
*,
executor=None,
) -> Optional[TransportBuffer]:
"""Store data in the storage backend."""
raise NotImplementedError()

async def get(
self, key: str, transport_buffer: TransportBuffer, request: Request
self,
key: str,
transport_buffer: TransportBuffer,
request: Request,
*,
executor=None,
) -> TransportBuffer:
"""Retrieve data from the storage backend."""
raise NotImplementedError()
Expand All @@ -112,7 +136,7 @@ def __init__(self) -> None:
def _build_full_tensor(self, key: str) -> None:
logger.debug(f"Building full tensor for {key}")
# we can also consider in the future not requiring the full tensor to be
# assembled, and instead only that the requested offsets are available
# assembled, and instead only that the requested offs are available
# this is a performance optimization, but could be tricky to implement.
assert self._has_full_tensor(key)

Expand Down Expand Up @@ -186,23 +210,33 @@ def _handle_dtensor(
}

async def put(
self, key: str, transport_buffer: TransportBuffer, request: Request
self,
key: str,
transport_buffer: TransportBuffer,
request: Request,
*,
executor=None,
) -> None:
if request.is_object:
self.kv[key] = {"obj": request.objects}
return

# since we pass tensor=None to the transport buffer,
# we allocate on the fly
tensor = await transport_buffer.read_into(tensor=None)
tensor = await transport_buffer.read_into(tensor=None, executor=executor)
if request.tensor_slice is not None:
self._handle_dtensor(key, request.tensor_slice, tensor)
return

self.kv[key] = tensor

async def get(
self, key: str, transport_buffer: TransportBuffer, request: Request
self,
key: str,
transport_buffer: TransportBuffer,
request: Request,
*,
executor=None,
) -> TransportBuffer:

if key not in self.kv:
Expand All @@ -216,7 +250,7 @@ async def get(
return transport_buffer

if request.tensor_slice is None:
await transport_buffer.write_from(self.kv[key])
await transport_buffer.write_from(self.kv[key], executor=executor)
return transport_buffer

# TODO:
Expand Down
Loading