In [1]:
import numpy as np
from primitive import arange, where, tensor

## [ones](https://numpy.org/doc/stable/reference/generated/numpy.ones.html)
- To think in tensor way, replace traditional iterative programming primitives with basic tensor ops. 
- `arange` and `broadcast` to replace `for loop`. 
- `where` to replace `if-else`. 

In [2]:
# example
np.ones(5)

array([1., 1., 1., 1., 1.])

In [3]:
# v1
# def ones(n: int):
#     return arange(n) * 0 + 1


# v2
def ones(n: int):
    return where(arange(n) == 1, 1, 1)

In [4]:
ones(5)

array([1, 1, 1, 1, 1])

In [5]:
# test
assert (ones(5) == np.ones(5)).all(), f"Mismatch: {ones(5)} != {np.ones(5)}"
assert (ones(0) == np.ones(0)).all(), f"Mismatch: {ones(0)} != {np.ones(0)}"

## [sum](https://numpy.org/doc/stable/reference/generated/numpy.sum.html)
- Think about `@`, which could be `matmul` or `dot product`, to reduce dimension. 

In [6]:
# example
arange(5), np.sum(arange(5))

(array([0, 1, 2, 3, 4]), 10)

In [7]:
# v1
def sum(a_i):
    i = a_i.shape[0]
    return a_i @ ones(i)

In [8]:
sum(arange(5))

10

In [9]:
# test
assert sum(np.array([1, 2, 3, 4])) == 10, f"Sum mismatch: {sum(np.array([1, 2, 3, 4]))} != 10"
assert sum(np.array([])) == 0, f"Sum mismatch: {sum(np.array([]))} != 0"
assert sum(np.array([-1, 1])) == 0, f"Sum mismatch: {sum(np.array([-1, 1]))} != 0"

## [outer](https://numpy.org/doc/stable/reference/generated/numpy.outer.html)
- `a[:, None]` adds a `dim=1` at the end. 
- Think about adding dummy dimension and use broadcast rule to expand the dimension. 

In [10]:
# example
np.outer(arange(3), ones(5))

array([[0, 0, 0, 0, 0],
       [1, 1, 1, 1, 1],
       [2, 2, 2, 2, 2]])

In [11]:
def outer(a_i, b_j):
    return a_i[:, None] * b_j

In [12]:
outer(arange(3), ones(5))

array([[0, 0, 0, 0, 0],
       [1, 1, 1, 1, 1],
       [2, 2, 2, 2, 2]])

In [13]:
# test
assert np.array_equal(outer(np.array([1, 2]), np.array([1, 2])), np.outer(np.array([1, 2]), np.array([1, 2]))), "Outer product mismatch"
assert np.array_equal(outer(np.array([0, 1]), np.array([0, 1])), np.outer(np.array([0, 1]), np.array([0, 1]))), "Outer product mismatch"
assert np.array_equal(outer(np.array([-1, 1]), np.array([-1, 1])), np.outer(np.array([-1, 1]), np.array([-1, 1]))), "Outer product mismatch"

## [diag](https://numpy.org/doc/stable/reference/generated/numpy.diag.html)
- Think about indexing for retrival puzzle. 

In [14]:
# example
a_33 = arange(9).reshape((3, 3))
a_33, np.diag(a_33)

(array([[0, 1, 2],
        [3, 4, 5],
        [6, 7, 8]]),
 array([0, 4, 8]))

In [15]:
def diag(a_ii):  # input has to be square matrix
    i = a_ii.shape[0]
    return a_ii[arange(i), arange(i)]

In [16]:
diag(a_33)

array([0, 4, 8])

In [17]:
# test
assert np.array_equal(diag(np.array([[1, 2], [3, 4]])), np.diag(np.array([[1, 2], [3, 4]]))), "Diagonal extraction mismatch"
assert np.array_equal(diag(np.array([[0, 1], [1, 0]])), np.diag(np.array([[0, 1], [1, 0]]))), "Diagonal extraction mismatch"
assert np.array_equal(diag(np.array([[-1, -2], [-3, -4]])), np.diag(np.array([[-1, -2], [-3, -4]]))), "Diagonal extraction mismatch"

