Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -138,3 +138,8 @@ if __name__ == "__main__":
2. Build the Bicycle, not the super car -- Develop value iterively, instead of trying to ship everything at once.

3. Work backwards from use cases, and leave tests!

# Testing

Pytest is used for testing. For an examples of how to run tests (and get logs), see:
`TORCHSTORE_LOG_LEVEL=DEBUG pytest -vs --log-cli-level=DEBUG tests/test_models.py::test_main`
20 changes: 11 additions & 9 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,15 @@
from logging import getLogger

import pytest

import torch

import torchstore as ts
from monarch.actor import Actor, current_rank, endpoint
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.fsdp import fully_shard
from torchstore.logging import init_logging
from torchstore.utils import spawn_actors
from torchstore.state_dict_utils import _state_dict_size

from transformers import AutoModelForCausalLM

Expand Down Expand Up @@ -70,6 +71,7 @@ def build_model(self):
model = AutoModelForCausalLM.from_pretrained(
TEST_MODEL, token=os.environ["HF_TOKEN"]
)
self.rlog(f"State dict size: {_state_dict_size(model.state_dict())}")
if self.world_size > 1:
self.initialize_distributed()
self.rlog("sharding")
Expand All @@ -83,6 +85,8 @@ def build_model(self):
return model, optimizer

def rlog(self, msg):
print(f"rank: {self.rank} {msg}")
self.logger.info(f"rank: {self.rank} {msg}")
logger.info(f"rank: {self.rank} {msg}")

@endpoint
Expand All @@ -99,7 +103,7 @@ async def do_push(self):
self.rlog("pushing state dict")
t = time.perf_counter()
await ts.put_state_dict(state_dict, "v0")
self.rlog(f"pushed state dict in {time.perf_counter()-t} seconds")
self.rlog(f"pushed state dict in {time.perf_counter() - t} seconds")

@endpoint
async def do_get(self):
Expand Down Expand Up @@ -138,6 +142,9 @@ async def test_resharding(strategy_params, use_rdma):
async def _do_test(put_mesh_shape, get_mesh_shape, strategy, use_rdma):
os.environ["TORCHSTORE_RDMA_ENABLED"] = "1" if use_rdma else "0"

ts.init_logging()
logger.info(f"Testing with strategy: {strategy}")

put_world_size = math.prod(put_mesh_shape)
await ts.initialize(
num_storage_volumes=put_world_size if strategy is not None else 1,
Expand All @@ -163,18 +170,13 @@ async def _do_test(put_mesh_shape, get_mesh_shape, strategy, use_rdma):
file_store_name=os.path.join(tmpdir, "get_world"),
)

logger.info("pushing state dict")
t = time.perf_counter()
logger.info(f"do_push ")
await put_world.do_push.call()
logger.info(f"pushing state dict took: {time.perf_counter() - t} seconds")

logger.info("fetching state dict")
t = time.perf_counter()

await get_world.do_get.call()
logger.info(f"getting state dict took: {time.perf_counter()-t} seconds")
finally:
await ts.shutdown()


if __name__ == "__main__":
main([__file__])
2 changes: 1 addition & 1 deletion tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,6 @@ def transport_plus_strategy_params():
(2, ts.LocalRankStrategy()),
(1, None), # singleton
]
rdma_options = [False] # , True] broken on my build
rdma_options = [True, False]

return "strategy_params, use_rdma", list(product(strategies, rdma_options))
2 changes: 2 additions & 0 deletions torchstore/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from torchstore.api import (
client,
reset_client,
exists,
get,
get_state_dict,
Expand Down Expand Up @@ -44,4 +45,5 @@
"SingletonStrategy",
"put_state_dict",
"get_state_dict",
"reset_client"
]
25 changes: 17 additions & 8 deletions torchstore/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@


# I need to keep this somewhere, so here we go
DEFAULT_TORCHSTORE_NAME: str = "TorchStoreController"
DEFAULT_TORCHSTORE_NAME: str = "TorchStore"

