In [1]:
from dataclasses import dataclass
from functools import lru_cache
from typing import Tuple, Union
import numpy as np
from fvhoe.array_management import VARIABLE_IDX_MAP

In [2]:
@dataclass
class HydroState:
    """
    HydroState class for managing hydrodynamic variables and passive scalars.
        rho: density
        vx, vy, vz: velocity components
        mx, my, mz: momentum components
        P: pressure
        E: total energy
        your_passive_scalar1, your_passive_scalar2, ...: passive scalars
    args:
        passive_scalars (tuple) : tuple of passive scalar names
        ndim (int) : number of dimensions
    """

    passive_scalars: tuple = ()
    ndim: int = 4

    def __post_init__(self):
        self.variable_map = {
            "rho": 0,
            "vx": 1,
            "mx": 1,
            "vy": 2,
            "my": 2,
            "vz": 3,
            "mz": 3,
            "v": np.arange(1, 4),
            "m": np.arange(1, 4),
            "P": 4,
            "E": 4,
        }
        self.passive_scalars = set(self.passive_scalars)
        _passive_scalar_idxs = []

        if self.passive_scalars:
            for i, scalar in enumerate(self.passive_scalars, start=5):
                self.variable_map[scalar] = i
            self.variable_map["passive_scalars"] = np.arange(
                5, 5 + len(self.passive_scalars)
            )
            self.includes_passives = True
        else:
            self.includes_passives = False

    def __hash__(self):
        return id(self)

    def __call__(
        self,
        var: Union[str, Tuple[str]] = None,
        x: Tuple[int, int] = None,
        y: Tuple[int, int] = None,
        z: Tuple[int, int] = None,
        axis: int = None,
        cut: Tuple[int, int] = None,
        step: int = None,
    ) -> Union[Tuple[slice], slice]:
        """
        Get the slice for the given variable and coordinates.
        args:
            var (str) : variable name or tuple of variable names. if None, all variables are selected
            x (Tuple[int, int]) : x-coordinate slice. if None, all x-coordinates are selected
            y (Tuple[int, int]) : y-coordinate slice. if None, all y-coordinates are selected
            z (Tuple[int, int]) : z-coordinate slice. if None, all z-coordinates are selected
            axis (int) : axis to cut, alternative to x, y, z
            cut (Tuple[int, int]) : slice along dimension specified by axis. ignored if axis is None
            step (int) : step size for the slice. ignored if axis is None
        returns:
            Tuple[slice] : slices for the given variable and coordinates with length equal to ndim.
                if ndim is 1, a single slice is returned
        """
        slices = [slice(None)] * self.ndim

        if var is not None:
            if isinstance(var, str):
                # retrieve single variable index
                if var not in self.variable_map:
                    raise ValueError(f"Variable '{var}' not found.")
                slices[0] = self.variable_map[var]
            elif isinstance(var, tuple):
                # retrieve multiple variable indices
                missing_vars = set(var) - set(self.variable_map.keys())
                if missing_vars:
                    raise ValueError(f"Variables not found: {missing_vars}")
                slices[0] = np.array(list(map(self.variable_map.get, var)))
            else:
                raise ValueError(f"Invalid type for var: {type(var)}")

        axes = [1, 2, 3, axis]
        axis_slices = [x, y, z, cut]
        for i, axis_slice in zip(axes, axis_slices):
            if axis_slice is not None:
                if i >= self.ndim:
                    raise ValueError(
                        f"Invalid axis {i} for array with {self.ndim} dimensions."
                    )
                if not isinstance(axis_slice, tuple):
                    raise ValueError(
                        f"Expected a tuple (start, stop) for axis {i}, got {axis_slice} of type {type(axis_slice)}"
                    )
                if len(axis_slice) != 2:
                    raise ValueError(
                        f"Invalid tuple length for axis {i}: {len(axis_slice)}"
                    )
                slices[i] = slice(
                    axis_slice[0] or None,
                    axis_slice[1] or None,
                    step if i == axis else None,
                )

        if len(slices) == 1:
            return slices[0]
        return tuple(slices)

