-
Notifications
You must be signed in to change notification settings - Fork 5
add optional tensor_slice_spec
arg to get()
API
#32
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
80e38ba
get with tensor slice
kaiyuan-li f530c28
address comments on client.py
kaiyuan-li 2441fd0
test dtensor put get
kaiyuan-li a9f0361
fix dtensor get set test
kaiyuan-li 3b748be
update
kaiyuan-li 86808d2
remove unused imports
kaiyuan-li 75665af
Merge branch 'main' into lky_tensor_slice
LucasLLC fdb9f28
fix merge conflicts
LucasLLC 0f686be
Merge branch 'main' into lky_tensor_slice
LucasLLC File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,10 +16,14 @@ | |
import torchstore as ts | ||
|
||
from monarch.actor import Actor, current_rank, endpoint | ||
|
||
# DTensor imports for DTensor slice testing | ||
from torch.distributed._tensor import Shard | ||
from torchstore.logging import init_logging | ||
from torchstore.transport.pipe import TensorSlice | ||
from torchstore.utils import spawn_actors | ||
|
||
from .utils import main, transport_plus_strategy_params | ||
from .utils import DTensorActor, main, transport_plus_strategy_params | ||
|
||
init_logging() | ||
logger = getLogger(__name__) | ||
|
@@ -216,6 +220,102 @@ async def exists(self, key): | |
await ts.shutdown() | ||
|
||
|
||
@pytest.mark.parametrize(*transport_plus_strategy_params()) | ||
@pytest.mark.asyncio | ||
async def test_get_tensor_slice(strategy_params, use_rdma): | ||
"""Test tensor slice API functionality""" | ||
os.environ["TORCHSTORE_RDMA_ENABLED"] = "1" if use_rdma else "0" | ||
|
||
class TensorSlicePutActor(Actor): | ||
"""Actor for putting tensors.""" | ||
|
||
def __init__(self, world_size): | ||
init_logging() | ||
self.world_size = world_size | ||
self.rank = current_rank().rank | ||
# required by LocalRankStrategy | ||
os.environ["LOCAL_RANK"] = str(self.rank) | ||
|
||
@endpoint | ||
async def put(self, key, tensor): | ||
await ts.put(key, tensor) | ||
|
||
volume_world_size, strategy = strategy_params | ||
await ts.initialize(num_storage_volumes=volume_world_size, strategy=strategy) | ||
|
||
# Spawn test actors - separate meshes for put and get to test cross-process communication | ||
put_actor_mesh = await spawn_actors( | ||
volume_world_size, | ||
TensorSlicePutActor, | ||
"tensor_slice_put_actors", | ||
world_size=volume_world_size, | ||
) | ||
|
||
try: | ||
test_tensor = torch.randn(1000, 2000) | ||
key = "test_tensor" | ||
|
||
# Store the tensor using put actor mesh | ||
put_actor = put_actor_mesh.slice(**{"hosts": 0, "gpus": 0}) | ||
await put_actor.put.call(key, test_tensor) | ||
|
||
# Test full tensor retrieval using get actor mesh | ||
retrieved_tensor = await ts.get(key) | ||
assert torch.equal(test_tensor, retrieved_tensor) | ||
|
||
# Test slice retrieval using get actor mesh | ||
tensor_slice_spec = TensorSlice( | ||
offsets=(100, 200), | ||
coordinates=(), | ||
global_shape=(1000, 2000), | ||
local_shape=(50, 100), | ||
mesh_shape=(), | ||
) | ||
|
||
tensor_slice = await ts.get(key, tensor_slice_spec=tensor_slice_spec) | ||
expected_slice = test_tensor[100:150, 200:300] | ||
assert torch.equal(tensor_slice, expected_slice) | ||
assert tensor_slice.shape == (50, 100) | ||
|
||
finally: | ||
await put_actor_mesh._proc_mesh.stop() | ||
await ts.shutdown() | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_tensor_slice_inplace(): | ||
"""Test tensor slice API with in-place operations""" | ||
await ts.initialize(num_storage_volumes=1) | ||
|
||
try: | ||
# Store a test tensor | ||
test_tensor = torch.randn(100, 200) | ||
await ts.put("inplace_test", test_tensor) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a cool thing to add to readme / docs |
||
|
||
# Test in-place retrieval with slice | ||
slice_spec = TensorSlice( | ||
offsets=(10, 20), | ||
coordinates=(), | ||
global_shape=(100, 200), | ||
local_shape=(30, 40), | ||
mesh_shape=(), | ||
) | ||
|
||
# Create pre-allocated buffer | ||
slice_buffer = torch.empty(30, 40) | ||
result = await ts.get( | ||
"inplace_test", inplace_tensor=slice_buffer, tensor_slice_spec=slice_spec | ||
) | ||
|
||
# Verify in-place operation | ||
assert result is slice_buffer | ||
expected_slice = test_tensor[10:40, 20:60] | ||
assert torch.equal(slice_buffer, expected_slice) | ||
|
||
finally: | ||
await ts.shutdown() | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_large_tensors(): | ||
"""Test basic put/get functionality for large tensors""" | ||
|
@@ -291,5 +391,39 @@ async def get(self): | |
# TODO: assert equal tensors from put/get | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_put_dtensor_get_full_tensor(): | ||
"""Test basic DTensor put/get functionality with separate put and get meshes using shared DTensorActor""" | ||
import tempfile | ||
|
||
await ts.initialize(num_storage_volumes=2, strategy=ts.LocalRankStrategy()) | ||
|
||
original_tensor = torch.arange(16).reshape(4, 4).float() | ||
|
||
with tempfile.TemporaryDirectory() as filesystem_store_dir: | ||
try: | ||
put_mesh = await spawn_actors( | ||
2, | ||
DTensorActor, | ||
"dtensor_put_mesh", | ||
mesh_shape=(2,), | ||
original_tensor=original_tensor, | ||
placements=[Shard(0)], | ||
file_store_name=os.path.join(filesystem_store_dir, "put_test"), | ||
visible_devices="0,1", | ||
) | ||
|
||
await put_mesh.do_put.call() | ||
|
||
fetched_tensor = await ts.get("test_key") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yay! |
||
assert torch.equal(original_tensor, fetched_tensor) | ||
|
||
finally: | ||
# Clean up process groups | ||
await put_mesh.destroy_process_group.call() | ||
await put_mesh._proc_mesh.stop() | ||
await ts.shutdown() | ||
|
||
|
||
if __name__ == "__main__": | ||
main(__file__) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we add a test such that we call put on a dtensor, and then call get with no tensor slice and no dtensor? The result should be the entire tensor
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Test added. Also factored out
DTensorActor
out fromtest_sharding.py
intoutils.py
for reuse intest_store.py
.