Skip to content

Conversation

kaiyuan-li
Copy link
Contributor

  1. Added an optional tensor_slice_spec to the get() API so we can get a slice instead of the full tensor.
  2. Test added for cross process get() with two actors and inplace get()

Next:

  1. fix dtensor related test (test_sharding.py)
  2. add test for dtensor get (with slice)
  3. optimize dtensor slice get
    a. avoid constructing the whole global tensor
    b. avoid unnecessary volumes without tensor slice of interest
    c. batch tensor slice query into volumes

@kaiyuan-li kaiyuan-li requested a review from LucasLLC September 10, 2025 15:11
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Sep 10, 2025
# but for now any should work.
fetched_tensor = await pipe.get_from_storage_volume(key, request)

# If user requested a specific slice, extract it
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curious if it's possible to throw an exception in the storage layer? Does the request already have this information?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought for now we assume request doesn't know such info and blindly fetches the whole tensor.


# Handle in-place operation for tensor slice
if inplace_tensor is not None:
inplace_tensor.copy_(sliced_tensor)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I forgot to add this i previous pr, but it's slightly cleaner to return inplace_tensor.copy_(sliced_tensor)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good to learn (and I checked pytorch doc that copy_() returns self). Done.

self,
key: str,
inplace_tensor: Optional[torch.Tensor] = None,
tensor_slice_spec: Optional[TensorSlice] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As written, we should assert either inplace tensor or tensor slice spec is None. Both being set should raise

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah nvm, in this case we should assert inplace_tensor.shape == tensor_slice_spec.shape

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Validation added to raise ValueError if both are specified and there's a mismatch between the shapes.

logger.debug(f"Fetching {key}")
request = Request.from_any(inplace_tensor)

# When slicing, don't use inplace_tensor for the request because the transport
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure I understand this comment -- does it simplify here if we assume one is None?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess the change here is because this Request always fetches the whole tensor. Before, there's no sliced option so inplace_tensor is either None or the whole tensor. Now there's is sliced option, but at least for now we always want the request to be with whole tensor.

Another question I have - do we want the pipe to know that we are just fetching a slice or not. At this moment, I was trying to assume that pipe always fetches the whole tensor. That's why the slice buffer is not put into Request. This is simple and clean but NOT efficient. Maybe we can improve the efficiency later by making Request also slice-aware. What do you think?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think actually request already has the ability to fetch a single tensor slice... this is how the tensor_slice_request work below.



@pytest.mark.asyncio
async def test_tensor_slice_inplace():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add a test such that we call put on a dtensor, and then call get with no tensor slice and no dtensor? The result should be the entire tensor

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test added. Also factored out DTensorActor out from test_sharding.py into utils.py for reuse in test_store.py.

Copy link
Contributor Author

@kaiyuan-li kaiyuan-li left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@LucasLLC thanks for the review. Those a good comments which helped me have a better understanding of the architecture and workflow. Please take another look :)

self,
key: str,
inplace_tensor: Optional[torch.Tensor] = None,
tensor_slice_spec: Optional[TensorSlice] = None,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Validation added to raise ValueError if both are specified and there's a mismatch between the shapes.

logger.debug(f"Fetching {key}")
request = Request.from_any(inplace_tensor)

# When slicing, don't use inplace_tensor for the request because the transport
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess the change here is because this Request always fetches the whole tensor. Before, there's no sliced option so inplace_tensor is either None or the whole tensor. Now there's is sliced option, but at least for now we always want the request to be with whole tensor.

Another question I have - do we want the pipe to know that we are just fetching a slice or not. At this moment, I was trying to assume that pipe always fetches the whole tensor. That's why the slice buffer is not put into Request. This is simple and clean but NOT efficient. Maybe we can improve the efficiency later by making Request also slice-aware. What do you think?

# but for now any should work.
fetched_tensor = await pipe.get_from_storage_volume(key, request)

# If user requested a specific slice, extract it
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought for now we assume request doesn't know such info and blindly fetches the whole tensor.


# Handle in-place operation for tensor slice
if inplace_tensor is not None:
inplace_tensor.copy_(sliced_tensor)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good to learn (and I checked pytorch doc that copy_() returns self). Done.



@pytest.mark.asyncio
async def test_tensor_slice_inplace():
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test added. Also factored out DTensorActor out from test_sharding.py into utils.py for reuse in test_store.py.

Copy link
Contributor

