Skip to content

Commit

Permalink
[feat] support namedtuple in container.py (#1069)
Browse files Browse the repository at this point in the history
  • Loading branch information
min-xu-ai committed Sep 13, 2022
1 parent 73bf596 commit eeb6684
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 5 deletions.
15 changes: 12 additions & 3 deletions fairscale/internal/containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# LICENSE file in the root directory of this source tree.

from collections import OrderedDict
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Set, Tuple, Union, cast

import numpy as np
import torch
Expand All @@ -14,7 +14,7 @@


def apply_to_type(
type_fn: Callable, fn: Callable, container: Union[torch.Tensor, np.ndarray, Dict, List, Tuple, Set]
type_fn: Callable, fn: Callable, container: Union[torch.Tensor, np.ndarray, Dict, List, Tuple, Set, NamedTuple]
) -> Any:
"""Recursively apply to all objects in different kinds of container types that matches a type function."""

Expand All @@ -34,7 +34,16 @@ def _apply(x: Union[torch.Tensor, np.ndarray, Dict, List, Tuple, Set]) -> Any:
elif isinstance(x, list):
return [_apply(x) for x in x]
elif isinstance(x, tuple):
return tuple(_apply(x) for x in x)
f = getattr(x, "_fields", None)
if f is None:
return tuple(_apply(x) for x in x)
else:
assert isinstance(f, tuple), "This needs to be a namedtuple"
# convert the namedtuple to a dict and _apply().
x = cast(NamedTuple, x)
_dict: Dict[str, Any] = x._asdict()
_dict = {key: _apply(value) for key, value in _dict.items()}
return type(x)(**_dict) # make a copy of the namedtuple
elif isinstance(x, set):
return {_apply(x) for x in x}
else:
Expand Down
12 changes: 10 additions & 2 deletions tests/utils/test_containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

""" Test utility classes from containers.py. """

from collections import OrderedDict
from collections import OrderedDict, namedtuple
import random

import pytest
Expand Down Expand Up @@ -42,13 +42,21 @@ def get_a_tensor():
return t

# create a mixed bag of data.
data = [1, "str"]
data = [1, "str"] # list
# dict
data.append({"key1": get_a_tensor(), "key2": {1: get_a_tensor()}, "key3": 3})
# set
data.insert(0, set(["x", get_a_tensor(), get_a_tensor()]))
# tuple
data.append(([1], get_a_tensor(), 1, [get_a_tensor()], set((1, 2))))
# OrderedDict
od = OrderedDict()
od["k"] = "value"
data.append(od)
# namedtuple
NT = namedtuple("NT", ["key1", "key2"])
nt = NT(key1=1, key2=get_a_tensor())
data.append(nt)

total = 0

Expand Down

0 comments on commit eeb6684

Please sign in to comment.