In [3]:
from __future__ import annotations
import numpy as np
from typing import Tuple, List, Optional
import random
import math
import sys

from lib.helpers import prod, flatten, recursive_checkshape, recursive_checktype

In [1]:
class tarray:
    def __init__(self, shape, dtype=float, buffer=None, offset=0, strides=None, order='C'):
        # boolean flags
        self.isContiguous = self._check_contiguous(buffer, shape, strides)
        self.isView = True if buffer else False
        self.order = order

        # fixed attributes
        self.dtype = dtype
        self.itemsize : int = 1 # number of bytes in an item 
        #? itemsize is one right now because we're storing elements not bytes. fp32 is 4 bytes, fp64 is 8 bytes, etc.
        self.nbytes = self.itemsize * self.size # size in bytes (number of bytes in an item * number of items)
        self.offset = offset # offset in bytes

        # mutable attributes
        self.shape = shape # shape (this has to be set before calling _cstrides)
        self.size = prod(shape) # size in elements
        self.strides = strides if strides else self._cstrides(shape) # strides in bytes
        
        # data
        if len(buffer) != self.size: raise Exception("Data size does not match shape")
        self.data = buffer if buffer else [[0] * self.size] # if no buffer, create one, else use the buffer
        
        # validate input arguments
        self._validate()

    # ! Methods called by the constructor ======================================
    def _check_contiguous(self, buffer, shape, strides):
        """Checks if the array is contiguous"""
        if not buffer or not strides or len(shape)==1: return True # if there's no buffer or strides, it's contiguous, or if it's 1D
        # TODO will need to change this method to _cstridesbytes if we're storing bytes instead of elements
        return (expected_strides := self._cstrides(shape)) == strides # return if the strides are the same as the expected strides

    def _validate(self):
        """Validate the arguments passed to the constructor"""
        if self.order != 'C': raise NotImplementedError("Only C order is supported right now") # check ourder
        if len(self.shape) != len(self.strides): raise ValueError("Shape and strides must have same length") # check shape and strides
        if not isinstance(self.offset, int): raise TypeError("Offset must be an integer") # check offset
        if not isinstance(self.shape, tuple): raise TypeError("Shape must be a tuple") # check shape
        if not isinstance(self.strides, tuple): raise TypeError("Strides must be a tuple") # check strides
        # TODO check valid dtype goes here
        # TODO check strides match the buffer (in bounds)

    # ! Internal methods =======================================================
    def _cstrides_readable(self, shape: Tuple[int, ...]):
        """
        Returns a tuple of strides for a C-ordered array (in elements)
        $$n^X = \sum^{N-1}_{i=1}n_i s^x_i$$ 
        """
        strides = [0 for _ in shape] # make an empty strides list of the same length as shape
        for i in range(len(shape)): # for each dim in the shape (0 to N-1) for len(shape) = N
            stride = 8 #? if this = 1, strides is the number of elements to skip, if this = 8, strides is the number of bytes to skip
            for j in range(i+1, len(shape)): # accumulate the rest of the dims up to the N-1 dim
                stride *= shape[j] # multiply the stride by the dim
            strides[i] = stride # set the stride in the list
        return tuple(strides) # return as a tuple
    
    def _cstrides(self, shape):
        """Returns a tuple of strides for a C-ordered array (in elements)"""
        return tuple([self.itemsize * prod(shape[i+1:]) for i in range(len(shape))])
    
    def _cstridesbytes(self, shape):
        """Returns a tuple of strides for a C-ordered array (in bytes)"""
        return tuple([self.itemsize * 8 * prod(shape[i+1:]) for i in range(len(shape))])
    
    # ! Public methods =========================================================
    def view(self):
        """Returns a view of the array (new tarray with same data)"""
        return tarray(self.shape, dtype=self.dtype, buffer=self.data, offset=self.offset, strides=self.strides, order=self.order)
    
    def reshape(self, shape: Tuple[int, ...]):
        """Returns a view of the array with the given shape"""
        if prod(shape) != self.size: raise ValueError("Cannot reshape to different size")
        return tarray(shape, dtype=self.dtype, buffer=self.data, offset=self.offset, strides=self.strides, order=self.order)

    def __getitem__(self, indices: Tuple[int,...]):
        raise NotImplementedError("getitem is not implemented yet")

    def __repr__(self):
        return f"tarray(shape={self.shape}, dtype={self.dtype}, strides={self.strides}, order={self.order})"


def array(obj, dtype=float, order='C'):
    """Returns a tarray from an object"""
    if isinstance(obj, tarray): return obj # 1) if it's already a tarray, return it
    shape = tuple(recursive_checkshape(obj)) # 2) get the shape
    data = flatten(obj) # 3) flatten the data
    if not recursive_checktype(data, dtype): raise TypeError(f"Cannot convert {type(obj)} to {dtype}") # 4) make sure all elements are of the same type
    return tarray(shape, dtype=dtype, buffer=data, order=order) # 5) return the tarray

NameError: name 'Tuple' is not defined

In [5]:
np_arr = np.random.randint(0, 10, (3, 3, 2))
arr = array(np_arr.astype(float))

In [6]:
print(tuple(8*i for i in arr.strides))
print(np_arr.strides)

(48, 16, 8)
(48, 16, 8)


In [7]:
arr.shape

(3, 3, 2)

In [9]:
arr[(0,0,0)]

NotImplementedError: getitem is not implemented yet

In [None]:
np_arr_reshape = np_arr.reshape((9, 2))
arr_reshape = arr.reshape((9, 2))

In [None]:
print(tuple(8*i for i in arr_reshape.strides))
print(np_arr_reshape.strides)

(16, 8)
(16, 8)


In [None]:
arr.data == arr_reshape.data

True

In [None]:
arr_reshape

tarray(shape=(9, 2), dtype=<class 'float'>, strides=(2, 1), order=C)

In [None]:
arr = array([1.,2.,3.,4.,5.,6.])

(8,)