Skip to content

Commit

Permalink
Remove some dead code (followup to #16159)
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed May 28, 2023
1 parent ae9160a commit 96e18d5
Showing 1 changed file with 1 addition and 13 deletions.
14 changes: 1 addition & 13 deletions jax/_src/interpreters/xla.py
Expand Up @@ -23,7 +23,7 @@
import operator
import re
from typing import (Any, Callable, Dict, Optional, Protocol,
Sequence, Set, Type, Tuple, Union, TYPE_CHECKING)
Sequence, Set, Type, Tuple, Union)

import numpy as np

Expand Down Expand Up @@ -330,9 +330,6 @@ def jaxpr_collectives(jaxpr):

### xla_call underlying jit

# TODO(yashkatariya): Remove after 1 month from March 23, 2023.
xla_call_p: core.CallPrimitive = core.CallPrimitive('xla_call')


def xla_call_partial_eval_update_params(
params: core.ParamDict, kept_inputs: Sequence[bool], num_new_inputs: int
Expand Down Expand Up @@ -445,12 +442,3 @@ def __missing__(self, key):

backend_specific_translations: Dict[str, _TranslationRuleAdapter]
backend_specific_translations = _BackendSpecificTranslationsAdapter()


if TYPE_CHECKING:
DeviceArray = Any
else:
class DeviceArray(object):
def __init__(self):
raise RuntimeError("DeviceArray is a backward compatibility shim "
"and cannot be instantiated.")

0 comments on commit 96e18d5

Please sign in to comment.