In [None]:
# default_exp patched

# Patched

> We patch some library and objects that don't belong to SAX. Don't worry, it's nothing substantial.

In [None]:
# hide
import matplotlib.pyplot as plt
from fastcore.test import test_eq
from pytest import approx, raises
import jax.numpy as jnp

import os, sys; sys.stderr = open(os.devnull, "w")

In [None]:
# export
from __future__ import annotations

import re
from fastcore.basics import patch_to
from flax.core import FrozenDict
from jaxlib.xla_extension import DeviceArray

from sax.typing_ import is_complex_float, is_float
from textwrap import dedent

Paching `FrozenDict` to have the same repr as a normal dict:

In [None]:
# exporti
@patch_to(FrozenDict)
def __repr__(self):  # type: ignore
    _dict = lambda d: dict(
        {k: (v if not isinstance(v, self.__class__) else dict(v)) for k, v in d.items()}
    )
    return f"{self.__class__.__name__}({dict.__repr__(_dict(self))})"

Patching `DeviceArray` to have less verbose reprs for 0-D arrays:

In [None]:
# exporti
@patch_to(DeviceArray)
def __repr__(self):  # type: ignore
    if self.ndim == 0 and is_float(self):
        v = float(self)
        return repr(round(v, 5)) if abs(v) > 1e-4 else repr(v)
    elif self.ndim == 0 and is_complex_float(self):
        r, i = float(self.real), float(self.imag)
        r = round(r, 5) if abs(r) > 1e-4 else r
        i = round(i, 5) if abs(i) > 1e-4 else i
        s = repr(r + 1j * i)
        if s[0] == "(" and s[-1] == ")":
            s = s[1:-1]
        return s
    else:
        s = super(self.__class__, self).__repr__()
        s = s.replace("DeviceArray(", "      array(")
        s = re.sub(r", dtype=.*[,)]", "", s)
        s = re.sub(r" weak_type=.*[,)]", "", s)
        return dedent(s)+")"

In [None]:
jnp.array(3)

In [None]:
jnp.array([3, 4, 5])