# cache for local clients
_local_clent_map: Dict[str, LocalClient] = {}
Expand All @@ -27,6 +27,7 @@ async def initialize(
num_storage_volumes: int = 1,
strategy: Optional[TorchStoreStrategy] = None,
store_name: str = DEFAULT_TORCHSTORE_NAME,
mesh=None,
) -> None:
"""Initialize the TorchStore distributed storage system.

Expand Down Expand Up @@ -56,13 +57,10 @@ async def initialize(
# TODO: monarch doesn't support nested actors yet, so we need to spawn storage volumes here
# ideally this is done in the controller.init
storage_volumes = await StorageVolume.spawn(
num_volumes=num_storage_volumes, id_func=strategy.get_volume_id
num_volumes=num_storage_volumes, mesh=mesh, id_func=strategy.get_volume_id
)

controller = await get_or_spawn_controller(
store_name,
Controller,
)
controller = await _controller(store_name)
await controller.init.call(
strategy=strategy,
num_storage_volumes=num_storage_volumes,
Expand All @@ -83,12 +81,23 @@ async def shutdown(store_name: str = DEFAULT_TORCHSTORE_NAME) -> None:
>>> await ts.shutdown() # Shutdown default store
>>> await ts.shutdown("my_custom_store")
"""
controller = await get_or_spawn_controller(store_name, Controller)
controller = await _controller(store_name)
await controller.teardown.call()
global _local_clent_map
_local_clent_map = {}


def reset_client(store_name: str = DEFAULT_TORCHSTORE_NAME) -> None:
"""Reset the local client for a given store. Useful for refreshing client state after shutdown.
"""
global _local_clent_map
_local_clent_map.pop(store_name, None)

async def _controller(store_name: str = DEFAULT_TORCHSTORE_NAME) -> Controller:
"""Get a controller handle for interacting with the store."""
return await get_or_spawn_controller(store_name, Controller)


async def client(store_name: str = DEFAULT_TORCHSTORE_NAME) -> LocalClient:
"""Get a local client handle for interacting with the store.

Expand All @@ -107,7 +116,7 @@ async def client(store_name: str = DEFAULT_TORCHSTORE_NAME) -> LocalClient:
if store_name in _local_clent_map:
return _local_clent_map[store_name]

controller = await get_or_spawn_controller(store_name, Controller)
controller = await _controller(store_name)
controller_strategy = await controller.get_controller_strategy.call_one()

local_client = LocalClient(
Expand Down
25 changes: 21 additions & 4 deletions torchstore/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@
from typing import Any, Optional, Union

import torch
import time

from torchstore.controller import ObjectType

from torchstore.logging import LatencyTracker
from torchstore.transport import Pipe, Request
from torchstore.utils import assemble_global_tensor, get_local_tensor

Expand All @@ -32,8 +34,7 @@ def __init__(

@torch.no_grad
async def put(self, key: str, value: Union[torch.Tensor, Any]):
logger.debug(f"Putting {key}")

latency_tracker = LatencyTracker(f"put:{key}")
request = Request.from_any(value)
# for now, we only write to one storage volume.
# we probably don't need a remote call for this case since
Expand All @@ -44,11 +45,16 @@ async def put(self, key: str, value: Union[torch.Tensor, Any]):
pipe = Pipe(storage_volume)

await pipe.put_to_storage_volume(key, request)
await self._controller.notify_put.call(key, request, volume_id)
latency_tracker.track_step("put_to_storage_volume")

await self._controller.notify_put.call(key, request.meta_only(), volume_id)
latency_tracker.track_step("notify_put")
latency_tracker.track_e2e()


@torch.no_grad
async def get(self, key: str, inplace_tensor: Optional[torch.Tensor] = None):
logger.debug(f"Fetching {key}")
latency_tracker = LatencyTracker(f"get:{key}")
request = Request.from_any(inplace_tensor)
object_type = ObjectType.from_request(request)

Expand All @@ -64,6 +70,8 @@ async def get(self, key: str, inplace_tensor: Optional[torch.Tensor] = None):
# TODO: in the future, we could intelligently select the best storage volume
# but for now any should work.
fetched_tensor = await pipe.get_from_storage_volume(key, request)
latency_tracker.track_step("get_from_storage_volume")
latency_tracker.track_e2e()
return fetched_tensor if inplace_tensor is None else inplace_tensor

# else: this is the dtensor / tensor slice case
Expand All @@ -81,6 +89,8 @@ async def get(self, key: str, inplace_tensor: Optional[torch.Tensor] = None):
assert partial_results, "No partial results found"
assert request.tensor_slice is not None

latency_tracker.track_step("get_from_storage_volume")

# build the entire tensor.
# TODO: again, we should have better control over
# rebuilding only the portion I need, but this is a good start
Expand Down Expand Up @@ -114,12 +124,19 @@ async def get(self, key: str, inplace_tensor: Optional[torch.Tensor] = None):
request.tensor_slice.local_shape,
request.tensor_slice.offsets,
)

latency_tracker.track_step("assemble_tensor")
t = time.perf_counter()
# Pipe does not have support for inplace copies of fetched tensors yet,
# so we just copy
if inplace_tensor is not None:
assert request.tensor_val is not None
request.tensor_val.copy_(fetched_tensor)
latency_tracker.track_step("copy")
latency_tracker.track_e2e()
return inplace_tensor

latency_tracker.track_e2e()
return fetched_tensor

async def exists(self, key: str) -> bool:
Expand Down
7 changes: 7 additions & 0 deletions torchstore/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,9 @@ def notify_put(self, key: str, request: Request, storage_volume_id: str) -> None
storage_volume_id (str): ID of the storage volume where the data was stored.
"""
self.assert_initialized()
assert request.tensor_val is None, (
f"request should not contain tensor data, as this will significantly increase e2e latency"
)