In [3]:
@dataclass
class CachedHydroState:
    """
    HydroState class for managing hydrodynamic variables and passive scalars.
        rho: density
        vx, vy, vz: velocity components
        mx, my, mz: momentum components
        P: pressure
        E: total energy
        your_passive_scalar1, your_passive_scalar2, ...: passive scalars
    args:
        passive_scalars (tuple) : tuple of passive scalar names
        ndim (int) : number of dimensions
    """

    passive_scalars: tuple = ()
    ndim: int = 4

    def __post_init__(self):
        self.variable_map = {
            "rho": 0,
            "vx": 1,
            "mx": 1,
            "vy": 2,
            "my": 2,
            "vz": 3,
            "mz": 3,
            "v": np.arange(1, 4),
            "m": np.arange(1, 4),
            "P": 4,
            "E": 4,
        }
        self.passive_scalars = set(self.passive_scalars)
        _passive_scalar_idxs = []

        if self.passive_scalars:
            for i, scalar in enumerate(self.passive_scalars, start=5):
                self.variable_map[scalar] = i
            self.variable_map["passive_scalars"] = np.arange(
                5, 5 + len(self.passive_scalars)
            )
            self.includes_passives = True
        else:
            self.includes_passives = False

    def __hash__(self):
        return id(self)

    @lru_cache(maxsize=None)
    def __call__(
        self,
        var: Union[str, Tuple[str]] = None,
        x: Tuple[int, int] = None,
        y: Tuple[int, int] = None,
        z: Tuple[int, int] = None,
        axis: int = None,
        cut: Tuple[int, int] = None,
        step: int = None,
    ) -> Union[Tuple[slice], slice]:
        """
        Get the slice for the given variable and coordinates.
        args:
            var (str) : variable name or tuple of variable names. if None, all variables are selected
            x (Tuple[int, int]) : x-coordinate slice. if None, all x-coordinates are selected
            y (Tuple[int, int]) : y-coordinate slice. if None, all y-coordinates are selected
            z (Tuple[int, int]) : z-coordinate slice. if None, all z-coordinates are selected
            axis (int) : axis to cut, alternative to x, y, z
            cut (Tuple[int, int]) : slice along dimension specified by axis. ignored if axis is None
            step (int) : step size for the slice. ignored if axis is None
        returns:
            Tuple[slice] : slices for the given variable and coordinates with length equal to ndim.
                if ndim is 1, a single slice is returned
        """
        slices = [slice(None)] * self.ndim

        if var is not None:
            if isinstance(var, str):
                # retrieve single variable index
                if var not in self.variable_map:
                    raise ValueError(f"Variable '{var}' not found.")
                slices[0] = self.variable_map[var]
            elif isinstance(var, tuple):
                # retrieve multiple variable indices
                missing_vars = set(var) - set(self.variable_map.keys())
                if missing_vars:
                    raise ValueError(f"Variables not found: {missing_vars}")
                slices[0] = np.array(list(map(self.variable_map.get, var)))
            else:
                raise ValueError(f"Invalid type for var: {type(var)}")

        axes = [1, 2, 3, axis]
        axis_slices = [x, y, z, cut]
        for i, axis_slice in zip(axes, axis_slices):
            if axis_slice is not None:
                if i >= self.ndim:
                    raise ValueError(
                        f"Invalid axis {i} for array with {self.ndim} dimensions."
                    )
                if not isinstance(axis_slice, tuple):
                    raise ValueError(
                        f"Expected a tuple (start, stop) for axis {i}, got {axis_slice} of type {type(axis_slice)}"
                    )
                if len(axis_slice) != 2:
                    raise ValueError(
                        f"Invalid tuple length for axis {i}: {len(axis_slice)}"
                    )
                slices[i] = slice(
                    axis_slice[0] or None,
                    axis_slice[1] or None,
                    step if i == axis else None,
                )

        if len(slices) == 1:
            return slices[0]
        return tuple(slices)

In [4]:
@lru_cache
def get_slice(
    ndim=4,
    var: Union[str, Tuple[str]] = None,
    x: Tuple[int, int] = None,
    y: Tuple[int, int] = None,
    z: Tuple[int, int] = None,
    axis: int = None,
    cut: Tuple[int, int] = None,
    step: int = None,
) -> Union[Tuple[slice], slice]:
    """
    Get the slice for the given variable and coordinates.
    args:
        var (str) : variable name or tuple of variable names. if None, all variables are selected
        x (Tuple[int, int]) : x-coordinate slice. if None, all x-coordinates are selected
        y (Tuple[int, int]) : y-coordinate slice. if None, all y-coordinates are selected
        z (Tuple[int, int]) : z-coordinate slice. if None, all z-coordinates are selected
        axis (int) : axis to cut, alternative to x, y, z
        cut (Tuple[int, int]) : slice along dimension specified by axis. ignored if axis is None
        step (int) : step size for the slice. ignored if axis is None
    returns:
        Tuple[slice] : slices for the given variable and coordinates with length equal to ndim.
            if ndim is 1, a single slice is returned
    """
    slices = [slice(None)] * ndim

    if var is not None:
        if isinstance(var, str):
            # retrieve single variable index
            if var not in VARIABLE_IDX_MAP:
                raise ValueError(f"Variable '{var}' not found.")
            slices[0] = VARIABLE_IDX_MAP[var]
        elif isinstance(var, tuple):
            # retrieve multiple variable indices
            missing_vars = set(var) - set(VARIABLE_IDX_MAP.keys())
            if missing_vars:
                raise ValueError(f"Variables not found: {missing_vars}")
            slices[0] = np.array(list(map(VARIABLE_IDX_MAP.get, var)))
        else:
            raise ValueError(f"Invalid type for var: {type(var)}")

    axes = [1, 2, 3, axis]
    axis_slices = [x, y, z, cut]
    for i, axis_slice in zip(axes, axis_slices):
        if axis_slice is not None:
            if i >= ndim:
                raise ValueError(f"Invalid axis {i} for array with {ndim} dimensions.")
            if not isinstance(axis_slice, tuple):
                raise ValueError(
                    f"Expected a tuple (start, stop) for axis {i}, got {axis_slice} of type {type(axis_slice)}"
                )
            if len(axis_slice) != 2:
                raise ValueError(
                    f"Invalid tuple length for axis {i}: {len(axis_slice)}"
                )
            slices[i] = slice(
                axis_slice[0] or None,
                axis_slice[1] or None,
                step if i == axis else None,
            )

    if len(slices) == 1:
        return slices[0]
    return tuple(slices)

In [5]:
hs = HydroState(passive_scalars=("scalar1", "scalar2"))
chs = CachedHydroState(passive_scalars=("scalar1", "scalar2"))

In [6]:
chs("rho")
hs("rho")

(0, slice(None, None, None), slice(None, None, None), slice(None, None, None))

In [7]:
%%timeit
get_slice(var="rho")

138 ns ± 6.96 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)


In [8]:
%%timeit
hs("rho")

836 ns ± 57.1 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)


In [9]:
%%timeit
chs("rho")

185 ns ± 6.44 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)
