Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 84 additions & 43 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ Key Features:
- Multiple transport backends (RDMA, regular TCP) for optimal performance
- Flexible storage volume management and sharding strategies

Note: Although this may change in the future, TorchStore only supports multi-processing/multi-node jobs launched with Monarch.
For more information on what Monarch is, see https://github.com/meta-pytorch/monarch?tab=readme-ov-file#monarch-


> ⚠️ **Early Development Warning** TorchStore is currently in an experimental
> stage. You should expect bugs, incomplete features, and APIs that may change
> in future versions. The project welcomes bugfixes, but to make sure things are
Expand Down Expand Up @@ -51,8 +55,13 @@ pip install -e .

# Install development dependencies
pip install -e '.[dev]'

# NOTE: It's common to run into libpytorch issues. A good workaround is to export:
# export LD_LIBRARY_PATH="$CONDA_PREFIX/lib:${LD_LIBRARY_PATH:-}"
```



### Regular Installation

To install the package directly from the repository:
Expand All @@ -67,78 +76,110 @@ Once installed, you can import it in your Python code:
import torchstore
```

Note: Setup currently assumes you have a working conda environment with both torch & monarch (this is currently a todo).

## Usage

```python
import torch
import asyncio

import torch

from monarch.actor import Actor, current_rank, endpoint

import torchstore as ts
from torchstore.utils import spawn_actors


WORLD_SIZE = 4


# In monarch, Actors are the way we represent multi-process/node applications. For additional details, see:
# https://github.com/meta-pytorch/monarch?tab=readme-ov-file#monarch-
class ExampleActor(Actor):
def __init__(self, world_size=WORLD_SIZE):
self.rank = current_rank().rank
self.world_size = WORLD_SIZE

@endpoint
async def store_tensor(self):
t = torch.tensor([self.rank])
await ts.put(f"{self.rank}_tensor", t)

@endpoint
async def print_tensor(self):
other_rank = (self.rank + 1) % self.world_size
t = await ts.get(f"{other_rank}_tensor")
print(f"Rank=[{self.rank}] Fetched {t} from {other_rank=}")


async def main():

# Create a store instance
await ts.initialize()

# Store a tensor
await ts.put("my_tensor", torch.randn(3, 4))
actors = await spawn_actors(WORLD_SIZE, ExampleActor, "example_actors")

# Retrieve a tensor
tensor = await ts.get("my_tensor")
# Calls "store_tensor" on each actor instance
await actors.store_tensor.call()
await actors.print_tensor.call()


if __name__ == "__main__":
if __name__ == "__main__":
asyncio.run(main())

# Expected output
# [0] [2] Rank=[2] Fetched tensor([3]) from other_rank=3
# [0] [0] Rank=[0] Fetched tensor([1]) from other_rank=1
# [0] [3] Rank=[3] Fetched tensor([0]) from other_rank=0
# [0] [1] Rank=[1] Fetched tensor([2]) from other_rank=2

```

### Resharding Support with DTensor

```python
import torchstore as ts
from torch.distributed._tensor import distribute_tensor, Replicate, Shard
from torch.distributed.device_mesh import init_device_mesh

async def place_dtensor_in_store():
device_mesh = init_device_mesh("cpu", (4,))
tensor = torch.arange(4)
dtensor = distribute_tensor(tensor, device_mesh, placements=[Shard(1)])

# Store a tensor
await ts.put("my_tensor", dtensor)
TorchStore makes it easy to fetch arbitraty slices of any Distributed Tensor.
For a full DTensor example, see [examples/dtensor.py](https://github.com/meta-pytorch/torchstore/blob/main/example/dtensor.py)


async def fetch_dtensor_from_store()
# You can now fetch arbitrary shards of this tensor from any rank e.g.
device_mesh = init_device_mesh("cpu", (2,2))
tensor = torch.rand(4)
dtensor = distribute_tensor(
tensor,
device_mesh,
placements=[Replicate(), Shard(0)]
)

# This line copies the previously stored dtensor into local memory.
await ts.get("my_tensor", dtensor)

def run_in_parallel(func):
# just for demonstrative purposes
return func
```python

if __name__ == "__main__":
ts.initialize()
run_in_parallel(place_dtensor_in_store)
run_in_parallel(fetch_dtensor_from_store)
ts.shutdown()
class DTensorActor(Actor):
"""
Example pseudo-code for an Actor utilizing DTensor support

Full actor definition in [examples/dtensor.py](https://github.com/meta-pytorch/torchstore/blob/main/example/dtensor.py)
"""

@endpoint
async def do_put(self):
# Typical dtensor boiler-plate
self.initialize_distributed()
device_mesh = init_device_mesh("cpu", self.mesh_shape)
tensor = self.original_tensor.to("cpu")
dtensor = distribute_tensor(tensor, device_mesh, placements=self.placements)

print(f"Calling put with {dtensor=}")
# This will place only the local shard into TorchStore
await ts.put(self.shared_key, dtensor)

@endpoint
async def do_get(self):
# Typical dtensor boiler-plate
self.initialize_distributed()
device_mesh = init_device_mesh("cpu", self.mesh_shape)
tensor = self.original_tensor.to("cpu")
dtensor = distribute_tensor(tensor, device_mesh, placements=self.placements)

# Torchstore will use the metadata in the local dtensor to only fetch tensor data
# which belongs to the local shard.
fetched_tensor = await ts.get(self.shared_key, dtensor)
print(fetched_tensor)

# checkout out tests/test_resharding.py for more e2e examples with resharding DTensor.
```

# Testing

Pytest is used for testing. For an examples of how to run tests (and get logs), see:
`TORCHSTORE_LOG_LEVEL=DEBUG pytest -vs --log-cli-level=DEBUG tests/test_models.py::test_main`
`TORCHSTORE_LOG_LEVEL=DEBUG pytest -vs --log-cli-level=DEBUG tests/test_models.py::test_basic

## License

Expand Down
8 changes: 5 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,25 @@ build-backend = "setuptools.build_meta"

[project]
name = "torchstore"
version = "0.1.0"
version = "0.0.0dev"
description = "A storage solution for PyTorch tensors with distributed tensor support"
readme = "README.md"
authors = [
{ name = "PyTorch Team", email = "packages@pytorch.org" },
]
requires-python = ">=3.9"
requires-python = ">=3.10"
classifiers = [
"Development Status :: 3 - Alpha",
"Intended Audience :: Developers",
"Intended Audience :: Science/Research",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
]
dependencies = [
"pygtrie",
"torch==2.9.0",
"torchmonarch==0.1.0rc4"
] #todo: add toml

[project.urls]
Expand Down
Loading