diff --git a/torchrec/ir/serializer.py b/torchrec/ir/serializer.py index 71b66dae8..abc950013 100644 --- a/torchrec/ir/serializer.py +++ b/torchrec/ir/serializer.py @@ -9,7 +9,7 @@ import json import logging -from typing import Any, Dict, Optional, Type +from typing import Any, Dict, List, Optional, Tuple, Type import torch @@ -69,8 +69,18 @@ def get_deserialized_device( return device -class JsonSerializerBase(SerializerInterface): +class JsonSerializer(SerializerInterface): + """ + Serializer for torch.export IR using json. + """ + + module_to_serializer_cls: Dict[str, Type["JsonSerializer"]] = {} _module_cls: Optional[Type[nn.Module]] = None + _children: Optional[List[str]] = None + + @classmethod + def children(cls, module: nn.Module) -> List[str]: + return [] if not cls._children else cls._children @classmethod def serialize_to_dict(cls, module: nn.Module) -> Dict[str, Any]: @@ -81,6 +91,7 @@ def deserialize_from_dict( cls, metadata_dict: Dict[str, Any], device: Optional[torch.device] = None, + unflatten: Optional[nn.Module] = None, ) -> nn.Module: raise NotImplementedError() @@ -88,40 +99,59 @@ def deserialize_from_dict( def serialize( cls, module: nn.Module, - ) -> torch.Tensor: - if cls._module_cls is None: + ) -> Tuple[torch.Tensor, List[str]]: + typename = type(module).__name__ + serializer = cls.module_to_serializer_cls.get(typename) + if serializer is None: raise ValueError( - "Must assign a nn.Module to class static variable _module_cls" + f"Expected typename to be one of {list(cls.module_to_serializer_cls.keys())}, got {typename}" ) - if not isinstance(module, cls._module_cls): + assert issubclass(serializer, JsonSerializer) + assert serializer._module_cls is not None + if not isinstance(module, serializer._module_cls): raise ValueError( - f"Expected module to be of type {cls._module_cls.__name__}, got {type(module)}" + f"Expected module to be of type {serializer._module_cls.__name__}, " + f"got {type(module)}" ) - metadata_dict = cls.serialize_to_dict(module) - return torch.frombuffer(json.dumps(metadata_dict).encode(), dtype=torch.uint8) + metadata_dict = serializer.serialize_to_dict(module) + raw_dict = {"typename": typename, "metadata_dict": metadata_dict} + serialized_tensor = torch.frombuffer( + json.dumps(raw_dict).encode(), dtype=torch.uint8 + ) + return serialized_tensor, serializer.children(module) @classmethod def deserialize( cls, input: torch.Tensor, - typename: str, device: Optional[torch.device] = None, + unflatten: Optional[nn.Module] = None, ) -> nn.Module: raw_bytes = input.numpy().tobytes() - metadata_dict = json.loads(raw_bytes.decode()) - module = cls.deserialize_from_dict(metadata_dict, device) - if cls._module_cls is None: + raw_dict = json.loads(raw_bytes.decode()) + typename = raw_dict["typename"] + if typename not in cls.module_to_serializer_cls: + raise ValueError( + f"Expected typename to be one of {list(cls.module_to_serializer_cls.keys())}, got {typename}" + ) + serializer = cls.module_to_serializer_cls[typename] + assert issubclass(serializer, JsonSerializer) + module = serializer.deserialize_from_dict( + raw_dict["metadata_dict"], device, unflatten + ) + + if serializer._module_cls is None: raise ValueError( "Must assign a nn.Module to class static variable _module_cls" ) - if not isinstance(module, cls._module_cls): + if not isinstance(module, serializer._module_cls): raise ValueError( - f"Expected module to be of type {cls._module_cls.__name__}, got {type(module)}" + f"Expected module to be of type {serializer._module_cls.__name__}, got {type(module)}" ) return module -class EBCJsonSerializer(JsonSerializerBase): +class EBCJsonSerializer(JsonSerializer): _module_cls = EmbeddingBagCollection @classmethod @@ -148,6 +178,7 @@ def deserialize_from_dict( cls, metadata_dict: Dict[str, Any], device: Optional[torch.device] = None, + unflatten: Optional[nn.Module] = None, ) -> nn.Module: tables = [ EmbeddingBagConfigMetadata(**table_config) @@ -164,40 +195,4 @@ def deserialize_from_dict( ) -class JsonSerializer(SerializerInterface): - """ - Serializer for torch.export IR using json. - """ - - module_to_serializer_cls: Dict[str, Type[SerializerInterface]] = { - "EmbeddingBagCollection": EBCJsonSerializer, - } - - @classmethod - def serialize( - cls, - module: nn.Module, - ) -> torch.Tensor: - typename = type(module).__name__ - if typename not in cls.module_to_serializer_cls: - raise ValueError( - f"Expected typename to be one of {list(cls.module_to_serializer_cls.keys())}, got {typename}" - ) - - return cls.module_to_serializer_cls[typename].serialize(module) - - @classmethod - def deserialize( - cls, - input: torch.Tensor, - typename: str, - device: Optional[torch.device] = None, - ) -> nn.Module: - if typename not in cls.module_to_serializer_cls: - raise ValueError( - f"Expected typename to be one of {list(cls.module_to_serializer_cls.keys())}, got {typename}" - ) - - return cls.module_to_serializer_cls[typename].deserialize( - input, typename, device - ) +JsonSerializer.module_to_serializer_cls["EmbeddingBagCollection"] = EBCJsonSerializer diff --git a/torchrec/ir/tests/test_serializer.py b/torchrec/ir/tests/test_serializer.py index f27ae0d94..6f6b62cc3 100644 --- a/torchrec/ir/tests/test_serializer.py +++ b/torchrec/ir/tests/test_serializer.py @@ -11,7 +11,7 @@ import copy import unittest -from typing import Callable, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union import torch from torch import nn @@ -54,6 +54,41 @@ def forward(self, features: KeyedJaggedTensor) -> List[torch.Tensor]: return res +class CompoundModuleSerializer(JsonSerializer): + _module_cls = CompoundModule + + @classmethod + def children(cls, module: nn.Module) -> List[str]: + children = ["ebc", "list"] + if module.comp is not None: + children += ["comp"] + return children + + @classmethod + def serialize_to_dict( + cls, + module: nn.Module, + ) -> Dict[str, Any]: + return {} + + @classmethod + def deserialize_from_dict( + cls, + metadata_dict: Dict[str, Any], + device: Optional[torch.device] = None, + unflatten: Optional[nn.Module] = None, + ) -> nn.Module: + assert unflatten is not None + ebc = unflatten.ebc + comp = getattr(unflatten, "comp", None) + i = 0 + mlist = [] + while hasattr(unflatten.list, str(i)): + mlist.append(getattr(unflatten.list, str(i))) + i += 1 + return CompoundModule(ebc, comp, mlist) + + class TestJsonSerializer(unittest.TestCase): def generate_model(self) -> nn.Module: class Model(nn.Module): @@ -328,6 +363,9 @@ def forward(self, features: KeyedJaggedTensor) -> List[torch.Tensor]: eager_out = model(id_list_features) + JsonSerializer.module_to_serializer_cls["CompoundModule"] = ( + CompoundModuleSerializer + ) # Serialize model, sparse_fqns = serialize_embedding_modules(model, JsonSerializer) ep = torch.export.export( @@ -346,6 +384,14 @@ def forward(self, features: KeyedJaggedTensor) -> List[torch.Tensor]: # Deserialize deserialized_model = deserialize_embedding_modules(ep, JsonSerializer) + + # Check if Compound Module is deserialized correctly + self.assertIsInstance(deserialized_model.comp, CompoundModule) + self.assertIsInstance(deserialized_model.comp.comp, CompoundModule) + self.assertIsInstance(deserialized_model.comp.comp.comp, CompoundModule) + self.assertIsInstance(deserialized_model.comp.list[1], CompoundModule) + self.assertIsInstance(deserialized_model.comp.list[1].comp, CompoundModule) + deserialized_model.load_state_dict(model.state_dict()) # Run forward on deserialized model deserialized_out = deserialized_model(id_list_features) diff --git a/torchrec/ir/types.py b/torchrec/ir/types.py index 397a4aa98..15ef430f3 100644 --- a/torchrec/ir/types.py +++ b/torchrec/ir/types.py @@ -10,7 +10,7 @@ #!/usr/bin/env python3 import abc -from typing import Any, Dict, Optional, Type +from typing import Any, Dict, List, Optional, Tuple import torch @@ -24,28 +24,25 @@ class SerializerInterface(abc.ABC): @classmethod @property - # pyre-ignore [3]: Returning `None` but type `Any` is specified. - def module_to_serializer_cls(cls) -> Dict[str, Type[Any]]: + def module_to_serializer_cls(cls) -> Dict[str, Any]: raise NotImplementedError @classmethod @abc.abstractmethod - # pyre-ignore [3]: Returning `None` but type `Any` is specified. def serialize( cls, module: nn.Module, - ) -> Any: + ) -> Tuple[torch.Tensor, List[str]]: # Take the eager embedding module and generate bytes in buffer - pass + raise NotImplementedError @classmethod @abc.abstractmethod def deserialize( cls, - # pyre-ignore [2]: Parameter `input` must have a type other than `Any`. - input: Any, - typename: str, + input: torch.Tensor, device: Optional[torch.device] = None, + unflatten: Optional[nn.Module] = None, ) -> nn.Module: # Take the bytes in the buffer and regenerate the eager embedding module - pass + raise NotImplementedError diff --git a/torchrec/ir/utils.py b/torchrec/ir/utils.py index 952d702b7..f4df1cdff 100644 --- a/torchrec/ir/utils.py +++ b/torchrec/ir/utils.py @@ -27,8 +27,9 @@ def serialize_embedding_modules( - model: nn.Module, + module: nn.Module, serializer_cls: Type[SerializerInterface] = DEFAULT_SERIALIZER_CLS, + fqn: str = "", ) -> Tuple[nn.Module, List[str]]: """ Takes all the modules that are of type `serializer_cls` and serializes them @@ -37,13 +38,46 @@ def serialize_embedding_modules( Returns the modified module and the list of fqns that had the buffer added. """ preserve_fqns = [] - for fqn, module in model.named_modules(): - if type(module).__name__ in serializer_cls.module_to_serializer_cls: - serialized_module = serializer_cls.serialize(module) - module.register_buffer("ir_metadata", serialized_module, persistent=False) - preserve_fqns.append(fqn) - return model, preserve_fqns + # handle current module + if type(module).__name__ in serializer_cls.module_to_serializer_cls: + serialized_tensor, children = serializer_cls.serialize(module) + module.register_buffer("ir_metadata", serialized_tensor, persistent=False) + preserve_fqns.append(fqn) + else: + children = [child for child, _ in module.named_children()] + + # handle child modules + for child in children: + submodule = module.get_submodule(child) + child_fqn = f"{fqn}.{child}" if len(fqn) > 0 else child + preserve_fqns.extend( + serialize_embedding_modules(submodule, serializer_cls, child_fqn)[1] + ) + return module, preserve_fqns + + +def _deserialize_embedding_modules( + module: nn.Module, + serializer_cls: Type[SerializerInterface], + device: Optional[torch.device] = None, +) -> nn.Module: + """ + returns: + 1. the children of the parent_fqn Dict[relative_fqn -> module] + 2. the next node Optional[fqn], Optional[module], which is not a child of the parent_fqn + """ + + for child_fqn, child in module.named_children(): + child = _deserialize_embedding_modules( + module=child, serializer_cls=serializer_cls, device=device + ) + setattr(module, child_fqn, child) + + if "ir_metadata" in dict(module.named_buffers()): + serialized_tensor = module.get_buffer("ir_metadata") + module = serializer_cls.deserialize(serialized_tensor, device, module) + return module def deserialize_embedding_modules( @@ -59,39 +93,7 @@ def deserialize_embedding_modules( Returns the unflattened ExportedProgram with the deserialized modules. """ model = torch.export.unflatten(ep) - module_type_dict = {} - for node in ep.graph.nodes: - if "nn_module_stack" in node.meta: - for fqn, type_name in node.meta["nn_module_stack"].values(): - # Only get the module type name, not the full type name - module_type_dict[fqn] = type_name.split(".")[-1] - - fqn_to_new_module = {} - for fqn, module in model.named_modules(): - if "ir_metadata" in dict(module.named_buffers()): - serialized_module = dict(module.named_buffers())["ir_metadata"] - - if fqn not in module_type_dict: - raise RuntimeError( - f"Cannot find the type of module {fqn} in the exported program" - ) - - deserialized_module = serializer_cls.deserialize( - serialized_module, - module_type_dict[fqn], - device, - ) - fqn_to_new_module[fqn] = deserialized_module - - for fqn, new_module in fqn_to_new_module.items(): - # handle nested attribute like "x.y.z" - attrs = fqn.split(".") - parent = model - for a in attrs[:-1]: - parent = getattr(parent, a) - setattr(parent, attrs[-1], new_module) - - return model + return _deserialize_embedding_modules(model, serializer_cls, device) def _get_dim(x: Union[DIM, str, None], s: str, max: Optional[int] = None) -> DIM: