Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 88 additions & 7 deletions autoparallel/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from torch.distributed.fsdp import MixedPrecisionPolicy
from torch.distributed.tensor import DeviceMesh
from torch.export._trace import _restore_state_dict
from torch.export._unlift import _assign_attr
from torch.export.unflatten import _AttrKind
from torch.fx.experimental.symbolic_shapes import ShapeEnv

Expand All @@ -52,6 +51,66 @@
_APPLY_VIEW_MM_VIEW_PATTERN = False


def _assign_attr(
attr: Any,
target_module: torch.nn.Module,
fqn: str,
attr_kind: _AttrKind,
ref_module: Optional[torch.nn.Module] = None,
):
"""
Custom version of torch.export._unlift._assign_attr that preserves the original
module structure (e.g., nn.ModuleDict) from ref_module.

Args:
attr: The attribute to assign (parameter/buffer/module)
target_module: The module to assign the attribute to
fqn: Fully qualified name of the attribute (e.g., "layers.0.weight")
attr_kind: Type of attribute (PARAMETER, BUFFER, etc.)
ref_module: Reference module to check for original structure (optional)
"""
*prefix, field = fqn.split(".")

# Navigate to the parent module, creating submodules as needed
curr_mod = target_module
for i, attr_name in enumerate(prefix):
if not hasattr(curr_mod, attr_name):
# Check if we should create a ModuleDict or regular Module
if ref_module is not None:
# Navigate to the same location in ref_module
ref_curr_mod = ref_module
for ref_attr_name in prefix[:i]:
if hasattr(ref_curr_mod, ref_attr_name):
ref_curr_mod = getattr(ref_curr_mod, ref_attr_name)
else:
ref_curr_mod = None # type: ignore[assignment]
break

# Check if the next level should be a ModuleDict
if ref_curr_mod is not None and hasattr(ref_curr_mod, attr_name):
ref_submod = getattr(ref_curr_mod, attr_name)
if isinstance(ref_submod, torch.nn.ModuleDict):
setattr(curr_mod, attr_name, torch.nn.ModuleDict())
else:
setattr(curr_mod, attr_name, torch.nn.Module())
else:
setattr(curr_mod, attr_name, torch.nn.Module())
else:
setattr(curr_mod, attr_name, torch.nn.Module())
Comment on lines +91 to +99
Copy link
Contributor

@fmassa fmassa Dec 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we would want to keep the whole original class structure around (maybe with a nn.Module subclass indicating that the class has been AutoParallelized).
Something like

cls = type(ref_submod)
new_inst = ref_submod.__new__(cls)
new_inst.__dict__ = ref_submod.__dict__.copy()
setattr(curr_mod, attr_name, new_inst)

or if we want a subclass

cls = type(ref_submod)
new_cls = type(f"AutoP[{cls.__name__}]", (cls,), ref_submod.__dict__.copy())
new_inst = new_cls.__new__(new_cls)
new_inst.__dict__ = ref_submod.__dict__.copy()
setattr(curr_mod, attr_name, new_inst)

(but we need to cache those new classes to avoid creating too many redundant classes maybe?)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i can give it a try


curr_mod = getattr(curr_mod, attr_name)

# Set the final attribute
if attr_kind == _AttrKind.PARAMETER:
assert isinstance(attr, torch.nn.Parameter)
curr_mod.register_parameter(field, attr)
elif attr_kind == _AttrKind.BUFFER:
assert isinstance(attr, torch.Tensor)
curr_mod.register_buffer(field, attr)
else:
setattr(curr_mod, field, attr)


