From 9e2b9e7812477e6bebc85be3abe19af9b73340e7 Mon Sep 17 00:00:00 2001 From: Alexey Radul Date: Thu, 8 Jun 2023 16:58:31 -0400 Subject: [PATCH] Fix mypy error. --- jax/_src/interpreters/batching.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/jax/_src/interpreters/batching.py b/jax/_src/interpreters/batching.py index 2c67230c324c..87f3b021ae40 100644 --- a/jax/_src/interpreters/batching.py +++ b/jax/_src/interpreters/batching.py @@ -15,8 +15,8 @@ import dataclasses from functools import partial -from typing import (Any, Callable, Dict, Iterable, Optional, Sequence, Set, - Tuple, Type, Union) +from typing import (Any, Callable, Dict, Iterable, List, Optional, + Sequence, Set, Tuple, Type, Union) import numpy as np @@ -115,7 +115,7 @@ class RaggedAxis: # For each axis, we store its index and the corresponding segment lengths. # For example, the pile i:(Fin 3) => f32[lens1.i, 7, lens2.i] # would be represented with ragged_axes = [(1, lens1), (3, lens2)] - ragged_axes: [(int, Array)] + ragged_axes: List[Tuple[int, Array]] @property def size(self): @@ -134,7 +134,9 @@ def move_axis(ax): new_ragged_axes = [(move_axis(ax), sizes) for ax, sizes in self.ragged_axes] return RaggedAxis(dst, new_ragged_axes) -def make_batch_axis(ndim, stacked_axis, ragged_axes): +def make_batch_axis( + ndim: int, stacked_axis: int, ragged_axes: List[Tuple[int, Array]] + ) -> Union[int, RaggedAxis]: if ragged_axes: canonical = [(canonicalize_axis(ax, ndim), sz) for ax, sz in ragged_axes] return RaggedAxis(canonicalize_axis(stacked_axis, ndim), canonical) @@ -205,7 +207,9 @@ def to_elt(trace: Trace, get_idx: GetIdx, x: Vmappable, spec: MapSpec) -> Elt: elif type(x) is Pile: if spec is not pile_axis: raise TypeError("pile input without using pile_axis in_axes spec") - (d, ias), = ((i, sz) for i, sz in enumerate(x.aval.elt_ty.shape) + ias: IndexedAxisSize # Not present in the AxisSize union in core.py + (d, ias), = ((i, sz) # type: ignore + for i, sz in enumerate(x.aval.elt_ty.shape) if type(sz) is IndexedAxisSize) batch_axis = make_batch_axis(x.data.ndim, 0, [(d+1, ias.lengths)]) return BatchTracer(trace, x.data, batch_axis) # type: ignore