In [2]:
from typing import List, Dict, Tuple

from lib.helpers import prod

In [4]:
class tarray:
    def __init__(self, shape, dtype, buffer=None, offset=0, strides=None, order='C'):
        """
        Initialize an array.

        Parameters:
        shape (tuple of ints): The shape of the array.
        dtype (data-type, optional): The data type of the array elements.
        buffer (buffer-like, optional): Object exposing buffer interface.
        offset (int, optional): Offset of array data in buffer.
        strides (tuple of ints, optional): Strides of data in memory.
        order ({'C', 'F'}, optional): Row-major (C-style) or column-major (Fortran-style) order.
        """

        self._shape = shape
        self.dtype = dtype
    
        self.itemsize = self._get_itemsize(dtype) # bytes per element

        if buffer and len(buffer) != self.size: raise Exception("Data size does not match shape")
        self.data = buffer if buffer else self._allocate_buffer(self.size, self.itemsize)
        self.base = None if not buffer else buffer

        self.nbytes = self.itemsize * self.size # total bytes
        self.offset = offset # offset in bytes
        self.order = order # order of the array (C or F)
        self.strides = strides if strides else self._cstrides(shape, self.itemsize) # strides in bytes

        self.isContiguous = self._check_contiguous(buffer, shape, strides, self.itemsize) 
        self._validate()

    # ! Methods called by the constructor ======================================
    def _allocate_buffer(self, size, itemsize):
        """Allocates a buffer of the given shape and dtype"""
        return bytearray(size * itemsize)
    
    def _get_itemsize(self, dtype):
        map = {"fp32": 4}
        if dtype in map: return map[dtype]
        raise NotImplementedError(f"Type {dtype} is not supported yet")

    def _check_contiguous(self, buffer, shape, strides, itemsize):
        """Checks if the array is contiguous"""
        if not buffer or not strides or len(shape)==1: return True # if there's no buffer or strides, it's contiguous, or if it's 1D
        # TODO will need to change this method to _cstridesbytes if we're storing bytes instead of elements
        return self._cstrides(shape, itemsize) == strides # return if the strides are the same as the expected strides

    def _validate(self):
        """Validate the arguments passed to the constructor"""
        if self.order != 'C': raise NotImplementedError("Only C order is supported right now") # check ourder
        if len(self.shape) != len(self.strides): raise ValueError("Shape and strides must have same length") # check shape and strides
        if not isinstance(self.offset, int): raise TypeError("Offset must be an integer") # check offset
        if not isinstance(self.shape, tuple): raise TypeError("Shape must be a tuple") # check shape
        if not isinstance(self.strides, tuple): raise TypeError("Strides must be a tuple") # check strides
        if not isinstance(self.data, bytearray): raise TypeError("Data must be a bytearray")
        if not isinstance(self.order, str): raise TypeError("Order must be a string")
        # TODO check valid dtype goes here
        # TODO check strides match the buffer (in bounds)

    # ! Internal methods =======================================================    
    def _cstrides(self, shape: Tuple[int, ...], itemsize: int):
        """
        Returns a tuple of strides for a C-ordered array (in bytes)
        $$n^X = \sum^{N-1}_{i=1}n_i s^x_i$$ where $s^x_i$ gives the *stride* for dimension *i*.
        $$s^C_i = \prod^{N-1}_{j=i+1}d_j=d_{i+1}d_{i+2}\cdots d_{N-1},$$
        """
        return tuple([itemsize*prod(shape[i+1:]) for i in range(len(shape))])
    
    def _fstrides(self, shape: Tuple[int, ...], itemsize: int):
        """
        Returns a tuple of strides for a F-ordered array (in bytes)
        $$n^X = \sum^{N-1}_{i=1}n_i s^x_i$$ where $s^x_i$ gives the *stride* for dimension *i*.
        $$s^F_i = \prod^{i-1}_{j=0}d_j=d_0d_1\cdots d_{i-1}$$
        """
        return tuple([itemsize*prod(shape[0:i]) for i in range(len(shape))]) # TODO check this (slicing is exclusive so we dont need to subtract 1 - i think)

    # ! Utility methods ========================================================
    def _validate_shape(self, shape: Tuple[int, ...]):
        """Validates the shape"""
        if not isinstance(shape, tuple): raise TypeError("Shape must be a tuple")
        if not all(isinstance(dim, int) for dim in shape): raise TypeError("All dimensions must be integers")
        if not all(dim > 0 for dim in shape): raise ValueError("All dimensions must be greater than 0") # TODO add 0 support
        # TODO add -1 support

    # ! Public methods =========================================================
    def __getitem__(self, key):
        """
        Get an item or slice from the array.
        """

    def __setitem__(self, key, value):
        """
        Set an item or slice in the array.
        """

    def reshape(self, new_shape):
        """
        Returns an array containing the same data with a new shape.

        Parameters:
        new_shape (tuple of ints): The new shape of the array.
        """

    def transpose(self, *axes):
        """
        Returns a view of the array with axes transposed.

        Parameters:
        axes (sequence of ints, optional): By default, reverse the dimensions.
        """

    @property
    def shape(self):
        """
        Tuple of array dimensions.
        """
        return self._shape

    @property
    def size(self):
        """
        Number of elements in the array.
        """
        return prod(self.shape)

    @property
    def ndim(self):
        """
        Number of array dimensions.
        """
        return len(self.shape)

    def __str__(self):
        """
        Informal string representation of the array.
        """
        return f"tarray(shape={self.shape}, dtype={self.dtype}, strides={self.strides}, order={self.order})"

    def __repr__(self):
        """
        String representation of the array.
        """
        return self.__str__()