# Unit-aware Computation with ``CustomArray``

[![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/chaobrain/saiunit/blob/master/docs/advanced_tutorials/custom_array.ipynb)
[![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/chaobrain/saiunit/blob/master/docs/advanced_tutorials/custom_array.ipynb)

## Introduction

The ``CustomArray`` class in saiunit provides a practical foundation for creating unit-aware computational arrays that maintain dimensional consistency throughout complex calculations. This tutorial shows how to use ``CustomArray`` to build array types that automatically handle units, enabling safer and more maintainable scientific computing.

### What Is Unit-aware Computation?

Unit-aware computation keeps physical quantities dimensionally correct across operations. Typical rules:
- Adding meters to meters results in meters
- Multiplying meters by meters results in square meters
- Dividing distance by time results in velocity
- Invalid operations (e.g., meters + seconds) are detected and raise errors

### Why Use CustomArray?

- Type safety: Prevents dimensional errors at runtime
- Automatic unit propagation through operations
- Works with NumPy and JAX arrays (and supports PyTorch-like methods)
- Extensible: create domain-specific array types (physics, neuroscience, etc.)
- Minimal overhead compared to raw arrays

In [1]:
# Imports
import saiunit as u
import brainstate

print("saiunit version:", getattr(u, '__version__', 'unknown'))
print('Sample units:', 'm, s, Hz, V, A, kg, N, Pa, J')

saiunit version: 0.1.0
Sample units: m, s, Hz, V, A, kg, N, Pa, J


## CustomArray Architecture

``CustomArray`` is a base class. Any class that inherits from it and provides a ``.value`` attribute automatically gains rich array behavior and unit-aware math via ``saiunit.math``.

Core requirements:
1. Inherit from ``u.CustomArray``
2. Store your underlying data (with units) in ``self.value``

Benefits:
- Separation of concerns: you focus on data/state, ``CustomArray`` handles array ops
- Unit propagation: math operations keep correct units
- Backend flexibility: ``self.value`` can be NumPy, JAX, or other array-likes

In [2]:
# A minimal, practical CustomArray
class MyArray(u.CustomArray):
    """Minimal unit-aware array: just store a `.value`."""
    def __init__(self, value):
        self.value = value  # typically a saiunit Quantity or plain array
    def __repr__(self):
        return f'MyArray({self.value})'

# Create an instance with units
length = MyArray([1, 2, 3] * u.meter)
length, length.shape, getattr(length.value, 'unit', 'unitless')

(MyArray(ArrayImpl([1, 2, 3], dtype=int32) * meter), (3,), meter)

## Unit Propagation with Operators

When ``.value`` is a ``Quantity``, standard operations automatically keep or change units correctly.

In [3]:
# Compatible addition keeps units
length_cm = MyArray([100, 200, 300] * u.cmeter)
total_length = length + length_cm  # meters + centimeters -> meters
print('total_length:', total_length)

total_length: ArrayImpl([2., 4., 6.], dtype=float32) * meter


In [4]:
# Multiplication changes units (area)
area = length * length  # m * m -> m^2
print('area:', area)

area: ArrayImpl([1, 4, 9], dtype=int32) * meter2


In [5]:
# Division changes units (velocity)
time = MyArray([1, 2, 3] * u.second)
velocity = length / time  # m / s
print('velocity:', velocity)

velocity: ArrayImpl([1., 1., 1.], dtype=float32) * meter / second


In [6]:
# Incompatible addition raises an error
try:
    bad = length + time
except Exception as e:
    print('Expected error:', e)

Expected error: Cannot calculate 
ArrayImpl([1, 2, 3], dtype=int32) * meter + ArrayImpl([1, 2, 3], dtype=int32) * second, because units do not match: m != s


## Using ``saiunit.math`` with CustomArray

The ``saiunit.math`` module mirrors NumPy/JAX APIs and is unit-aware. All functions accept ``CustomArray`` instances: internally, saiunit extracts ``.value`` via helper utilities and returns quantities with correct units.

Categories (simplified):
- Keep-unit functions (e.g., ``mean``, ``sum``, ``concatenate``, ``stack``) return the same unit
- Change-unit functions (e.g., ``square``, ``sqrt``, ``multiply``, ``divide``, ``var``) transform units according to math rules
- Some functions require unitless inputs (e.g., ``round``, ``floor``)

In [7]:
# Keep-unit examples
print('mean(length):', u.math.mean(length))
print('sum(length):', u.math.sum(length))

# Change-unit examples
print('square(length):', u.math.square(length))  # m^2
print('sqrt(square(length)):', u.math.sqrt(u.math.square(length)))  # back to m
print('var(length):', u.math.var(length))  # m^2

# Broadcasting and stacking
stacked = u.math.stack([length, length_cm])
print('stacked shape:', getattr(stacked, 'shape', None))

# Linear algebra with units
force = MyArray([10, 20, 30] * u.newton)
displacement = MyArray([0.5, 1.0, 1.5] * u.meter)
work = u.math.dot(force, displacement)  # N·m -> J (joule)
print('work (dot):', work)

mean(length): 2. * meter
sum(length): 6 * meter
square(length): ArrayImpl([1, 4, 9], dtype=int32) * meter2
sqrt(square(length)): ArrayImpl([1., 2., 3.], dtype=float32) * meter2 ** 0.5
var(length): 0.6666667 * meter2
stacked shape: (2, 3)
work (dot): 70. * joule


## Converting Units for Display or Interop

Use ``Quantity.to_decimal(target_unit)`` to get values in a desired unit scale for display, logging, or plotting.

In [8]:
# Convert quantity values inside your CustomArray for display
meters = MyArray([1, 2, 3] * u.meter)
print('as meters:', meters.value.to_decimal(u.meter))
print('as centimeters:', meters.value.to_decimal(u.cmeter))

as meters: [1 2 3]
as centimeters: [100. 200. 300.]


## Stateful and Learnable Arrays (BrainState)

For stateful workflows, combine ``brainstate.State`` with ``CustomArray`` to create learnable, unit-aware parameters.

In [9]:
class StatefulArray(brainstate.State, u.CustomArray):
    pass

# Example: a learnable parameter with units
param = StatefulArray(0.1 * u.second)
print('stateful param:', param)

stateful param: StatefulArray(
  value=~float32[] * second
)


## Robust Patterns and Error Handling

Tips:
- Document expected units for each array (e.g., meters for length)
- Validate inputs when building domain-specific types
- Catch and surface unit mismatch errors with clear messages
- Prefer ``saiunit.math`` over raw NumPy for unit-aware operations

## Summary

- Inherit from ``u.CustomArray`` and set ``self.value`` (often a ``Quantity``)
- Use operators and ``saiunit.math`` to get automatic unit propagation
- Convert units with ``Quantity.to_decimal`` for display or interop
- Combine with BrainState to build stateful, unit-aware components

With these patterns, you can build reliable, unit-safe computational workflows across NumPy and JAX backends.