if key not in self.keys_to_storage_volumes:
self.keys_to_storage_volumes[key] = {}
Expand All @@ -158,3 +161,7 @@ def teardown(self) -> None:
self.strategy = None
self.storage_volumes = None
self.num_storage_volumes = None

@endpoint
def get_keys_to_storage_volumes(self) -> Dict[str, Dict[str, StorageInfo]]:
return self.keys_to_storage_volumes
21 changes: 21 additions & 0 deletions torchstore/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,35 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import time
import logging
import os
import sys


def init_logging():
log_level = os.environ.get("TORCHSTORE_LOG_LEVEL", "INFO").upper()

logging.root.setLevel(log_level)
stdout_handler = logging.StreamHandler(sys.stdout)
stdout_handler.setLevel(log_level)

# Check if a StreamHandler to sys.stdout is already present
for handler in logging.root.handlers:
if isinstance(handler, logging.StreamHandler) and getattr(handler, 'stream', None) == sys.stdout:
# Already has a stdout handler, no need to add another
return
logging.root.addHandler(stdout_handler)

class LatencyTracker:
def __init__(self, name: str) -> None:
self.name = name
self.last_step = self.start_time = time.perf_counter()

def track_step(self, step_name: str) -> None:
now = time.perf_counter()
logging.debug(f"{self.name}:{step_name} took {now - self.last_step} seconds")
self.last_step = now

def track_e2e(self) -> None:
logging.debug(f"{self.name} took {time.perf_counter() - self.start_time} seconds")
14 changes: 12 additions & 2 deletions torchstore/state_dict_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,7 @@ async def get_state_dict(
else ({}, None)
)
if strict and user_mapping is not None:
pass
# assert user_mapping == fetched_mapping
assert user_mapping == fetched_mapping

fetched_state_dict = {}
for flattened_key in fetched_mapping.keys():
Expand Down Expand Up @@ -84,3 +83,14 @@ async def get_state_dict(
# fetched_state_dict = dict(zip(keys, results))

return unflatten_state_dict(fetched_state_dict, fetched_mapping)

def _state_dict_size(state_dict):
"""Returns the size of the state dict in MBs"""
size = 0
sd, _ = flatten_state_dict(state_dict)
for tensor in sd.values():
if not isinstance(tensor, torch.Tensor):
continue

size += tensor.numel() * tensor.element_size()
return size // (1024 * 1024)
8 changes: 5 additions & 3 deletions torchstore/storage_volume.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,14 @@ def __init__(

@classmethod
async def spawn(
cls, num_volumes: int, *init_args: Any, **init_kwargs: Any
cls, num_volumes: int, mesh, *init_args: Any, **init_kwargs: Any,
) -> "StorageVolume":
return await spawn_actors(
num_volumes, cls, cls.actor_name, *init_args, **init_kwargs
actors = await spawn_actors(
num_volumes, cls, cls.actor_name, mesh, *init_args, **init_kwargs
)

return actors

@endpoint
async def get_id(self) -> str:
return self.volume_id
Expand Down
19 changes: 19 additions & 0 deletions torchstore/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ def __init__(self):
self.storage_volumes = None
self.volume_id_to_coord = {}

def __str__(self) -> str:
storage_vol_len = len(self.storage_volumes) if self.storage_volumes is not None else 0
return f"{self.__class__.__name__}(storage_volume_len={storage_vol_len})"

@classmethod
def get_volume_id(cls):
"""Get the unique ID for this process's storage volume. Called by volume on init.
Expand Down Expand Up @@ -120,6 +124,21 @@ async def set_storage_volumes(self, storage_volumes):
await super().set_storage_volumes(storage_volumes)


class HostStrategy(TorchStoreStrategy):
"""Assumes one storage volume per host.

Each process uses 'HOSTNAME' to determine which storage volume to connect to.
"""
@classmethod
def get_volume_id(cls):
# Note: this should only called at spawn, which makes this safe.
return os.environ["HOSTNAME"]

@classmethod
def get_client_id(cls):
return os.environ["HOSTNAME"]


class LocalRankStrategy(TorchStoreStrategy):
"""Strategy that maps storage volumes based on LOCAL_RANK environment variable.

Expand Down
Loading
Loading