## [eye](https://numpy.org/doc/stable/reference/generated/numpy.eye.html)
- Semantically, it's similar to `outer` above. Add a dim plus simple op to expand to 2d, then use `where` to cast `bool` to `int` for final result. 

In [18]:
# example
np.eye(3)

array([[1., 0., 0.],
       [0., 1., 0.],
       [0., 0., 1.]])

In [19]:
def eye(n: int):
    return where(arange(n)[:, None] == arange(n), 1, 0)

In [20]:
eye(3)

array([[1, 0, 0],
       [0, 1, 0],
       [0, 0, 1]])

In [21]:
# test
assert np.array_equal(eye(3), np.eye(3)), "Identity matrix generation mismatch"
assert np.array_equal(eye(5), np.eye(5)), "Identity matrix generation mismatch"
assert np.array_equal(eye(1), np.eye(1)), "Identity matrix generation mismatch"

## [triu](https://numpy.org/doc/stable/reference/generated/numpy.triu.html)
- In broader context, it is the key to `attention` mechanism, and attention mask. 
- Semantically, use same trick of `eye`. 

In [22]:
# example
a_33 = arange(9).reshape((3, 3))
a_33, np.triu(a_33)

(array([[0, 1, 2],
        [3, 4, 5],
        [6, 7, 8]]),
 array([[0, 1, 2],
        [0, 4, 5],
        [0, 0, 8]]))

In [23]:
def triu(a_ii):
    i = a_ii.shape[0]
    cond = arange(i)[:, None] <= arange(i)
    return where(cond, a_ii, 0)

In [24]:
triu(a_33)

array([[0, 1, 2],
       [0, 4, 5],
       [0, 0, 8]])

In [25]:
# test
assert np.array_equal(triu(np.arange(9).reshape((3, 3))), np.triu(np.arange(9).reshape((3, 3)))), "Upper triangle matrix generation mismatch"
assert np.array_equal(triu(np.arange(4).reshape((2, 2))), np.triu(np.arange(4).reshape((2, 2)))), "Upper triangle matrix generation mismatch"
assert np.array_equal(triu(np.zeros((3, 3))), np.triu(np.zeros((3, 3)))), "Upper triangle matrix generation mismatch"

## [cumsum](https://numpy.org/doc/stable/reference/generated/numpy.cumsum.html)

In [26]:
# example
np.cumsum(ones(10))

array([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10])

In [27]:
def cumsum(a_i):
    i = a_i.shape[0]
    return a_i @ triu(ones(i)[:, None] * ones(i))

In [28]:
cumsum(ones(10))

array([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10])

In [29]:
# test
assert np.array_equal(cumsum(np.arange(10)), np.cumsum(np.arange(10))), "Cumulative sum mismatch"
assert np.array_equal(cumsum(np.zeros(10)), np.cumsum(np.zeros(10))), "Cumulative sum mismatch for zeros"
assert np.array_equal(cumsum(np.ones(10)), np.cumsum(np.ones(10))), "Cumulative sum mismatch for ones"

## [diff](https://numpy.org/doc/stable/reference/generated/numpy.diff.htm)

In [30]:
# example
np.diff(arange(5))

array([1, 1, 1, 1])

In [31]:
def diff(a_i):
    i = a_i.shape[0]
    return a_i[arange(i)[1:]] - a_i[arange(i)[:-1]]

In [32]:
diff(arange(5))

array([1, 1, 1, 1])

In [33]:
# test
random_array = np.random.randint(0, 10, size=5)
assert np.array_equal(diff(random_array), np.diff(random_array)), "Difference mismatch for random array"
assert np.array_equal(diff(np.zeros(5)), np.diff(np.zeros(5))), "Difference mismatch for zeros"
assert np.array_equal(diff(np.ones(5)), np.diff(np.ones(5))), "Difference mismatch for ones"


