Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/unit_test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,6 @@ jobs:
run: python -m pip install -e ".[dev]"
- name: Run unit tests with coverage
# TorchStore RDMA will not run on CPU-only machines, resharding tests are too slow.
run: TORCHSTORE_RDMA_ENABLED=0 pytest tests/ --ignore=tests/test_resharding.py -k "not test_large_tensors and not test_put_dtensor_get_full_tensor" --cov=. --cov-report=xml --durations=20 -vv
run: TORCHSTORE_RDMA_ENABLED=0 pytest tests/ --ignore=tests/test_resharding.py -k "not test_large_tensors and not test_put_dtensor_get_full_tensor and not test_delete" --cov=. --cov-report=xml --durations=20 -vv
- name: Upload Coverage to Codecov
uses: codecov/codecov-action@v3
81 changes: 81 additions & 0 deletions tests/test_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,87 @@ async def exists(self, key):
await ts.shutdown()


@pytest.mark.parametrize(*transport_plus_strategy_params())
@pytest.mark.asyncio
async def test_delete(strategy_params, use_rdma):
"""Test the delete() API functionality"""
os.environ["TORCHSTORE_RDMA_ENABLED"] = "1" if use_rdma else "0"

class DeleteTestActor(Actor):
"""Actor for testing delete functionality."""

def __init__(self, world_size):
init_logging()
self.world_size = world_size
self.rank = current_rank().rank
# required by LocalRankStrategy
os.environ["LOCAL_RANK"] = str(self.rank)

@endpoint
async def put(self, key, value):
await ts.put(key, value)

@endpoint
async def delete(self, key):
await ts.delete(key)

@endpoint
async def exists(self, key):
return await ts.exists(key)

@endpoint
async def get(self, key):
return await ts.get(key)

volume_world_size, strategy = strategy_params
await ts.initialize(num_storage_volumes=volume_world_size, strategy=strategy)

# Spawn test actors
actor_mesh = await spawn_actors(
volume_world_size,
DeleteTestActor,
"delete_test_actors",
world_size=volume_world_size,
)

try:
# Test 1: Store tensors, verify they exist, then delete them
tensor = torch.tensor([1, 2, 3, 4, 5])
for rank in range(volume_world_size):
actor = actor_mesh.slice(gpus=rank)
await actor.put.call(f"tensor_key_{rank}", tensor)

# Verify all tensors exist
for rank in range(volume_world_size):
results = await actor_mesh.exists.call(f"tensor_key_{rank}")
for _, exists_result in results:
assert exists_result

# Delete tensors one at a time and verify each deletion
for rank in range(volume_world_size):
actor = actor_mesh.slice(gpus=rank)
await actor.delete.call(f"tensor_key_{rank}")

# Verify this specific tensor no longer exists
results = await actor_mesh.exists.call(f"tensor_key_{rank}")
for _, exists_result in results:
assert not exists_result

# Verify other tensors still exist (if any remain)
for other_rank in range(rank + 1, volume_world_size):
results = await actor_mesh.exists.call(f"tensor_key_{other_rank}")
for _, exists_result in results:
assert exists_result

# Test 2: Try to get deleted tensor (should raise exception)
with pytest.raises(Exception):
await actor_mesh.get.call("tensor_key_0")

finally:
await actor_mesh._proc_mesh.stop()
await ts.shutdown()


