Skip to content

Commit

Permalink
change default dtype to jax.numpy
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 committed Mar 28, 2024
1 parent 3a803df commit af859f4
Showing 1 changed file with 13 additions and 13 deletions.
26 changes: 13 additions & 13 deletions braincore/environ.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,37 +297,37 @@ def set_gpu_preallocation(mode: Union[float, bool]):
@functools.lru_cache()
def _get_uint(precision: int):
if precision == 64:
return np.uint64
return jnp.uint64
elif precision == 32:
return np.uint32
return jnp.uint32
elif precision == 16:
return np.uint16
return jnp.uint16
elif precision == 8:
return np.uint8
return jnp.uint8
else:
raise ValueError(f'Unsupported precision: {precision}')


@functools.lru_cache()
def _get_int(precision: int):
if precision == 64:
return np.int64
return jnp.int64
elif precision == 32:
return np.int32
return jnp.int32
elif precision == 16:
return np.int16
return jnp.int16
elif precision == 8:
return np.int8
return jnp.int8
else:
raise ValueError(f'Unsupported precision: {precision}')


@functools.lru_cache()
def _get_float(precision: int):
if precision == 64:
return np.float64
return jnp.float64
elif precision == 32:
return np.float32
return jnp.float32
elif precision == 16:
return jnp.bfloat16
# return np.float16
Expand All @@ -338,11 +338,11 @@ def _get_float(precision: int):
@functools.lru_cache()
def _get_complex(precision: int):
if precision == 64:
return np.complex128
return jnp.complex128
elif precision == 32:
return np.complex64
return jnp.complex64
elif precision == 16:
return np.complex32
return jnp.complex32
else:
raise ValueError(f'Unsupported precision: {precision}')

Expand Down

0 comments on commit af859f4

Please sign in to comment.