@LucasLLC LucasLLC left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From my understanding, the main difference between dtensor and tensor slice is a tensor slice does not have a coordinate. iiuc this was a requirement at some point during put, but should not be required during gets.

I'm curious if we create the request from the tensor slice, if the implementation would be simplified since we could use the same code paths as we do for dtensor?

if slice_spec is None:
return await ts.get(key)
else:
return await ts.get(key, tensor_slice_spec=slice_spec)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the future we'll want to raise as well if the tensor slice does not exist / is invalid

volume_world_size,
TensorSliceGetActor,
"tensor_slice_get_actors",
world_size=volume_world_size,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why spawn volume_world_size actors if we're slicing to single node below?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make sense. This was mostly from copying existing test example, lol. It doesn't look necessary. Updated to a direct ts.get().

# Initialize TorchStore with 2 storage volumes and LocalRankStrategy
from torchstore.strategy import LocalRankStrategy

await ts.initialize(num_storage_volumes=2, strategy=LocalRankStrategy())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: ts.LocalRankStrategy instead of import

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated.

DTensorActor,
"dtensor_get_mesh",
mesh_shape=(2,),
original_tensor=torch.zeros(4, 4).float(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: -> torch.zeros_like(original_tensor)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated.



@pytest.mark.asyncio
async def test_dtensor_simple_put_get():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm confused by this test. Why is this needed outside of test_resharding? As a side note it also does not test correctness (would also suggest changing the placement dim).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I misunderstood your comment in the last iteration. Updated this test to put a dtensor and just fetch it with get('key'). Also verified that the result matches original tensor.

request = Request.from_any(inplace_tensor)

if tensor_slice_spec is not None and inplace_tensor is not None:
if tensor_slice_spec.local_shape != inplace_tensor.shape:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if inplace_tensor is a dtensor we should assert on offset as well?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added validation that tensor_slice should be None if inplace_tensor is a DTensor.

logger.debug(f"Fetching {key}")
request = Request.from_any(inplace_tensor)

# When slicing, don't use inplace_tensor for the request because the transport
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think actually request already has the ability to fetch a single tensor slice... this is how the tensor_slice_request work below.

# multinode support here
volume_map = await self._controller.locate_volumes.call_one(key)

if object_type in (ObjectType.OBJECT, ObjectType.TENSOR):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this if statement should only apply to ObjectType.OBJECT. Objects are the only items that are allowed to live on a single storage volume (so far).

In the past, the same was true about requesting tensors, since we assumed the only "DTensor" would request a tensor from sharded storage. Since we are allowing users to request arbitrary 'tensor slice' without a dtensor, this is not longer the case, meaning we should change this code to only account for ObjectType.OBJECT, and likely use the codepath below.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done with a major update. Please take a look.

@kaiyuan-li
Copy link
Contributor Author

Please take another look :)

async def get(
key: str,
inplace_tensor: Optional[torch.Tensor] = None,
tensor_slice_spec: Optional[TensorSlice] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add docstring explaining the combinations which are allowed. e.g. "tensor_slice_spec + dtensor is not allowed"

full_tensor = await self._get_distributed_whole_tensor(key)

if isinstance(inplace_tensor, DTensor):
request = Request.from_any(inplace_tensor)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: from Dtensor

try:
# Store a test tensor
test_tensor = torch.randn(100, 200)
await ts.put("inplace_test", test_tensor)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a cool thing to add to readme / docs


await put_mesh.do_put.call()

fetched_tensor = await ts.get("test_key")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yay!

object_type = ObjectType.from_request(request)

# multinode support here
stored_object_type = await self._get_stored_object_type(key)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One additional control path query, which is probably fine but we do something similar in Pipe. (get_meta).

Copy link
Contributor

@LucasLLC LucasLLC left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall really happy with the functionality of this PR! Tysm for going through the motions.

Since this is not pressing, I would like to spend a little more time thinking through how we can best encapsulate the respondsibilities of each object - for example potentially moving some of this logic out of client into pipe.

If this ends up being blocking we can revisit, but since we have time I'd rather see if we can make the right decisions upfront! Let's schedule some time this week to go over it.

Copy link
Contributor

@LucasLLC LucasLLC left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

approving since we need this fix

@LucasLLC LucasLLC merged commit 0f0e7d4 into main Sep 19, 2025
1 of 5 checks passed
@LucasLLC LucasLLC deleted the lky_tensor_slice branch September 19, 2025 15:35
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants