Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 30 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,18 @@

A storage solution for PyTorch tensors with distributed tensor support.

TorchStore provides a distributed, asynchronous tensor storage system built on top of
Monarch actors. It enables efficient storage and retrieval of PyTorch tensors across
multiple processes and nodes with support for various transport mechanisms including
RDMA when available.

Key Features:
- Distributed tensor storage with configurable storage strategies
- Asynchronous put/get operations for tensors and arbitrary objects
- Support for PyTorch state_dict serialization/deserialization
- Multiple transport backends (RDMA, regular TCP) for optimal performance
- Flexible storage volume management and sharding strategies

# Under Construction!

Nothing to see here yet, but check back soon
Expand Down Expand Up @@ -46,7 +58,7 @@ pip install git+https://github.com/your-username/torchstore.git
Once installed, you can import it in your Python code:

```python
from torchstore import MultiProcessStore
import torchstore
```

Note: Setup currently assumes you have a working conda environment with both torch & monarch (this is currently a todo). For now the fastest way of setting up is going through [this](https://www.internalfb.com/wiki/Monarch/Monarch_xlformers_integration/Running_Monarch_on_Conda/#how-to-run-monarch) guide.
Expand All @@ -58,18 +70,18 @@ Protop: Install finetine conda & use the 'local' option for the latest packges
```python
import torch
import asyncio
from torchstore import MultiProcessStore
import torchstore as ts

async def main():

# Create a store instance
store = await MultiProcessStore.create_store()
store = await ts.initialize()

# Store a tensor
await store.put("my_tensor", torch.randn(3, 4))
await ts.put("my_tensor", torch.randn(3, 4))

# Retrieve a tensor
tensor = await store.get("my_tensor")
tensor = await ts.get("my_tensor")


if __name__ == "__main__":
Expand All @@ -80,7 +92,7 @@ if __name__ == "__main__":
### Resharding Support with DTensor

```python
from torchstore import MultiProcessStore
import torchstore as ts
from torch.distributed._tensor import distribute_tensor, Replicate, Shard
from torch.distributed.device_mesh import init_device_mesh

Expand All @@ -89,11 +101,8 @@ async def place_dtensor_in_store():
tensor = torch.arange(4)
dtensor = distribute_tensor(tensor, device_mesh, placements=[Shard(1)])

# Create a store instance
store = await MultiProcessStore.create_store()

# Store a tensor
await store.put("my_tensor", dtensor)
await ts.put("my_tensor", dtensor)


async def fetch_dtensor_from_store()
Expand All @@ -107,7 +116,17 @@ async def fetch_dtensor_from_store()
)

# This line copies the previously stored dtensor into local memory.
await store.get("my_tensor", dtensor)
await ts.get("my_tensor", dtensor)

def run_in_parallel(func):
# just for demonstrative purposes
return func

if __name__ == "__main__":
ts.initialize()
run_in_parallel(place_dtensor_in_store)
run_in_parallel(fetch_dtensor_from_store)
ts.shutdown()

