diff --git a/jax/interpreters/xla.py b/jax/interpreters/xla.py index 38a89dc62def..ac5dadb3cd27 100644 --- a/jax/interpreters/xla.py +++ b/jax/interpreters/xla.py @@ -16,7 +16,7 @@ from collections import defaultdict, deque, namedtuple import itertools as it import operator as op -from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Set, Type, Tuple +from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Type, Tuple from warnings import warn from absl import logging