Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added tests/__init__.py
Empty file.
34 changes: 13 additions & 21 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,27 +10,23 @@
from monarch.actor import Actor, current_rank, endpoint
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.fsdp import fully_shard

from transformers import AutoModelForCausalLM

from torchstore import MultiProcessStore
from torchstore._state_dict_utils import get_state_dict, push_state_dict
from torchstore.utils import spawn_actors
from torchstore.logging import init_logging
from torchstore.utils import spawn_actors

from transformers import AutoModelForCausalLM

logger = getLogger(__name__)


assert os.environ.get("HF_TOKEN", None) is not None, "HF_TOKEN must be set"
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note, you shouldn't need this if you hf auth login.

TEST_MODEL = "Qwen/Qwen3-1.7B" # ~4GB
# TEST_MODEL = "meta-llama/Llama-3.1-8B" # ~ 16GB


class ModelTest(Actor):
def __init__(self, store, mesh_shape, file_store_name):
def __init__(self, mesh_shape, file_store_name):
init_logging()
self.rank = current_rank().rank
self.store = store
self.mesh_shape = mesh_shape
self.world_size = math.prod(mesh_shape)
self.file_store_name = file_store_name
Expand All @@ -50,9 +46,7 @@ def initialize_distributed(self):

def build_model(self):
self.rlog("building model")
model = AutoModelForCausalLM.from_pretrained(
TEST_MODEL, token=os.environ["HF_TOKEN"]
)
model = AutoModelForCausalLM.from_pretrained(TEST_MODEL)
if self.world_size > 1:
self.initialize_distributed()
self.rlog("sharding")
Expand Down Expand Up @@ -81,7 +75,7 @@ async def do_push(self):

self.rlog("pushing state dict")
t = time.perf_counter()
await push_state_dict(self.store, state_dict, "v0")
await push_state_dict(state_dict, "v0")
self.rlog(f"pushed state dict in {time.perf_counter()-t} seconds")

@endpoint
Expand All @@ -91,12 +85,12 @@ async def do_get(self):
"model": model.state_dict(),
"optimizer": optimizer.state_dict(),
}

if self.world_size > 1:
torch.distributed.barrier()
self.rlog("getting state dict")
t = time.perf_counter()
await get_state_dict(self.store, "v0", state_dict)
await get_state_dict("v0", state_dict)
self.rlog(f"got state dict in {time.perf_counter() - t} seconds")


Expand All @@ -115,39 +109,37 @@ async def test_resharding(self):

async def _do_test(self, put_mesh_shape, get_mesh_shape):
with tempfile.TemporaryDirectory() as tmpdir:
store = await MultiProcessStore.create_store()

put_world_size = math.prod(put_mesh_shape)
put_world = await spawn_actors(
put_world_size,
ModelTest,
"save_world",
store=store,
mesh_shape=put_mesh_shape,
file_store_name=os.path.join(tmpdir, "save_world"),
)
)

get_world_size = math.prod(get_mesh_shape)
get_world = await spawn_actors(
get_world_size,
ModelTest,
"get_world",
store=store,
mesh_shape=get_mesh_shape,
file_store_name=os.path.join(tmpdir, "get_world"),
)


logger.info("pushing state dict")
t = time.perf_counter()
await put_world.do_push.call()
logger.info(f"pushing state dict took: {time.perf_counter()-t} seconds")

logger.info("fetching state dict")
t = time.perf_counter()
await get_world.do_get.call()
logger.info(f"getting state dict took: {time.perf_counter()-t} seconds")

await put_world._proc_mesh.stop()
await get_world._proc_mesh.stop()


if __name__ == "__main__":
init_logging()
Expand Down
31 changes: 13 additions & 18 deletions tests/test_resharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,16 @@
import os
import tempfile
import unittest
import logging
import sys
from logging import getLogger

import torch

import torchstore as store

from monarch.actor import Actor, current_rank, endpoint
from torch.distributed._tensor import distribute_tensor, Replicate, Shard
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.tensor._utils import _compute_local_shape_and_global_offset

from torchstore import MultiProcessStore
from torchstore.utils import get_local_tensor, spawn_actors

logger = getLogger(__name__)
Expand All @@ -28,28 +26,25 @@ class DTensorActor(Actor):

def __init__(
self,
store,
mesh_shape,
original_tensor,
placements,
file_store_name,
visible_devices="0,1,2,3,4,5,6,7",
):
self.rank = current_rank().rank
self.store = store
self.mesh_shape = mesh_shape
self.world_size = math.prod(mesh_shape)
self.original_tensor = original_tensor
self.placements = placements
self.file_store_name = file_store_name


# this is only necessary for nccl, but we're not using it in this test.
os.environ["CUDA_VISIBLE_DEVICES"] = visible_devices

def rlog(self, msg):
# TODO: set to 'info' once this is fixed in monarch (which currently is hiding logs :/)
logger.info(f"rank: {self.rank} {msg}")
logger.warning(f"rank: {self.rank} {msg}")

def initialize_distributed(self):
self.rlog(f"Initialize process group using {self.file_store_name=} ")
Expand All @@ -76,7 +71,7 @@ async def do_put(self):
dtensor = distribute_tensor(tensor, device_mesh, placements=self.placements)

self.rlog(f"calling put with {dtensor=}")
await self.store.put(self.shared_key, dtensor)
await store.put(self.shared_key, dtensor)

@endpoint
async def do_get(self):
Expand All @@ -91,7 +86,7 @@ async def do_get(self):
dtensor = distribute_tensor(tensor, device_mesh, placements=self.placements)

