Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 2 additions & 87 deletions tests/test_resharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,100 +16,15 @@

import torchstore as ts

from monarch.actor import Actor, current_rank, endpoint
from torch.distributed._tensor import distribute_tensor, Replicate, Shard
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed._tensor import Replicate, Shard
from torch.distributed.tensor._utils import _compute_local_shape_and_global_offset
from torchstore.utils import get_local_tensor, spawn_actors

from .utils import main, transport_plus_strategy_params
from .utils import DTensorActor, main, transport_plus_strategy_params

logger = getLogger(__name__)


class DTensorActor(Actor):
"""Test class used to verify correctness of resharding across different shardings.
Currently only supports a single tensor
"""

shared_key = "test_key"

def __init__(
self,
mesh_shape,
original_tensor,
placements,
file_store_name,
visible_devices="0,1,2,3,4,5,6,7",
):
self.rank = current_rank().rank
self.mesh_shape = mesh_shape
self.world_size = math.prod(mesh_shape)
self.original_tensor = original_tensor
self.placements = placements
self.file_store_name = file_store_name

# torchstore will fail without this (see LocalRankStrategy)
os.environ["LOCAL_RANK"] = str(self.rank)

# this is only necessary for nccl, but we're not using it in this test.
os.environ["CUDA_VISIBLE_DEVICES"] = visible_devices

def rlog(self, msg):
# TODO: set to 'info' once this is fixed in monarch (which currently is hiding logs :/)
logger.info(f"rank: {self.rank} {msg}")

def initialize_distributed(self):
self.rlog(f"Initialize process group using {self.file_store_name=} ")
torch.distributed.init_process_group(
backend="gloo",
rank=self.rank,
world_size=self.world_size,
init_method=f"file://{self.file_store_name}",
)

# this barrier is more to make sure torch.distibuted is working
self.rlog("barrrer")
torch.distributed.barrier()

@endpoint
async def do_put(self):
self.initialize_distributed()

self.rlog("Create device mesh")
device_mesh = init_device_mesh("cpu", self.mesh_shape)

self.rlog("distributing dtensor")
tensor = self.original_tensor.to("cpu")
dtensor = distribute_tensor(tensor, device_mesh, placements=self.placements)

self.rlog(f"calling put with {dtensor=}")
await ts.put(self.shared_key, dtensor)

@endpoint
async def do_get(self):
self.initialize_distributed()

self.rlog("Create device mesh")
# TODO: nccl is giving me a weird error on process group split for 2d mesh
device_mesh = init_device_mesh("cpu", self.mesh_shape)

self.rlog("distributing dtensor")
tensor = self.original_tensor.to("cpu")
dtensor = distribute_tensor(tensor, device_mesh, placements=self.placements)

self.rlog(f"calling get with {dtensor=}")
fetched_tensor = await ts.get(self.shared_key, dtensor)
self.rlog(f"after fetch: {dtensor=}")
assert torch.equal(dtensor, fetched_tensor)

return fetched_tensor, device_mesh.get_coordinate()

@endpoint
async def destroy_process_group(self):
torch.distributed.destroy_process_group()


@pytest.mark.parametrize(*transport_plus_strategy_params())
@pytest.mark.asyncio
async def test_1d_resharding(strategy_params, use_rdma):
Expand Down
136 changes: 135 additions & 1 deletion tests/test_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,14 @@
import torchstore as ts

from monarch.actor import Actor, current_rank, endpoint

# DTensor imports for DTensor slice testing
from torch.distributed._tensor import Shard
from torchstore.logging import init_logging
from torchstore.transport.pipe import TensorSlice
from torchstore.utils import spawn_actors

from .utils import main, transport_plus_strategy_params
from .utils import DTensorActor, main, transport_plus_strategy_params

init_logging()
logger = getLogger(__name__)
Expand Down Expand Up @@ -216,6 +220,102 @@ async def exists(self, key):
await ts.shutdown()


@pytest.mark.parametrize(*transport_plus_strategy_params())
@pytest.mark.asyncio
async def test_get_tensor_slice(strategy_params, use_rdma):
"""Test tensor slice API functionality"""
os.environ["TORCHSTORE_RDMA_ENABLED"] = "1" if use_rdma else "0"

class TensorSlicePutActor(Actor):
"""Actor for putting tensors."""

def __init__(self, world_size):
init_logging()
self.world_size = world_size
self.rank = current_rank().rank
# required by LocalRankStrategy
os.environ["LOCAL_RANK"] = str(self.rank)

@endpoint
async def put(self, key, tensor):
await ts.put(key, tensor)

volume_world_size, strategy = strategy_params
await ts.initialize(num_storage_volumes=volume_world_size, strategy=strategy)

# Spawn test actors - separate meshes for put and get to test cross-process communication
put_actor_mesh = await spawn_actors(
volume_world_size,
TensorSlicePutActor,
"tensor_slice_put_actors",
world_size=volume_world_size,
)

try:
test_tensor = torch.randn(1000, 2000)
key = "test_tensor"

# Store the tensor using put actor mesh
put_actor = put_actor_mesh.slice(**{"hosts": 0, "gpus": 0})
await put_actor.put.call(key, test_tensor)

# Test full tensor retrieval using get actor mesh
retrieved_tensor = await ts.get(key)
assert torch.equal(test_tensor, retrieved_tensor)

# Test slice retrieval using get actor mesh
tensor_slice_spec = TensorSlice(
offsets=(100, 200),
coordinates=(),
global_shape=(1000, 2000),
local_shape=(50, 100),
mesh_shape=(),
)

tensor_slice = await ts.get(key, tensor_slice_spec=tensor_slice_spec)
expected_slice = test_tensor[100:150, 200:300]
assert torch.equal(tensor_slice, expected_slice)
assert tensor_slice.shape == (50, 100)