# checkout out tests/test_resharding.py for more e2e examples with resharding DTensor.
```
Expand Down
15 changes: 6 additions & 9 deletions example/torchstore_rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,8 @@
from typing import Tuple

import torch
import torchstore as ts
from monarch.actor import Actor, current_rank, endpoint, proc_mesh
from torchstore import MultiProcessStore
from torchstore._state_dict_utils import get_state_dict, push_state_dict


# Run the example : python example/torchstore_rl.py
Expand Down Expand Up @@ -44,7 +43,7 @@ async def step(
self.optim.step()
print("[learner] weights: ", self.model.state_dict())
# Put weights in to torch.store
await push_state_dict(self.store, self.model.state_dict(), key="toy_app")
await ts.put_state_dict(self.model.state_dict(), key="toy_app")


class Generator(Actor):
Expand All @@ -61,9 +60,7 @@ async def update_weights(self):
)
)
# Fetch weights from torch.store
await get_state_dict(
self.store, key="toy_app", user_state_dict=self.model.state_dict()
)
await ts.get_state_dict(key="toy_app", user_state_dict=self.model.state_dict())
print(
"[generator {}] new weights: {}".format(self.index, self.model.state_dict())
)
Expand All @@ -89,10 +86,10 @@ async def main():
learner_mesh = await proc_mesh(gpus=num_learners)
gen_mesh = await proc_mesh(gpus=num_generators)

store = await MultiProcessStore.create_store()
await ts.initialize()

learner = await learner_mesh.spawn("learner", Learner, store)
generators = await gen_mesh.spawn("generator", Generator, store)
learner = await learner_mesh.spawn("learner", Learner)
generators = await gen_mesh.spawn("generator", Generator)

logits, reward = await generators.generate.call_one(
torch.randn(4, 4, device="cuda")
Expand Down
10 changes: 3 additions & 7 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,24 +10,20 @@ readme = "README.md"
authors = [
{ name = "PyTorch Team", email = "packages@pytorch.org" },
]
requires-python = ">=3.6"
requires-python = ">=3.9"
classifiers = [
"Development Status :: 3 - Alpha",
"Intended Audience :: Developers",
"Intended Audience :: Science/Research",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.6",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
]
dependencies = [] #todo: add toml

[project.urls]
"Homepage" = "https://github.com/your-username/torchstore"
"Bug Tracker" = "https://github.com/your-username/torchstore/issues"
"Homepage" = "https://github.com/meta-pytorch/torchstore"
"Bug Tracker" = "https://github.com/meta-pytorch/torchstore/issues"

[project.optional-dependencies]
dev = [
Expand Down
5 changes: 5 additions & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
70 changes: 39 additions & 31 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,24 +8,22 @@
import os
import tempfile
import time
import unittest
from logging import getLogger

import pytest

import torch

import torchstore as ts
from monarch.actor import Actor, current_rank, endpoint
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.fsdp import fully_shard

from torchstore import MultiProcessStore
from torchstore._state_dict_utils import get_state_dict, push_state_dict
from torchstore.logging import init_logging
from torchstore.utils import spawn_actors

from transformers import AutoModelForCausalLM

from .utils import main, transport_plus_strategy_params

logger = getLogger(__name__)

needs_cuda = pytest.mark.skipif(
Expand All @@ -40,14 +38,15 @@


class ModelTest(Actor):
def __init__(self, store, mesh_shape, file_store_name):
init_logging()
def __init__(self, mesh_shape, file_store_name):
ts.init_logging()
self.rank = current_rank().rank
self.store = store
self.mesh_shape = mesh_shape
self.world_size = math.prod(mesh_shape)
self.file_store_name = file_store_name

os.environ["LOCAL_RANK"] = str(self.rank)

def initialize_distributed(self):
self.rlog(f"Initialize process group using {self.file_store_name=} ")
torch.distributed.init_process_group(
Expand Down Expand Up @@ -94,8 +93,8 @@ async def do_push(self):

self.rlog("pushing state dict")
t = time.perf_counter()
await push_state_dict(self.store, state_dict, "v0")
self.rlog(f"pushed state dict in {time.perf_counter() - t} seconds")
await ts.put_state_dict(state_dict, "v0")
self.rlog(f"pushed state dict in {time.perf_counter()-t} seconds")

@endpoint
async def do_get(self):
Expand All @@ -109,34 +108,43 @@ async def do_get(self):
torch.distributed.barrier()
self.rlog("getting state dict")
t = time.perf_counter()
await get_state_dict(self.store, "v0", state_dict)
await ts.get_state_dict("v0", state_dict)
self.rlog(f"got state dict in {time.perf_counter() - t} seconds")


@needs_cuda
class TestHFModel(unittest.IsolatedAsyncioTestCase):
async def test_basic(self):
# FSDP
put_mesh_shape = (1,)
get_mesh_shape = (1,)
await self._do_test(put_mesh_shape, get_mesh_shape)
@pytest.mark.parametrize(*transport_plus_strategy_params())
@pytest.mark.asyncio
async def test_basic(strategy_params, use_rdma):
# FSDP
put_mesh_shape = (1,)
get_mesh_shape = (1,)
await _do_test(put_mesh_shape, get_mesh_shape, strategy_params[1], use_rdma)

async def test_resharding(self):
# FSDP
put_mesh_shape = (4,)
get_mesh_shape = (8,)
await self._do_test(put_mesh_shape, get_mesh_shape)

async def _do_test(self, put_mesh_shape, get_mesh_shape):
with tempfile.TemporaryDirectory() as tmpdir:
store = await MultiProcessStore.create_store()
@pytest.mark.parametrize(*transport_plus_strategy_params())
@pytest.mark.asyncio
async def test_resharding(strategy_params, use_rdma):
# FSDP
put_mesh_shape = (4,)
get_mesh_shape = (8,)
await _do_test(put_mesh_shape, get_mesh_shape, strategy_params[1], use_rdma)


async def _do_test(put_mesh_shape, get_mesh_shape, strategy, use_rdma):
os.environ["TORCHSTORE_RDMA_ENABLED"] = "1" if use_rdma else "0"

put_world_size = math.prod(put_mesh_shape)
await ts.initialize(
num_storage_volumes=put_world_size if strategy is not None else 1,
strategy=strategy,
)
try:
with tempfile.TemporaryDirectory() as tmpdir:
put_world_size = math.prod(put_mesh_shape)
put_world = await spawn_actors(
put_world_size,
ModelTest,
"save_world",
store=store,
mesh_shape=put_mesh_shape,
file_store_name=os.path.join(tmpdir, "save_world"),
)
Expand All @@ -146,7 +154,6 @@ async def _do_test(self, put_mesh_shape, get_mesh_shape):
get_world_size,
ModelTest,
"get_world",
store=store,
mesh_shape=get_mesh_shape,
file_store_name=os.path.join(tmpdir, "get_world"),
)
Expand All @@ -159,9 +166,10 @@ async def _do_test(self, put_mesh_shape, get_mesh_shape):
logger.info("fetching state dict")
t = time.perf_counter()
await get_world.do_get.call()
logger.info(f"getting state dict took: {time.perf_counter() - t} seconds")
logger.info(f"getting state dict took: {time.perf_counter()-t} seconds")
finally:
await ts.shutdown()


if __name__ == "__main__":
init_logging()
unittest.main()
main([__file__])
Loading
Loading