Skip to content

Commit

Permalink
A change to type checking in checkpoint loading.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 465616127
  • Loading branch information
henrykmichalewski authored and Copybara-Service committed Aug 5, 2022
1 parent 03a1b45 commit 99e9784
Showing 1 changed file with 10 additions and 2 deletions.
12 changes: 10 additions & 2 deletions trax/fastmath/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,8 +236,16 @@ def tree_unflatten(flat, tree, copy_from_tree=None):
of tree but with leaves from flat, and the remaining elements of flat if
more were provided than the number of leaves of tree (useful for recursion).
"""
if copy_from_tree is not None and tree in copy_from_tree:
return tree, flat
if copy_from_tree is not None:
for el in copy_from_tree:
# Equality checks comparing a DeviceArray with other Python objects
# may legitimately raise a TypeError.
try:
if tree == el:
return tree, flat
except TypeError:
continue

if isinstance(tree, (list, tuple)):
new_tree, rest = [], flat
for t in tree:
Expand Down

0 comments on commit 99e9784

Please sign in to comment.