Skip to content

Commit

Permalink
Fix Array import in segment_utils
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 611536788
  • Loading branch information
rjagerman authored and Rax Developers committed Feb 29, 2024
1 parent 073798e commit 3537d8f
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion rax/_src/segment_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
import jax
import jax.numpy as jnp

from rax._src.types import Array
from rax._src import types

Array = types.Array


def same_segment_mask(segments: Array) -> Array:
Expand Down

0 comments on commit 3537d8f

Please sign in to comment.