finally:
await put_actor_mesh._proc_mesh.stop()
await ts.shutdown()


@pytest.mark.asyncio
async def test_tensor_slice_inplace():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add a test such that we call put on a dtensor, and then call get with no tensor slice and no dtensor? The result should be the entire tensor

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test added. Also factored out DTensorActor out from test_sharding.py into utils.py for reuse in test_store.py.

"""Test tensor slice API with in-place operations"""
await ts.initialize(num_storage_volumes=1)

try:
# Store a test tensor
test_tensor = torch.randn(100, 200)
await ts.put("inplace_test", test_tensor)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a cool thing to add to readme / docs


# Test in-place retrieval with slice
slice_spec = TensorSlice(
offsets=(10, 20),
coordinates=(),
global_shape=(100, 200),
local_shape=(30, 40),
mesh_shape=(),
)

# Create pre-allocated buffer
slice_buffer = torch.empty(30, 40)
result = await ts.get(
"inplace_test", inplace_tensor=slice_buffer, tensor_slice_spec=slice_spec
)

# Verify in-place operation
assert result is slice_buffer
expected_slice = test_tensor[10:40, 20:60]
assert torch.equal(slice_buffer, expected_slice)

finally:
await ts.shutdown()


@pytest.mark.asyncio
async def test_large_tensors():
"""Test basic put/get functionality for large tensors"""
Expand Down Expand Up @@ -291,5 +391,39 @@ async def get(self):
# TODO: assert equal tensors from put/get


@pytest.mark.asyncio
async def test_put_dtensor_get_full_tensor():
"""Test basic DTensor put/get functionality with separate put and get meshes using shared DTensorActor"""
import tempfile

await ts.initialize(num_storage_volumes=2, strategy=ts.LocalRankStrategy())

original_tensor = torch.arange(16).reshape(4, 4).float()

with tempfile.TemporaryDirectory() as filesystem_store_dir:
try:
put_mesh = await spawn_actors(
2,
DTensorActor,
"dtensor_put_mesh",
mesh_shape=(2,),
original_tensor=original_tensor,
placements=[Shard(0)],
file_store_name=os.path.join(filesystem_store_dir, "put_test"),
visible_devices="0,1",
)

await put_mesh.do_put.call()

fetched_tensor = await ts.get("test_key")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yay!

assert torch.equal(original_tensor, fetched_tensor)

finally:
# Clean up process groups
await put_mesh.destroy_process_group.call()
await put_mesh._proc_mesh.stop()
await ts.shutdown()


if __name__ == "__main__":
main(__file__)
92 changes: 92 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,19 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import math
import os
from itertools import product
from logging import getLogger

import pytest
import torch
import torchstore as ts
from monarch.actor import Actor, current_rank, endpoint
from torch.distributed._tensor import distribute_tensor
from torch.distributed.device_mesh import init_device_mesh

logger = getLogger(__name__)

def main(file):
ts.init_logging()
Expand All @@ -24,3 +33,86 @@ def transport_plus_strategy_params():
rdma_options = [True, False]

return "strategy_params, use_rdma", list(product(strategies, rdma_options))


class DTensorActor(Actor):
"""Test class used to verify correctness of resharding across different shardings.
Currently only supports a single tensor
"""

shared_key = "test_key"

def __init__(
self,
mesh_shape,
original_tensor,
placements,
file_store_name,
visible_devices="0,1,2,3,4,5,6,7",
):
self.rank = current_rank().rank
self.mesh_shape = mesh_shape
self.world_size = math.prod(mesh_shape)
self.original_tensor = original_tensor
self.placements = placements
self.file_store_name = file_store_name

# torchstore will fail without this (see LocalRankStrategy)
os.environ["LOCAL_RANK"] = str(self.rank)

# this is only necessary for nccl, but we're not using it in this test.
os.environ["CUDA_VISIBLE_DEVICES"] = visible_devices

def rlog(self, msg):
# TODO: set to 'info' once this is fixed in monarch (which currently is hiding logs :/)
logger.info(f"rank: {self.rank} {msg}")

def initialize_distributed(self):
self.rlog(f"Initialize process group using {self.file_store_name=} ")
torch.distributed.init_process_group(
backend="gloo",
rank=self.rank,
world_size=self.world_size,
init_method=f"file://{self.file_store_name}",
)

# this barrier is more to make sure torch.distibuted is working
self.rlog("barrrer")
torch.distributed.barrier()

@endpoint
async def do_put(self):
self.initialize_distributed()

self.rlog("Create device mesh")
device_mesh = init_device_mesh("cpu", self.mesh_shape)

self.rlog("distributing dtensor")
tensor = self.original_tensor.to("cpu")
dtensor = distribute_tensor(tensor, device_mesh, placements=self.placements)

self.rlog(f"calling put with {dtensor=}")
await ts.put(self.shared_key, dtensor)

@endpoint
async def do_get(self):
self.initialize_distributed()

self.rlog("Create device mesh")
# TODO: nccl is giving me a weird error on process group split for 2d mesh
device_mesh = init_device_mesh("cpu", self.mesh_shape)

self.rlog("distributing dtensor")
tensor = self.original_tensor.to("cpu")
dtensor = distribute_tensor(tensor, device_mesh, placements=self.placements)

self.rlog(f"calling get with {dtensor=}")
fetched_tensor = await ts.get(self.shared_key, dtensor)
self.rlog(f"after fetch: {dtensor=}")
assert torch.equal(dtensor, fetched_tensor)

return fetched_tensor, device_mesh.get_coordinate()

@endpoint
async def destroy_process_group(self):
torch.distributed.destroy_process_group()
Loading
Loading