@pytest.mark.parametrize(*transport_plus_strategy_params())
@pytest.mark.asyncio
async def test_get_tensor_slice(strategy_params, use_rdma):
Expand Down
2 changes: 2 additions & 0 deletions torchstore/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from torchstore.api import (
client,
delete,
exists,
get,
get_state_dict,
Expand Down Expand Up @@ -43,6 +44,7 @@
"init_logging",
"put",
"get",
"delete",
"keys",
"exists",
"client",
Expand Down
20 changes: 20 additions & 0 deletions torchstore/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,26 @@ async def get(
return await cl.get(key, inplace_tensor, tensor_slice_spec)


async def delete(
key: str,
*,
store_name: str = DEFAULT_TORCHSTORE_NAME,
) -> None:
"""Delete a key from the distributed store.

Args:
key (str): Unique identifier of the value to delete.

Keyword Args:
store_name (str): Name of the store to use. Defaults to DEFAULT_TORCHSTORE_NAME.

Example:
>>> await delete("my_tensor")
"""
cl = await client(store_name=store_name)
return await cl.delete(key)


async def keys(
prefix: str | None = None,
) -> List[str]:
Expand Down
30 changes: 30 additions & 0 deletions torchstore/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import asyncio
from logging import getLogger
from typing import Any, Union

Expand Down Expand Up @@ -118,6 +119,35 @@ async def keys(self, prefix: str | None = None) -> list[str]:
# Keys are synced across all storage volumes, so we just call one.
return await self._controller.keys.call_one(prefix)

async def delete(self, key: str) -> None:
"""
Delete a key from the distributed store.

Args:
key (str): The key to delete.

Returns:
None

Raises:
KeyError: If the key does not exist in the store.
"""
latency_tracker = LatencyTracker(f"delete:{key}")
volume_map = await self._controller.locate_volumes.call_one(key)

async def delete_from_volume(volume_id: str):
volume = self.strategy.get_storage_volume(volume_id)
# Notify should come before the actual delete, so that the controller
# doesn't think the key is still in the store when delete is happening.
await self._controller.notify_delete.call_one(key, volume_id)
await volume.delete.call(key)

await asyncio.gather(
*[delete_from_volume(volume_id) for volume_id in volume_map]
)

latency_tracker.track_e2e()

async def exists(self, key: str) -> bool:
"""Check if a key exists in the distributed store.

Expand Down
19 changes: 19 additions & 0 deletions torchstore/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,5 +168,24 @@ def keys(self, prefix=None) -> List[str]:
return list(self.keys_to_storage_volumes.keys())
return self.keys_to_storage_volumes.keys().filter_by_prefix(prefix)

@endpoint
def notify_delete(self, key: str, storage_volume_id: str) -> None:
"""
Notify the controller that deletion of data is initiated in a storage volume.

This should called after a successful delete operation to
maintain the distributed storage index.
"""
self.assert_initialized()
if key not in self.keys_to_storage_volumes:
raise KeyError(f"Unable to locate {key} in any storage volumes.")
if storage_volume_id not in self.keys_to_storage_volumes[key]:
raise KeyError(
f"Unable to locate {key} in storage volume {storage_volume_id}."
)
del self.keys_to_storage_volumes[key][storage_volume_id]
if len(self.keys_to_storage_volumes[key]) == 0:
del self.keys_to_storage_volumes[key]

def get_keys_to_storage_volumes(self) -> Mapping[str, Dict[str, StorageInfo]]:
return self.keys_to_storage_volumes
13 changes: 13 additions & 0 deletions torchstore/storage_volume.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ async def get(
) -> TransportBuffer:
return await self.store.get(key, transport_buffer, request)

@endpoint
async def delete(self, key: str) -> None:
await self.store.delete(key)

@endpoint
async def get_meta(self, key: str) -> Union[Tuple[torch.Size, torch.dtype], str]:
return await self.store.get_meta(key)
Expand All @@ -88,6 +92,10 @@ async def get_meta(self, key: str) -> Union[Tuple[torch.Size, torch.dtype], str]
"""Get metadata about stored data."""
raise NotImplementedError()

async def delete(self, key: str) -> None:
"""Delete data from the storage backend."""
raise NotImplementedError()


class InMemoryStore(StorageImpl):
"""Local in memory storage."""
Expand Down Expand Up @@ -234,3 +242,8 @@ async def get_meta(self, key: str) -> Union[Tuple[torch.Size, torch.dtype], str]
return val["tensor"].shape, val["tensor"].dtype

raise RuntimeError(f"Unknown type for {key} type={type(val)}")

async def delete(self, key: str) -> None:
if key not in self.kv:
raise KeyError(f"Key '{key}' not found. {list(self.kv.keys())=}")
del self.kv[key]