Skip to content

Commit

Permalink
Fix mypy error.
Browse files Browse the repository at this point in the history
  • Loading branch information
axch committed Jun 8, 2023
1 parent effaf67 commit 9e2b9e7
Showing 1 changed file with 9 additions and 5 deletions.
14 changes: 9 additions & 5 deletions jax/_src/interpreters/batching.py
Expand Up @@ -15,8 +15,8 @@

import dataclasses
from functools import partial
from typing import (Any, Callable, Dict, Iterable, Optional, Sequence, Set,
Tuple, Type, Union)
from typing import (Any, Callable, Dict, Iterable, List, Optional,
Sequence, Set, Tuple, Type, Union)

import numpy as np

Expand Down Expand Up @@ -115,7 +115,7 @@ class RaggedAxis:
# For each axis, we store its index and the corresponding segment lengths.
# For example, the pile i:(Fin 3) => f32[lens1.i, 7, lens2.i]
# would be represented with ragged_axes = [(1, lens1), (3, lens2)]
ragged_axes: [(int, Array)]
ragged_axes: List[Tuple[int, Array]]

@property
def size(self):
Expand All @@ -134,7 +134,9 @@ def move_axis(ax):
new_ragged_axes = [(move_axis(ax), sizes) for ax, sizes in self.ragged_axes]
return RaggedAxis(dst, new_ragged_axes)

def make_batch_axis(ndim, stacked_axis, ragged_axes):
def make_batch_axis(
ndim: int, stacked_axis: int, ragged_axes: List[Tuple[int, Array]]
) -> Union[int, RaggedAxis]:
if ragged_axes:
canonical = [(canonicalize_axis(ax, ndim), sz) for ax, sz in ragged_axes]
return RaggedAxis(canonicalize_axis(stacked_axis, ndim), canonical)
Expand Down Expand Up @@ -205,7 +207,9 @@ def to_elt(trace: Trace, get_idx: GetIdx, x: Vmappable, spec: MapSpec) -> Elt:
elif type(x) is Pile:
if spec is not pile_axis:
raise TypeError("pile input without using pile_axis in_axes spec")
(d, ias), = ((i, sz) for i, sz in enumerate(x.aval.elt_ty.shape)
ias: IndexedAxisSize # Not present in the AxisSize union in core.py
(d, ias), = ((i, sz) # type: ignore
for i, sz in enumerate(x.aval.elt_ty.shape)
if type(sz) is IndexedAxisSize)
batch_axis = make_batch_axis(x.data.ndim, 0, [(d+1, ias.lengths)])
return BatchTracer(trace, x.data, batch_axis) # type: ignore
Expand Down

0 comments on commit 9e2b9e7

Please sign in to comment.