In [3]:
from __future__ import annotations
import numpy as np
from typing import Tuple, List, Optional
import random
import math
import sys

from lib.helpers import prod, flatten, recursive_checkshape, recursive_checktype

In [4]:
class tarray:
    def __init__(self, shape, dtype='fp32', buffer=None, offset=0, strides=None, order='C'):
        self.dtype = dtype
        self.shape = shape
        self.size = prod(shape) # number of elements
        self.ndim = len(shape)
        self.itemsize = self._get_itemsize(dtype) # bytes per element

        if buffer and len(buffer) != self.size: raise Exception("Data size does not match shape")
        self.data = buffer if buffer else self._allocate_buffer(self.size, self.itemsize)
        self.base = None if not buffer else buffer

        self.nbytes = self.itemsize * self.size # total bytes
        self.offset = offset # offset in bytes
        self.order = order # order of the array (C or F)
        self.strides = strides if strides else self._cstrides(shape, self.itemsize) # strides in bytes

        self.isContiguous = self._check_contiguous(buffer, shape, strides, self.itemsize) 
        self._validate()

    # ! Methods called by the constructor ======================================
    def _allocate_buffer(self, size, itemsize):
        """Allocates a buffer of the given shape and dtype"""
        return bytearray(size * itemsize)
    
    def _get_itemsize(self, dtype):
        map = {"fp32": 4}
        if dtype in map: return map[dtype]
        raise NotImplementedError(f"Type {dtype} is not supported yet")

    def _check_contiguous(self, buffer, shape, strides, itemsize):
        """Checks if the array is contiguous"""
        if not buffer or not strides or len(shape)==1: return True # if there's no buffer or strides, it's contiguous, or if it's 1D
        # TODO will need to change this method to _cstridesbytes if we're storing bytes instead of elements
        return self._cstrides(shape, itemsize) == strides # return if the strides are the same as the expected strides

    def _validate(self):
        """Validate the arguments passed to the constructor"""
        if self.order != 'C': raise NotImplementedError("Only C order is supported right now") # check ourder
        if len(self.shape) != len(self.strides): raise ValueError("Shape and strides must have same length") # check shape and strides
        if not isinstance(self.offset, int): raise TypeError("Offset must be an integer") # check offset
        if not isinstance(self.shape, tuple): raise TypeError("Shape must be a tuple") # check shape
        if not isinstance(self.strides, tuple): raise TypeError("Strides must be a tuple") # check strides
        if not isinstance(self.data, bytearray): raise TypeError("Data must be a bytearray")
        if not isinstance(self.order, str): raise TypeError("Order must be a string")
        # TODO check valid dtype goes here
        # TODO check strides match the buffer (in bounds)

    # ! Internal methods =======================================================
    def _cstrides_readable(self, shape: Tuple[int, ...], itemsize: int):
        """
        Returns a tuple of strides for a C-ordered array (in bytes)
        $$n^X = \sum^{N-1}_{i=1}n_i s^x_i$$ 
        """
        strides = [0 for _ in shape] # make an empty strides list of the same length as shape
        for i in range(len(shape)): # for each dim in the shape (0 to N-1) for len(shape) = N
            stride = itemsize # set the stride to the itemsize * 8 (bits)
            for j in range(i+1, len(shape)): # accumulate the rest of the dims up to the N-1 dim
                stride *= shape[j] # multiply the stride by the dim
            strides[i] = stride # set the stride in the list
        return tuple(strides) # return as a tuple
    
    def _cstrides(self, shape: Tuple[int, ...], itemsize: int):
        """
        Returns a tuple of strides for a C-ordered array (in bytes)
        $$n^X = \sum^{N-1}_{i=1}n_i s^x_i$$ 
        """
        return tuple([itemsize*prod(shape[i+1:]) for i in range(len(shape))])
    
    # ! Utility methods ========================================================
    def _validate_shape(self, shape: Tuple[int, ...]):
        """Validates the shape"""
        if not isinstance(shape, tuple): raise TypeError("Shape must be a tuple")
        if not all(isinstance(dim, int) for dim in shape): raise TypeError("All dimensions must be integers")
        if not all(dim > 0 for dim in shape): raise ValueError("All dimensions must be greater than 0") # TODO add 0 support
        # TODO add -1 support
    
    # ! Public methods =========================================================
    def view(self):
        """Returns a view of the array (new tarray with same data)"""
        return tarray(self.shape, dtype=self.dtype, buffer=self.data, offset=self.offset, strides=self.strides, order=self.order)
    
    def reshape(self, shape: Tuple[int, ...]):
        """Returns a view of the array with the given shape"""
        self._validate_shape(shape)
        if prod(shape) != self.size: raise ValueError("Cannot reshape to different size") 
        
        # calculate the new strides
        # case 1: simple reshape with no permutation (3, 4) -> (2, 6)
        if len(shape) == len(self.shape):
            new_strides = self.strides
        # case 2: adding or removing a dimension
        # this includes adding singleton dimensions or removing them (3, 4) -> (3, 1, 4) or (3, 1, 4) -> (3, 4)
            
        # case 3: permutation (3, 4) -> (4, 3) or (3, 4, 5) -> (4, 3, 5)
        
        # case 4: reshaping with -1 (12,) -> (-1, 3) which infers 4
            
        # case 5: permutation and singleton (3, 4) -> (4, 1, 3)


        return tarray(shape, dtype=self.dtype, buffer=self.data, offset=self.offset, strides=new_strides, order=self.order)

    def __getitem__(self, indices: Tuple[int,...]):
        raise NotImplementedError("getitem is not implemented yet")

    def __repr__(self):
        return f"tarray(shape={self.shape}, dtype={self.dtype}, strides={self.strides}, order={self.order})"


def array(obj, dtype, order='C'):
    """Returns a tarray from an object"""
    if isinstance(obj, tarray): return obj # 1) if it's already a tarray, return it
    shape = tuple(recursive_checkshape(obj)) # 2) get the shape
    data = flatten(obj) # 3) flatten the data
    # TODO handle dtype
    # if not recursive_checktype(data, float): raise TypeError(f"Cannot convert {type(obj)} to {dtype}") # 4) make sure all elements are of the same type
    return tarray(shape, dtype=dtype, buffer=data, order=order) # 5) return the tarray

In [5]:
np_arr = np.random.randint(0, 10, (3, 3, 2))
arr = array(np_arr.astype(float))

In [6]:
print(tuple(8*i for i in arr.strides))
print(np_arr.strides)

(48, 16, 8)
(48, 16, 8)


In [7]:
arr.shape

(3, 3, 2)

In [9]:
arr[(0,0,0)]

NotImplementedError: getitem is not implemented yet

In [None]:
np_arr_reshape = np_arr.reshape((9, 2))
arr_reshape = arr.reshape((9, 2))

In [None]:
print(tuple(8*i for i in arr_reshape.strides))
print(np_arr_reshape.strides)

(16, 8)
(16, 8)


In [None]:
arr.data == arr_reshape.data

True

In [None]:
arr_reshape

tarray(shape=(9, 2), dtype=<class 'float'>, strides=(2, 1), order=C)

In [None]:
arr = array([1.,2.,3.,4.,5.,6.])

(8,)