def _get_decomp_table():
decomp_table = copy.copy(select_decomp_table())
# TODO: removing those as they cause missing DTensor propagation rules
Expand Down Expand Up @@ -549,11 +608,24 @@ def _register_params_and_init_weights(
# We construct an unflattened structure on parallel_mod,
# e.g. _assign_attr(v, parallel_model, k="layers.0.weight") will literally
# create empty nn.Modules recursively and then stash 'v' so it shows up in the right spot
# We pass self.model as reference to preserve the original module structure (e.g., nn.ModuleDict)
for k, v in sharded_param_dict.items():
_assign_attr(v, self.parallel_model, k, attr_kind=_AttrKind.PARAMETER)
_assign_attr(
v,
self.parallel_model,
k,
attr_kind=_AttrKind.PARAMETER,
ref_module=self.model,
)

for k, v in sharded_buffer_dict.items():
_assign_attr(v, self.parallel_model, k, attr_kind=_AttrKind.BUFFER)
_assign_attr(
v,
self.parallel_model,
k,
attr_kind=_AttrKind.BUFFER,
ref_module=self.model,
)

# Right now we require a convention that the user model provides an init_weights method,
# although we could snoop for other methods too.
Expand Down Expand Up @@ -620,9 +692,12 @@ def __init__(
sharded_param_dict: dict[str, torch.nn.Parameter],
sharded_buffer_dict: dict[str, torch.Tensor],
init_weights_model: torch.nn.Module,
ref_model: torch.nn.Module,
):
super().__init__()
self._register_params_and_buffers(sharded_param_dict, sharded_buffer_dict)
self._register_params_and_buffers(
sharded_param_dict, sharded_buffer_dict, ref_model
)

# Right now we require a convention that the user model provides an init_weights method,
# although we could snoop for other methods too.
Expand All @@ -638,16 +713,21 @@ def init_weights(_self, *args, **kwargs):
# but with our new DTensor sharded params attached to the user module.
self.init_weights = MethodType(init_weights, self)

def _register_params_and_buffers(self, sharded_param_dict, sharded_buffer_dict):
def _register_params_and_buffers(
self, sharded_param_dict, sharded_buffer_dict, ref_model
):

# We construct an unflattened structure on parallel_mod,
# e.g. _assign_attr(v, parallel_model, k="layers.0.weight") will literally
# create empty nn.Modules recursively and then stash 'v' so it shows up in the right spot
# We pass ref_model to preserve the original module structure (e.g., nn.ModuleDict)
for k, v in sharded_param_dict.items():
_assign_attr(v, self, k, attr_kind=_AttrKind.PARAMETER)
_assign_attr(
v, self, k, attr_kind=_AttrKind.PARAMETER, ref_module=ref_model
)

for k, v in sharded_buffer_dict.items():
_assign_attr(v, self, k, attr_kind=_AttrKind.BUFFER)
_assign_attr(v, self, k, attr_kind=_AttrKind.BUFFER, ref_module=ref_model)

def forward(self, *args):
raise NotImplementedError("This is a placeholder for the pipeline model")
Expand Down Expand Up @@ -828,6 +908,7 @@ def apply_placement_pp(
sharded_param_dict,
sharded_buffer_dict,
self.init_weights_model,
self.model,
)
return {
"graph_callables": graph_modules,
Expand Down
59 changes: 59 additions & 0 deletions tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,3 +305,62 @@ def input_fn():
# %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%primals_1, 10), kwargs = {})
# %add_2 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_1, 30), kwargs = {})
# return ((add, add_2), (tangents_1, None))


def test_moduledict_preservation(device_mesh_1d):
"""Test that nn.ModuleDict structure is preserved during _assign_attr."""
dim = 128

class Model(nn.Module):
def __init__(self, dim):
super().__init__()
# Create a ModuleDict to test preservation
self.layers = nn.ModuleDict(
{
"layer1": nn.Linear(dim, dim),
"layer2": nn.Linear(dim, dim),
}
)

def forward(self, x):
x = self.layers["layer1"](x)
x = self.layers["layer2"](x)
return x

def input_fn():
b = 512
inputs = (torch.rand(b, dim, device="cuda"),)
return inputs

with torch.device("meta"):
model = Model(dim)

# Verify original model has ModuleDict
assert isinstance(model.layers, nn.ModuleDict)

with AutoParallel(
model,
input_fn,
device_mesh_1d,
) as autop:
x_sharding = (Shard(0),)
autop.add_input_constraints([x_sharding])
sharding_placement = autop.optimize_placement()

# AutoParallel produces a module with meta-DTensor parameters that need to be initialized
parallel_mod = autop.apply_placement(sharding_placement)

# Verify that the parallel_mod preserves the ModuleDict structure
assert isinstance(
parallel_mod.layers, nn.ModuleDict
), f"Expected nn.ModuleDict but got {type(parallel_mod.layers)}"

# Verify that the ModuleDict contains the expected layers
assert "layer1" in parallel_mod.layers
assert "layer2" in parallel_mod.layers
assert isinstance(parallel_mod.layers["layer1"], nn.Module)
assert isinstance(parallel_mod.layers["layer2"], nn.Module)

# Verify parameters are accessible through the ModuleDict structure
assert hasattr(parallel_mod.layers["layer1"], "weight")
assert hasattr(parallel_mod.layers["layer2"], "weight")
Loading