New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Compare treedefs by num_leaves
not traversal_
in tree_transpose
.
#3659
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would suggest that @hawkinsp take a look at the changes in pytree.cc
const auto& root = traversal_.back(); | ||
const auto& inner_root = inner.traversal_.back(); | ||
auto& out_root = out->traversal_.back(); | ||
out_root.num_nodes = (root.num_nodes - root.num_leaves) + |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can see that this was buggy before, but I am uneasy about this particular fix. First, this mutates in place the nodes. Won't this alter outer_treedef? I suspect that the rest of this module was written with the assumption that the Nodes are immutable. Second, this fix restores the invariant for num_nodes
and num_leaves
only for one node; shouldn't it be restored for all the interior nodes in outer_treedef?
I am not very familiar with the internals of this module. Is it true that num_nodes
and num_leaves
are needed only for the root of the treedef. Maybe then it makes sense to keep these fields in the treedef, instead of the Node? That would address both issues since now we do not need to mutate Nodes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Re mutating in place afaik out->traversal_.push_back(n)
in the loop above will make a copy of the node, so we're only mutating the copy.
Re restoring for all interior nodes, I think you're right. In the cases I was looking carefully at we have a very simple traversal (kLeaf, .., kLeaf, kCustom
) but for a deeply nested tree we would need to recompute the full traversal I think. I'll take a look later this week at it, but I'm surprised our tests pass as is?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIRC the num_nodes
and num_leaves
values are to allow Children
to reconstruct the tree structure. So you'd need to call Compose
and then Children
to see the bug.
jax/tree_util.py
Outdated
expected_treedef = outer_treedef.compose(inner_treedef) | ||
if treedef != expected_treedef: | ||
raise TypeError("Mismatch\n{}\n != \n{}".format(treedef, expected_treedef)) | ||
"""Produces a pytree with the same data but a different structure. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am proposing that we add some documentation to tree_transpose (I did not know what it did)
inner_size = inner_treedef.num_leaves | ||
outer_size = outer_treedef.num_leaves | ||
if treedef.num_leaves != (inner_size * outer_size): | ||
expected_treedef = outer_treedef.compose(inner_treedef) | ||
raise TypeError(f"Mismatch\n{treedef}\n != \n{expected_treedef}") | ||
flat = iter(flat) | ||
lol = [[next(flat) for _ in range(inner_size)] for __ in range(outer_size)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was wondering why can't this fucntion be implemented as:
flat, treedef = tree_flatten(pytree_to_transpose)
new_treedef = inner_treedef.compose(outer_treedef)
return tree_unflatten(new_treedef, flat)
Perhaps I do not fully understand its semantics.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(Perhaps the original tree_transpose
implementation predates the compose
operator on treedefs?)
I thought that might work, so I tried it but got a failure on APITest.test_jacobian_on_pytrees
. I believe the flat order must change.
In general for a `kCustom` node it is not guaranteed that `a.compose(b)` will have the same `traversal_` as some structure `c` (which is the composition of `a+b`). We have a real world example in deepmind/dm-haiku with our FlatMapping type and I've put a simpler example in `tree_util_tests.py`. Since this test seems largely to be for input validation I've changed this to compute the expected number of leaves (which is cheaper than using compose as the previous implementation did) which will catch common errors and is guaranteed to work for any well formed pytree (additionally I had to fix the leaf and node count for composed pytrees which were wrong at HEAD).
8c00cd8
to
08e459c
Compare
Hey both, sorry I have not find the time to address your comments. I think the PR leaves the codebase in a strictly better state (fixing one bug in tree_transpose) and will unblock a PR in Haiku that should improve performance for many users. Mind if we leave the lack of composition between Transpose and Children as a known issue for now? (I don't expect anyone to actually hit this with the current API). I have left a TODO and will try and find time to follow up. |
@tomhennigan That plan sounds fine to me, though I'd defer to any input @hawkinsp offers. |
In general for a
kCustom
node it is not guaranteed thata.compose(b)
willhave the same
traversal_
as some structurec
(which is the composition ofa+b
). We have a real world example in deepmind/dm-haiku with our FlatMappingtype and I've put a simpler example in
tree_util_tests.py
.Since this test seems largely to be for input validation I've changed this to
compute the expected number of leaves (which is cheaper than using compose as
the previous implementation did) which will catch common errors and is
guaranteed to work for any well formed pytree (additionally I had to fix the
leaf and node count for composed pytrees which were wrong at HEAD).