self.rlog(f"calling get with {dtensor=}")
fetched_tensor = await self.store.get(self.shared_key, dtensor)
fetched_tensor = await store.get(self.shared_key, dtensor)
self.rlog(f"after fetch: {dtensor=}")
assert torch.equal(dtensor, fetched_tensor)

Expand Down Expand Up @@ -200,22 +195,22 @@ async def _test_resharding(
):
"""Given a "put" mesh shape and a "get" mesh shape.
1. Create separate worlds for each mesh shape, running on different devices /PGs.
2. Each rank in 'put' world will create a DTensor, and call self.store.put(key="test_key", value=dtensor)
3. Each rank in 'get' world will create a DTensor (with a different sharding, and seeded with torch.zero), and call self.store.get(key="test_key", value=dtensor)
2. Each rank in 'put' world will create a DTensor, and call store.put(key="test_key", value=dtensor)
3. Each rank in 'get' world will create a DTensor (with a different sharding, and seeded with torch.zero), and call store.get(key="test_key", value=dtensor)
4. The result of the above operation should be the original DTensor, but resharded between putter/getter worlds

Example:
#Our "put" world starts with something like this:
original_tensor = [0,1,2,3], world_size=4
dtensor = distribute_tensor(original_tensor)
# Rank0: dtensor._local_tensor == [0], Rank1: dtensor._local_tensor == [1], Rank2: dtensor._local_tensor == [2], ...
self.store.put("shared_key", dtensor)
store.put("shared_key", dtensor)

#Our "put" world starts with something like this:
original_Tensor = [0, 0, 0, 0], world_size=2
dtensor = distribute_tensor(original_tensor)
# Rank0: dtensor._local_tensor == [0,0], Rank1: dtensor._local_tensor == [0,0]
self.store.get("shared_key", dtensor)
store.get("shared_key", dtensor)

# Rank0: dtensor._local_tensor == [0,1], Rank1: dtensor._local_tensor == [2,3]
"""
Expand All @@ -238,7 +233,6 @@ async def _test_resharding(
original_tensor = torch.arange(8**2).reshape(
8, 8
) # 8x8 square, with ([[0...7],[8...15],[...]])
store = await MultiProcessStore.create_store()
with tempfile.TemporaryDirectory() as filesystem_store_dir:
# each actor mesh represents a group of processes.
# e.g., two different islands running spmd
Expand All @@ -251,7 +245,6 @@ async def _test_resharding(
"put_mesh",
original_tensor=original_tensor,
placements=put_placements,
store=store,
mesh_shape=put_mesh_shape,
file_store_name=os.path.join(filesystem_store_dir, "put_test"),
visible_devices=put_visible_devices,
Expand All @@ -267,10 +260,9 @@ async def _test_resharding(
DTensorActor,
"get_mesh",
original_tensor=torch.zeros(
8, 8, dtype=original_tensor.dtype
8, 8
), # these values get replaced with values from original_tensor after fetching
placements=get_placements,
store=store,
mesh_shape=get_mesh_shape,
file_store_name=os.path.join(filesystem_store_dir, "get_test"),
visible_devices=get_visible_devices,
Expand All @@ -289,6 +281,9 @@ async def _test_resharding(
await put_mesh.destroy_process_group.call()
await get_mesh.destroy_process_group.call()

await put_mesh._proc_mesh.stop()
await get_mesh._proc_mesh.stop()

def _assert_correct_sharded_tensor(
self, full_tensor, sharded_tensor, get_placements, coordinate
):
Expand Down
8 changes: 2 additions & 6 deletions tests/test_state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from torch.distributed.fsdp import fully_shard
from torch.distributed.tensor import DTensor

from torchstore import MultiProcessStore
from torchstore._state_dict_utils import get_state_dict, push_state_dict
from torchstore.utils import spawn_actors

Expand Down Expand Up @@ -160,7 +159,7 @@ async def test_state_dict(self):
class Trainer(Actor):
# Monarch RDMA does not work outside of an actor, so we need
# to wrapp this test first
#TODO: assert this within rdma buffer
# TODO: assert this within rdma buffer
@endpoint
async def do_test(self, store):
model = CompositeParamModel()
Expand All @@ -182,10 +181,8 @@ async def do_test(self, store):
return state_dict, fetched_state_dict

trainer = await spawn_actors(1, Trainer, "trainer")
store = await MultiProcessStore.create_store()
state_dict, fetched_state_dict = await trainer.do_test.call_one(store)
self._assert_equal_state_dict(state_dict, fetched_state_dict)


async def test_dcp_sharding_parity(self):
for save_mesh_shape, get_mesh_shape in [
Expand All @@ -198,7 +195,6 @@ async def test_dcp_sharding_parity(self):
save_world_size = math.prod(save_mesh_shape)
get_world_size = math.prod(get_mesh_shape)

store = await MultiProcessStore.create_store()
with tempfile.TemporaryDirectory() as tmpdir:
dcp_checkpoint_fn = os.path.join(tmpdir, "dcp_checkpoint.pt")

Expand Down Expand Up @@ -232,7 +228,7 @@ async def test_dcp_sharding_parity(self):
except Exception as e:
raise AssertionError(
f"Assertion failed on rank {coord.rank} ({save_mesh_shape=} {get_mesh_shape=}): {e}"
) from e
) from e

def _assert_equal_state_dict(self, state_dict1, state_dict2):
flattened_state_dict_1, _ = flatten_state_dict(state_dict1)
Expand Down
Loading