In [65]:
import jax.numpy as jnp
import numpy as np

In [66]:
def data_processor(data):
    print(f"{data}")
    print(f"{type(data)}")


In [67]:
class NpToJaxDrainRing:
    """
    Collect rows in a fast mutable NumPy ring buffer.
    On drain(), return a jnp.array (oldest→newest) and reset to empty.
    """
    def __init__(self, capacity=200, cols=4, dtype=np.float64):
        self.capacity = capacity
        self.cols = cols
        self._data = np.zeros((capacity, cols), dtype=dtype)
        self._start = 0      # index of oldest row
        self._size = 0       # number of valid rows

    def add(self, row):
        """Append one row; overwrite oldest if full."""
        r = np.asarray(row)
        if r.shape != (self.cols,):
            raise ValueError(f"expected shape ({self.cols},), got {r.shape}")
        end = (self._start + self._size) % self.capacity
        self._data[end] = r
        if self._size < self.capacity:
            self._size += 1
        else:
            self._start = (self._start + 1) % self.capacity

    def extend(self, rows):
        for r in rows:
            self.add(r)

    def _chron_view(self):
        if self._size == 0:
            return self._data[0:0]
        s = self._start
        e = (self._start + self._size) % self.capacity
        if s < e:
            return self._data[s:e]
        return np.vstack((self._data[s:], self._data[:e]))

    def drain(self):
        """
        Return all rows as a jnp.array (oldest→newest) and reset buffer.
        Note: this converts/copies to device as needed.
        """
        out_np = self._chron_view().copy()
        # reset
        self._start = 0
        self._size = 0
        # optional: zero out data to avoid holding old values
        # self._data[:] = 0
        return jnp.array(out_np)
    
    def __call__(self):
        self.drain()

    # Subscriptable / iterable (chronological)
    def __len__(self): return self._size
    def __getitem__(self, idx): return self._chron_view()[idx]
    def __iter__(self): return iter(self._chron_view())
    def __repr__(self):
        return f"NpToJaxDrainRing(size={self._size}, cap={self.capacity}, data=\n{self._chron_view()}\n)"


In [68]:
data1 = NpToJaxDrainRing()

In [69]:
data_processor(data1.drain())

[]
<class 'jaxlib._jax.ArrayImpl'>


In [70]:
data1.add([1,2,3,4])
data1.add([442,2432,42,4])
data1.add([9,2,42,4])


In [73]:
print(f"{data1=}")
print(f"{data1=}")
print(f"{len(data1)=}")
print(f"{data1[0]=}")
print(f"{len(data1)=}")


data1=NpToJaxDrainRing(size=3, cap=200, data=
[[1.000e+00 2.000e+00 3.000e+00 4.000e+00]
 [4.420e+02 2.432e+03 4.200e+01 4.000e+00]
 [9.000e+00 2.000e+00 4.200e+01 4.000e+00]]
)
data1=NpToJaxDrainRing(size=3, cap=200, data=
[[1.000e+00 2.000e+00 3.000e+00 4.000e+00]
 [4.420e+02 2.432e+03 4.200e+01 4.000e+00]
 [9.000e+00 2.000e+00 4.200e+01 4.000e+00]]
)
len(data1)=3
data1[0]=array([1., 2., 3., 4.])
len(data1)=3


In [62]:
data_processor(data1.drain())

[[1.000e+00 2.000e+00 3.000e+00 4.000e+00]
 [4.420e+02 2.432e+03 4.200e+01 4.000e+00]
 [9.000e+00 2.000e+00 4.200e+01 4.000e+00]]
<class 'jaxlib._jax.ArrayImpl'>


In [63]:
data1[0:3,:]

array([], shape=(0, 4), dtype=float64)