-
Notifications
You must be signed in to change notification settings - Fork 5
add optional tensor_slice_spec arg to get() API
#32
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
torchstore/client.py
Outdated
| # but for now any should work. | ||
| fetched_tensor = await pipe.get_from_storage_volume(key, request) | ||
|
|
||
| # If user requested a specific slice, extract it |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
curious if it's possible to throw an exception in the storage layer? Does the request already have this information?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought for now we assume request doesn't know such info and blindly fetches the whole tensor.
torchstore/client.py
Outdated
|
|
||
| # Handle in-place operation for tensor slice | ||
| if inplace_tensor is not None: | ||
| inplace_tensor.copy_(sliced_tensor) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I forgot to add this i previous pr, but it's slightly cleaner to return inplace_tensor.copy_(sliced_tensor)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good to learn (and I checked pytorch doc that copy_() returns self). Done.
torchstore/client.py
Outdated
| self, | ||
| key: str, | ||
| inplace_tensor: Optional[torch.Tensor] = None, | ||
| tensor_slice_spec: Optional[TensorSlice] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As written, we should assert either inplace tensor or tensor slice spec is None. Both being set should raise
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah nvm, in this case we should assert inplace_tensor.shape == tensor_slice_spec.shape
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Validation added to raise ValueError if both are specified and there's a mismatch between the shapes.
torchstore/client.py
Outdated
| logger.debug(f"Fetching {key}") | ||
| request = Request.from_any(inplace_tensor) | ||
|
|
||
| # When slicing, don't use inplace_tensor for the request because the transport |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure I understand this comment -- does it simplify here if we assume one is None?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess the change here is because this Request always fetches the whole tensor. Before, there's no sliced option so inplace_tensor is either None or the whole tensor. Now there's is sliced option, but at least for now we always want the request to be with whole tensor.
Another question I have - do we want the pipe to know that we are just fetching a slice or not. At this moment, I was trying to assume that pipe always fetches the whole tensor. That's why the slice buffer is not put into Request. This is simple and clean but NOT efficient. Maybe we can improve the efficiency later by making Request also slice-aware. What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think actually request already has the ability to fetch a single tensor slice... this is how the tensor_slice_request work below.
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_tensor_slice_inplace(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we add a test such that we call put on a dtensor, and then call get with no tensor slice and no dtensor? The result should be the entire tensor
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Test added. Also factored out DTensorActor out from test_sharding.py into utils.py for reuse in test_store.py.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@LucasLLC thanks for the review. Those a good comments which helped me have a better understanding of the architecture and workflow. Please take another look :)
torchstore/client.py
Outdated
| self, | ||
| key: str, | ||
| inplace_tensor: Optional[torch.Tensor] = None, | ||
| tensor_slice_spec: Optional[TensorSlice] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Validation added to raise ValueError if both are specified and there's a mismatch between the shapes.
torchstore/client.py
Outdated
| logger.debug(f"Fetching {key}") | ||
| request = Request.from_any(inplace_tensor) | ||
|
|
||
| # When slicing, don't use inplace_tensor for the request because the transport |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess the change here is because this Request always fetches the whole tensor. Before, there's no sliced option so inplace_tensor is either None or the whole tensor. Now there's is sliced option, but at least for now we always want the request to be with whole tensor.
Another question I have - do we want the pipe to know that we are just fetching a slice or not. At this moment, I was trying to assume that pipe always fetches the whole tensor. That's why the slice buffer is not put into Request. This is simple and clean but NOT efficient. Maybe we can improve the efficiency later by making Request also slice-aware. What do you think?
torchstore/client.py
Outdated
| # but for now any should work. | ||
| fetched_tensor = await pipe.get_from_storage_volume(key, request) | ||
|
|
||
| # If user requested a specific slice, extract it |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought for now we assume request doesn't know such info and blindly fetches the whole tensor.
torchstore/client.py
Outdated
|
|
||
| # Handle in-place operation for tensor slice | ||
| if inplace_tensor is not None: | ||
| inplace_tensor.copy_(sliced_tensor) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good to learn (and I checked pytorch doc that copy_() returns self). Done.
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_tensor_slice_inplace(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Test added. Also factored out DTensorActor out from test_sharding.py into utils.py for reuse in test_store.py.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From my understanding, the main difference between dtensor and tensor slice is a tensor slice does not have a coordinate. iiuc this was a requirement at some point during put, but should not be required during gets.
I'm curious if we create the request from the tensor slice, if the implementation would be simplified since we could use the same code paths as we do for dtensor?
tests/test_store.py
Outdated
| if slice_spec is None: | ||
| return await ts.get(key) | ||
| else: | ||
| return await ts.get(key, tensor_slice_spec=slice_spec) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the future we'll want to raise as well if the tensor slice does not exist / is invalid
tests/test_store.py
Outdated
| volume_world_size, | ||
| TensorSliceGetActor, | ||
| "tensor_slice_get_actors", | ||
| world_size=volume_world_size, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why spawn volume_world_size actors if we're slicing to single node below?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make sense. This was mostly from copying existing test example, lol. It doesn't look necessary. Updated to a direct ts.get().
tests/test_store.py
Outdated
| # Initialize TorchStore with 2 storage volumes and LocalRankStrategy | ||
| from torchstore.strategy import LocalRankStrategy | ||
|
|
||
| await ts.initialize(num_storage_volumes=2, strategy=LocalRankStrategy()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: ts.LocalRankStrategy instead of import
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated.
tests/test_store.py
Outdated
| DTensorActor, | ||
| "dtensor_get_mesh", | ||
| mesh_shape=(2,), | ||
| original_tensor=torch.zeros(4, 4).float(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: -> torch.zeros_like(original_tensor)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated.
tests/test_store.py
Outdated
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_dtensor_simple_put_get(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm confused by this test. Why is this needed outside of test_resharding? As a side note it also does not test correctness (would also suggest changing the placement dim).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I misunderstood your comment in the last iteration. Updated this test to put a dtensor and just fetch it with get('key'). Also verified that the result matches original tensor.
torchstore/client.py
Outdated
| request = Request.from_any(inplace_tensor) | ||
|
|
||
| if tensor_slice_spec is not None and inplace_tensor is not None: | ||
| if tensor_slice_spec.local_shape != inplace_tensor.shape: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if inplace_tensor is a dtensor we should assert on offset as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added validation that tensor_slice should be None if inplace_tensor is a DTensor.
torchstore/client.py
Outdated
| logger.debug(f"Fetching {key}") | ||
| request = Request.from_any(inplace_tensor) | ||
|
|
||
| # When slicing, don't use inplace_tensor for the request because the transport |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think actually request already has the ability to fetch a single tensor slice... this is how the tensor_slice_request work below.
torchstore/client.py
Outdated
| # multinode support here | ||
| volume_map = await self._controller.locate_volumes.call_one(key) | ||
|
|
||
| if object_type in (ObjectType.OBJECT, ObjectType.TENSOR): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this if statement should only apply to ObjectType.OBJECT. Objects are the only items that are allowed to live on a single storage volume (so far).
In the past, the same was true about requesting tensors, since we assumed the only "DTensor" would request a tensor from sharded storage. Since we are allowing users to request arbitrary 'tensor slice' without a dtensor, this is not longer the case, meaning we should change this code to only account for ObjectType.OBJECT, and likely use the codepath below.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done with a major update. Please take a look.
|
Please take another look :) |
| async def get( | ||
| key: str, | ||
| inplace_tensor: Optional[torch.Tensor] = None, | ||
| tensor_slice_spec: Optional[TensorSlice] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please add docstring explaining the combinations which are allowed. e.g. "tensor_slice_spec + dtensor is not allowed"
| full_tensor = await self._get_distributed_whole_tensor(key) | ||
|
|
||
| if isinstance(inplace_tensor, DTensor): | ||
| request = Request.from_any(inplace_tensor) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: from Dtensor
| try: | ||
| # Store a test tensor | ||
| test_tensor = torch.randn(100, 200) | ||
| await ts.put("inplace_test", test_tensor) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a cool thing to add to readme / docs
|
|
||
| await put_mesh.do_put.call() | ||
|
|
||
| fetched_tensor = await ts.get("test_key") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yay!
| object_type = ObjectType.from_request(request) | ||
|
|
||
| # multinode support here | ||
| stored_object_type = await self._get_stored_object_type(key) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One additional control path query, which is probably fine but we do something similar in Pipe. (get_meta).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall really happy with the functionality of this PR! Tysm for going through the motions.
Since this is not pressing, I would like to spend a little more time thinking through how we can best encapsulate the respondsibilities of each object - for example potentially moving some of this logic out of client into pipe.
If this ends up being blocking we can revisit, but since we have time I'd rather see if we can make the right decisions upfront! Let's schedule some time this week to go over it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
approving since we need this fix
tensor_slice_specto theget()API so we can get a slice instead of the full tensor.get()with two actors and inplaceget()Next:
test_sharding.py)a. avoid constructing the whole global tensor
b. avoid unnecessary volumes without tensor slice of interest
c. batch tensor slice query into volumes