Skip to content

Commit

Permalink
[dynamo] support dict.update(seq2) / OrderedDict.update(seq2) / `…
Browse files Browse the repository at this point in the history
…defaultdict.update(seq2)` (pytorch#115011)

Pull Request resolved: pytorch#115011
Approved by: https://github.com/jansel
ghstack dependencies: pytorch#115010
  • Loading branch information
XuehaiPan authored and dmenig committed Dec 21, 2023
1 parent ab11d29 commit 455c576
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 4 deletions.
19 changes: 19 additions & 0 deletions test/dynamo/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -745,6 +745,9 @@ def fn(x):
dd["a"] = x + 1
dd[param] = 123
dd["c"] = x * 2
dd.update({"b": x * 3})
dd.update([["d", x - 2], ("e", x + 2)])
dd.update(zip("ab", [x + 3, x + 4]))
return dd["b"], dd

x = torch.randn(10, 10)
Expand All @@ -754,7 +757,10 @@ def fn(x):

self.assertTrue(same(ref[0], res[0]))
self.assertTrue(same(ref[1]["a"], res[1]["a"]))
self.assertTrue(same(ref[1]["b"], res[1]["b"]))
self.assertTrue(same(ref[1]["c"], res[1]["c"]))
self.assertTrue(same(ref[1]["d"], res[1]["d"]))
self.assertTrue(same(ref[1]["e"], res[1]["e"]))
self.assertTrue(same(ref[1][param], res[1][param]))

@make_test
Expand Down Expand Up @@ -811,6 +817,19 @@ def test_dict_fromkeys(x, y):
d3 = collections.OrderedDict.fromkeys(tuple(lst), value=y)
return d1["a"] * d2["b"] + d2["a"] + d1["b"] + d3["a"] + d3["b"] + 1

@make_test
def test_dict_update(x, y, z):
d = {"a": x, "b": y}
d.update({"a": y - 1})
d.update([("b", z + 1), ["c", z]])
d.update(zip("ab", [z + 3, y + 2]))

od = collections.OrderedDict(a=x * 3, b=y + 2)
od.update({"a": y + 5})
od.update([["b", z + 6], ("c", z - 7)])
od.update(zip("ab", [z - 3, x + 2]))
return d["a"] * od["a"] + od["c"] + d["b"] + od["b"] * d["c"]

@make_test
def test_min_max(a, b):
c = a + b
Expand Down
35 changes: 31 additions & 4 deletions torch/_dynamo/variables/dicts.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,17 @@ def call_method(
args: "List[VariableTracker]",
kwargs: "Dict[str, VariableTracker]",
) -> "VariableTracker":
from . import ConstantVariable, TupleVariable
from . import (
ConstantVariable,
ListIteratorVariable,
ListVariable,
TupleVariable,
)

val = self.items

if name == "__getitem__":
assert len(args) == 1
return self.getitem_const(args[0])

elif name == "items":
Expand Down Expand Up @@ -139,10 +145,9 @@ def call_method(
)
elif (
name in ("pop", "get")
and args
and len(args) == 2
and ConstDictVariable.is_valid_key(args[0])
and ConstDictVariable.get_key(args[0]) not in self.items
and len(args) == 2
):
# missing item, return the default value
return args[1]
Expand All @@ -158,12 +163,34 @@ def call_method(
return result
elif (
name == "update"
and args
and len(args) == 1
and isinstance(args[0], ConstDictVariable)
and self.mutable_local
):
newval = dict(val)
newval.update(args[0].items)
newval.update(kwargs) # all keys in kwargs are valid (`str`s)
result = self.modifed(newval)
return tx.replace_all(self, result)
elif (
name == "update"
and len(args) == 1
and isinstance(
args[0],
(
ListVariable,
TupleVariable,
ListIteratorVariable,
),
)
and self.mutable_local
):
newval = dict(val)
for x in args[0].unpack_var_sequence(tx):
k, v = x.unpack_var_sequence(tx)
assert ConstDictVariable.is_valid_key(k)
newval[ConstDictVariable.get_key(k)] = v
newval.update(kwargs) # all keys in kwargs are valid (`str`s)
result = self.modifed(newval)
return tx.replace_all(self, result)
elif (
Expand Down

0 comments on commit 455c576

Please sign in to comment.