Skip to content

Commit

Permalink
fix a bug in the __getitem__ method
Browse files Browse the repository at this point in the history
  • Loading branch information
m0saan committed Jun 17, 2023
1 parent 063e9db commit 6f0cc05
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 142 deletions.
12 changes: 11 additions & 1 deletion minima/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,12 @@
'minima.operators.reshape': ('operators.html#reshape', 'minima/operators.py'),
'minima.operators.summation': ('operators.html#summation', 'minima/operators.py'),
'minima.operators.transpose': ('operators.html#transpose', 'minima/operators.py')},
'minima.optim': { 'minima.optim.Adam': ('optim.html#adam', 'minima/optim.py'),
'minima.optim': { 'minima.optim.AdaGrad': ('optim.html#adagrad', 'minima/optim.py'),
'minima.optim.AdaGrad.__init__': ('optim.html#adagrad.__init__', 'minima/optim.py'),
'minima.optim.AdaGrad._opt_step': ('optim.html#adagrad._opt_step', 'minima/optim.py'),
'minima.optim.AdaGrad._reg_step': ('optim.html#adagrad._reg_step', 'minima/optim.py'),
'minima.optim.AdaGrad.step': ('optim.html#adagrad.step', 'minima/optim.py'),
'minima.optim.Adam': ('optim.html#adam', 'minima/optim.py'),
'minima.optim.Adam.__init__': ('optim.html#adam.__init__', 'minima/optim.py'),
'minima.optim.Adam._opt_step': ('optim.html#adam._opt_step', 'minima/optim.py'),
'minima.optim.Adam._reg_step': ('optim.html#adam._reg_step', 'minima/optim.py'),
Expand All @@ -347,6 +352,11 @@
'minima.optim.Optimizer.__init__': ('optim.html#optimizer.__init__', 'minima/optim.py'),
'minima.optim.Optimizer.step': ('optim.html#optimizer.step', 'minima/optim.py'),
'minima.optim.Optimizer.zero_grad': ('optim.html#optimizer.zero_grad', 'minima/optim.py'),
'minima.optim.RMSProp': ('optim.html#rmsprop', 'minima/optim.py'),
'minima.optim.RMSProp.__init__': ('optim.html#rmsprop.__init__', 'minima/optim.py'),
'minima.optim.RMSProp._opt_step': ('optim.html#rmsprop._opt_step', 'minima/optim.py'),
'minima.optim.RMSProp._reg_step': ('optim.html#rmsprop._reg_step', 'minima/optim.py'),
'minima.optim.RMSProp.step': ('optim.html#rmsprop.step', 'minima/optim.py'),
'minima.optim.SGD': ('optim.html#sgd', 'minima/optim.py'),
'minima.optim.SGD.__init__': ('optim.html#sgd.__init__', 'minima/optim.py'),
'minima.optim.SGD._opt_step': ('optim.html#sgd._opt_step', 'minima/optim.py'),
Expand Down
8 changes: 7 additions & 1 deletion minima/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,7 +607,13 @@ def __getitem__(self, idxs):
]
)
assert len(idxs) == self.ndim, "Need indexes equal to number of dimensions"
shape = tuple((idx.stop - idx.start) // idx.step for idx in idxs)

shape = []
for i in idxs:
d = i.stop - i.start
dim_size = d // i.step + d % i.step
shape.append(dim_size)

offset = sum(idx.start * stride for idx, stride in zip(idxs, self._strides))
strides = tuple(idx.step * stride for idx, stride in zip(idxs, self._strides)) # Corrected line -> haha was FUN!!
return NDArray.make(shape, strides=strides, device=self._device, handle=self._handle, offset=offset)
Expand Down
54 changes: 27 additions & 27 deletions minima/ndarray_backend_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@
# %% ../nbs/07_ndarray_backend_numpy.ipynb 2
import numpy as np

# %% ../nbs/07_ndarray_backend_numpy.ipynb 7
# %% ../nbs/07_ndarray_backend_numpy.ipynb 3
__device_name__ = "numpy"
_datatype = np.float32
_datetype_size = np.dtype(_datatype).itemsize

# %% ../nbs/07_ndarray_backend_numpy.ipynb 8
# %% ../nbs/07_ndarray_backend_numpy.ipynb 4
class Array:
def __init__(self, size):
self.array = np.empty(size, dtype=np.float32)
Expand All @@ -26,7 +26,7 @@ def __repr__(self):
def size(self):
return self.array.size

# %% ../nbs/07_ndarray_backend_numpy.ipynb 10
# %% ../nbs/07_ndarray_backend_numpy.ipynb 6
def to_numpy(a, shape, strides, offset):
"""
Converts a contiguous 1D array into an N-dimensional array using numpy stride tricks.
Expand Down Expand Up @@ -64,7 +64,7 @@ def to_numpy(a, shape, strides, offset):
)


# %% ../nbs/07_ndarray_backend_numpy.ipynb 14
# %% ../nbs/07_ndarray_backend_numpy.ipynb 10
def from_numpy(a: np.ndarray, out) -> None:
"""
Assigns a flattened version of the input N-dimensional array to another array.
Expand Down Expand Up @@ -100,7 +100,7 @@ def from_numpy(a: np.ndarray, out) -> None:
out.array[:] = a.flatten()


# %% ../nbs/07_ndarray_backend_numpy.ipynb 20
# %% ../nbs/07_ndarray_backend_numpy.ipynb 16
def fill(out: Array, val) -> None:
"""
Fills an Array object with a specific value.
Expand Down Expand Up @@ -129,7 +129,7 @@ def fill(out: Array, val) -> None:
out.array.fill(val)


# %% ../nbs/07_ndarray_backend_numpy.ipynb 22
# %% ../nbs/07_ndarray_backend_numpy.ipynb 18
def compact(a, out: Array, shape, strides, offset):
"""
Transforms a 1D array into an N-dimensional array, flattens it, and assigns it to an Array object.
Expand Down Expand Up @@ -168,7 +168,7 @@ def compact(a, out: Array, shape, strides, offset):

out.array[:] = to_numpy(a, shape, strides, offset).flatten()

# %% ../nbs/07_ndarray_backend_numpy.ipynb 24
# %% ../nbs/07_ndarray_backend_numpy.ipynb 20
def ewise_setitem(a: Array, out: Array, shape, strides, offset):
"""
Modifies a section of an Array object to be equivalent to another reshaped array, on an element-wise basis.
Expand Down Expand Up @@ -209,7 +209,7 @@ def ewise_setitem(a: Array, out: Array, shape, strides, offset):

to_numpy(out, shape, strides, offset)[:] = a.array.reshape(shape)

# %% ../nbs/07_ndarray_backend_numpy.ipynb 32
# %% ../nbs/07_ndarray_backend_numpy.ipynb 28
def scalar_setitem(val, out: Array, shape, strides, offset):
"""
Fills a section of an Array object with a specific scalar value.
Expand Down Expand Up @@ -249,7 +249,7 @@ def scalar_setitem(val, out: Array, shape, strides, offset):
to_numpy(out, shape, strides, offset)[:] = val


# %% ../nbs/07_ndarray_backend_numpy.ipynb 34
# %% ../nbs/07_ndarray_backend_numpy.ipynb 30
def ewise_add(a: Array, b: Array, out: Array):
"""
Performs an element-wise addition of two Array objects and assigns the result to a third Array object.
Expand Down Expand Up @@ -282,7 +282,7 @@ def ewise_add(a: Array, b: Array, out: Array):
out.array[:] = a.array + b.array


# %% ../nbs/07_ndarray_backend_numpy.ipynb 36
# %% ../nbs/07_ndarray_backend_numpy.ipynb 32
def scalar_add(a: Array, val, out: Array):
"""
Adds a scalar value to an Array object and assigns the result to another Array object.
Expand Down Expand Up @@ -315,7 +315,7 @@ def scalar_add(a: Array, val, out: Array):
out.array[:] = a.array + val


# %% ../nbs/07_ndarray_backend_numpy.ipynb 38
# %% ../nbs/07_ndarray_backend_numpy.ipynb 34
def ewise_mul(a: Array, b: Array, out: Array):
"""
Performs an element-wise multiplication of two Array objects and assigns the result to a third Array object.
Expand Down Expand Up @@ -348,7 +348,7 @@ def ewise_mul(a: Array, b: Array, out: Array):
out.array[:] = a.array * b.array


# %% ../nbs/07_ndarray_backend_numpy.ipynb 40
# %% ../nbs/07_ndarray_backend_numpy.ipynb 36
def scalar_mul(a: Array, val, out: Array):
"""
Multiplies an Array object by a scalar value and assigns the result to another Array object.
Expand Down Expand Up @@ -381,7 +381,7 @@ def scalar_mul(a: Array, val, out: Array):
out.array[:] = a.array * val


# %% ../nbs/07_ndarray_backend_numpy.ipynb 42
# %% ../nbs/07_ndarray_backend_numpy.ipynb 38
def ewise_div(a: Array, b: Array, out: Array):
"""
Performs an element-wise division of two Array objects and assigns the result to a third Array object.
Expand Down Expand Up @@ -414,7 +414,7 @@ def ewise_div(a: Array, b: Array, out: Array):
out.array[:] = a.array / b.array


# %% ../nbs/07_ndarray_backend_numpy.ipynb 44
# %% ../nbs/07_ndarray_backend_numpy.ipynb 40
def scalar_div(a: Array, val, out: Array):
"""
Divides an Array object by a scalar value and assigns the result to another Array object.
Expand Down Expand Up @@ -447,7 +447,7 @@ def scalar_div(a: Array, val, out: Array):
out.array[:] = a.array / val


# %% ../nbs/07_ndarray_backend_numpy.ipynb 46
# %% ../nbs/07_ndarray_backend_numpy.ipynb 42
def scalar_power(a: Array, val, out: Array):
"""
Raises an Array object to the power of a scalar value and assigns the result to another Array object.
Expand Down Expand Up @@ -480,7 +480,7 @@ def scalar_power(a: Array, val, out: Array):
out.array[:] = a.array ** val


# %% ../nbs/07_ndarray_backend_numpy.ipynb 48
# %% ../nbs/07_ndarray_backend_numpy.ipynb 44
def ewise_maximum(a: Array, b: Array, out: Array):
"""
Computes the element-wise maximum of two Array objects and assigns the result to a third Array object.
Expand Down Expand Up @@ -513,7 +513,7 @@ def ewise_maximum(a: Array, b: Array, out: Array):
out.array[:] = np.maximum(a.array, b.array)


# %% ../nbs/07_ndarray_backend_numpy.ipynb 50
# %% ../nbs/07_ndarray_backend_numpy.ipynb 46
def scalar_maximum(a: Array, val, out: Array):
"""
Computes the maximum of an Array object and a scalar value, and assigns the result to another Array object.
Expand Down Expand Up @@ -545,7 +545,7 @@ def scalar_maximum(a: Array, val, out: Array):
"""
out.array[:] = np.maximum(a.array, val)

# %% ../nbs/07_ndarray_backend_numpy.ipynb 52
# %% ../nbs/07_ndarray_backend_numpy.ipynb 48
def ewise_eq(a: Array, b: Array, out: Array):
"""
Performs an element-wise comparison for equality between two Array objects and assigns the result to a third Array object.
Expand Down Expand Up @@ -578,7 +578,7 @@ def ewise_eq(a: Array, b: Array, out: Array):
out.array[:] = (a.array == b.array).astype(np.float32)


# %% ../nbs/07_ndarray_backend_numpy.ipynb 54
# %% ../nbs/07_ndarray_backend_numpy.ipynb 50
def scalar_eq(a: Array, val, out: Array):
"""
Compares an Array object with a scalar value for equality and assigns the result to another Array object.
Expand Down Expand Up @@ -611,7 +611,7 @@ def scalar_eq(a: Array, val, out: Array):
out.array[:] = (a.array == val).astype(np.float32)


# %% ../nbs/07_ndarray_backend_numpy.ipynb 56
# %% ../nbs/07_ndarray_backend_numpy.ipynb 52
def ewise_ge(a: Array, b: Array, out: Array):
"""
Performs an element-wise comparison to check if elements of one Array object are greater than or equal to those of another Array object. The result is assigned to a third Array object.
Expand Down Expand Up @@ -644,7 +644,7 @@ def ewise_ge(a: Array, b: Array, out: Array):
out.array[:] = (a.array >= b.array).astype(np.float32)


# %% ../nbs/07_ndarray_backend_numpy.ipynb 58
# %% ../nbs/07_ndarray_backend_numpy.ipynb 54
def scalar_ge(a: Array, val, out: Array):
"""
Compares an Array object with a scalar value to check if elements in the Array object are greater than or equal to the scalar. The result is assigned to another Array object.
Expand Down Expand Up @@ -677,7 +677,7 @@ def scalar_ge(a: Array, val, out: Array):
out.array[:] = (a.array >= val).astype(np.float32)


# %% ../nbs/07_ndarray_backend_numpy.ipynb 60
# %% ../nbs/07_ndarray_backend_numpy.ipynb 56
def ewise_log(a: Array, out: Array):
"""
Computes the natural logarithm of each element in an Array object and assigns the result to another Array object.
Expand Down Expand Up @@ -708,7 +708,7 @@ def ewise_log(a: Array, out: Array):
out.array[:] = np.log(a.array)


# %% ../nbs/07_ndarray_backend_numpy.ipynb 62
# %% ../nbs/07_ndarray_backend_numpy.ipynb 58
def ewise_exp(a: Array, out: Array):
"""
Computes the exponential of each element in an Array object and assigns the result to another Array object.
Expand Down Expand Up @@ -739,7 +739,7 @@ def ewise_exp(a: Array, out: Array):
out.array[:] = np.exp(a.array)


# %% ../nbs/07_ndarray_backend_numpy.ipynb 68
# %% ../nbs/07_ndarray_backend_numpy.ipynb 64
def ewise_tanh(a: Array, out: Array):
"""
Computes the hyperbolic tangent of each element in an Array object and assigns the result to another Array object.
Expand Down Expand Up @@ -770,7 +770,7 @@ def ewise_tanh(a: Array, out: Array):
out.array[:] = np.tanh(a.array)


# %% ../nbs/07_ndarray_backend_numpy.ipynb 73
# %% ../nbs/07_ndarray_backend_numpy.ipynb 69
def reduce_max(a: Array, out: Array, reduce_size: int):
"""
Computes the maximum of every `reduce_size` elements in an Array object and assigns the result to another Array object.
Expand Down Expand Up @@ -803,7 +803,7 @@ def reduce_max(a: Array, out: Array, reduce_size: int):
"""
out.array[:] = a.array[:].reshape(-1, reduce_size).max(axis=1)

# %% ../nbs/07_ndarray_backend_numpy.ipynb 77
# %% ../nbs/07_ndarray_backend_numpy.ipynb 73
def reduce_sum(a: Array, out: Array, reduce_size: int):
"""
Computes the sum of every `reduce_size` elements in an Array object and assigns the result to another Array object.
Expand Down Expand Up @@ -836,7 +836,7 @@ def reduce_sum(a: Array, out: Array, reduce_size: int):
"""
out.array[:] = a.array[:].reshape(-1, reduce_size).sum(axis=1)

# %% ../nbs/07_ndarray_backend_numpy.ipynb 82
# %% ../nbs/07_ndarray_backend_numpy.ipynb 78
def matmul(a: Array, b: Array, out: Array, m: int, n: int, p: int):
"""
Performs matrix multiplication between two Array objects and assigns the result to another Array object.
Expand Down
20 changes: 7 additions & 13 deletions nbs/06_ndarray.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -651,7 +651,13 @@
" ]\n",
" )\n",
" assert len(idxs) == self.ndim, \"Need indexes equal to number of dimensions\"\n",
" shape = tuple((idx.stop - idx.start) // idx.step for idx in idxs)\n",
" \n",
" shape = []\n",
" for i in idxs:\n",
" d = i.stop - i.start\n",
" dim_size = d // i.step + d % i.step\n",
" shape.append(dim_size)\n",
" \n",
" offset = sum(idx.start * stride for idx, stride in zip(idxs, self._strides))\n",
" strides = tuple(idx.step * stride for idx, stride in zip(idxs, self._strides)) # Corrected line -> haha was FUN!!\n",
" return NDArray.make(shape, strides=strides, device=self._device, handle=self._handle, offset=offset)\n",
Expand Down Expand Up @@ -872,18 +878,6 @@
"display_name": "python3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.8"
}
},
"nbformat": 4,
Expand Down
Loading

0 comments on commit 6f0cc05

Please sign in to comment.