From 3537d8fca5b4115c7f992f8e326ed963828a491f Mon Sep 17 00:00:00 2001 From: Rolf Jagerman Date: Thu, 29 Feb 2024 11:22:04 -0800 Subject: [PATCH] Fix Array import in segment_utils PiperOrigin-RevId: 611536788 --- rax/_src/segment_utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/rax/_src/segment_utils.py b/rax/_src/segment_utils.py index 15e1d10..b228ef8 100644 --- a/rax/_src/segment_utils.py +++ b/rax/_src/segment_utils.py @@ -19,7 +19,9 @@ import jax import jax.numpy as jnp -from rax._src.types import Array +from rax._src import types + +Array = types.Array def same_segment_mask(segments: Array) -> Array: