Skip to content

Commit

Permalink
ops/numpy.py: support key as list in GetItem
Browse files Browse the repository at this point in the history
When loading a model that contains GetItem nodes with multidimensional indices/slices as key, the key argument is loaded from JSON as a list, not a tuple (because JSON does not have the distinction).
So, treat the key list as equivalent to the key tuple. Copying is important, otherwise, the later pop() will remove the bound slice elements.
  • Loading branch information
tvogel committed Mar 12, 2024
1 parent 2c96829 commit 7e8532e
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions keras/ops/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2695,6 +2695,8 @@ def compute_output_spec(self, x, key):
remaining_key = [key]
elif isinstance(key, tuple):
remaining_key = list(key)
elif isinstance(key, list):
remaining_key = key.copy()
else:
raise ValueError(
f"Unsupported key type for array slice. Recieved: `{key}`"
Expand Down

1 comment on commit 7e8532e

@tvogel
Copy link
Contributor Author

@tvogel tvogel commented on 7e8532e Mar 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please take a look, I had trouble loading a model that did things like this:

    kernel_output = keras.layers.Concatenate()(
      [ 
        kernel(input[:, 2*other:2*other+2] - input[:, 2*ref:2*ref+2]) 
        for ref in range(n)
        for other in range(n) 
        if other != ref 
      ])

Please sign in to comment.