## [vstack](https://numpy.org/doc/stable/reference/generated/numpy.vstack.html)
- Don't have to fully develop the condition, think about broadcast. 

In [34]:
# example
np.vstack((ones(5), arange(5)))

array([[1, 1, 1, 1, 1],
       [0, 1, 2, 3, 4]])

In [35]:
def vstack(a_i, b_i):
    return where(tensor([True, False])[:, None], a_i, b_i)

In [36]:
vstack(ones(5), arange(5))

array([[1, 1, 1, 1, 1],
       [0, 1, 2, 3, 4]])

In [37]:
# test
length1 = np.random.randint(1, 10)
assert np.array_equal(vstack(ones(length1), arange(length1)), np.vstack((ones(length1), arange(length1)))), "vstack mismatch for ones and arange with random length"

length2 = np.random.randint(1, 10)
assert np.array_equal(vstack(arange(length2), arange(length2)), np.vstack((arange(length2), arange(length2)))), "vstack mismatch for two aranges with same random length"

length3 = np.random.randint(1, 10)
assert np.array_equal(vstack(arange(length3), ones(length3)), np.vstack((arange(length3), ones(length3)))), "vstack mismatch for arange and ones with random length"


## [roll](https://numpy.org/doc/stable/reference/generated/numpy.roll.html)
- My first implementation used direct `where` to deal with index i-1, which works but it's iterative programming thinking. Get used to think in tensors. 
- Whenever you want to deal with special case, the tranditonal if-else kicks in, think again in tensor form.
- Remember the goal of these exercises is not to implement functions, but to twist the brain enough to learn the pattern of tensor ops. 

In [38]:
# example
np.roll(arange(5), shift=-1)  # negative, shift left

array([1, 2, 3, 4, 0])

In [39]:
# v1
# def roll(a_i):
#     i = a_i.shape[0]
#     return a_i[where(arange(i) == i - 1, 0, arange(i) + 1)]


# v2
def roll(a_i):
    i = a_i.shape[0]
    return a_i[(arange(i) + 1) % i]

In [40]:
roll(arange(5))

array([1, 2, 3, 4, 0])

In [41]:
# test
# fmt: off
test_arrays = [
    (np.random.randint(1, 10), arange, "arange with random length"),
    (np.random.randint(1, 10), lambda x: np.random.randint(0, 100, size=x), "random array with random length"),
    (10, lambda x: np.linspace(0, 1, x), "linspace array with fixed length")
]

for length, array_func, description in test_arrays:
    test_array = array_func(length)
    assert np.array_equal(roll(test_array), np.roll(test_array, shift=-1)), f"roll mismatch for {description}: {length}, roll: {roll(test_array)}, np.roll: {np.roll(test_array, shift=1)}"
# fmt: on

## [flip](https://numpy.org/doc/stable/reference/generated/numpy.flip.html)
- For each position, original index and the flipped index sums to `i - 1`. 

In [42]:
# example
np.flip(arange(5))

array([4, 3, 2, 1, 0])

In [43]:
def flip(a_i):
    i = a_i.shape[0]
    return a_i[i - 1 - arange(i)]

In [44]:
flip(arange(5))

array([4, 3, 2, 1, 0])

In [45]:
# test
# fmt: off
test_arrays_flip = [
    (arange(5), "arange with fixed length"),
    (np.random.randint(0, 100, size=10), "random array with fixed length"),
    (np.array([0, -1, -2, -3, 0, 1, 2, 3, 0]), "array with negative numbers and zeros")
]

for test_array, description in test_arrays_flip:
    assert np.array_equal(flip(test_array), np.flip(test_array)), f"flip mismatch for {description}: {test_array}, flip: {flip(test_array)}, np.flip: {np.flip(test_array)}"
# fmt: on

## [compress](https://numpy.org/doc/stable/reference/generated/numpy.compress.html)

In [46]:
# example
np.compress(condition=tensor([True, False, True, False]), a=arange(4))

array([0, 2])

In [None]:
def compress(condition_i, a_i):
    pass