### colab link [tensor puzzles](https://colab.research.google.com/github/srush/Tensor-Puzzles/blob/main/Tensor%20Puzzlers.ipynb#scrollTo=e46e4ada)

In [127]:
# # grab helper functions provided by tensor puzzles repo https://github.com/srush/Tensor-Puzzles?tab=readme-ov-file
# !mkdir util
# !wget -q https://github.com/srush/Tensor-Puzzles/raw/main/lib.py util
# !touch __init__.py

In [None]:
from util.lib import draw_examples, make_test, run_test
import torch
import numpy as np
from torchtyping import TensorType as TT
tensor = torch.tensor

In [None]:
def arange(i: int):
    "Use this function to replace a for-loop."
    return torch.tensor(range(i))

draw_examples("arange", [{"" : arange(i)} for i in [5, 3, 9]])

In [None]:
# Example of broadcasting
examples = [(arange(4), arange(5)[:, None]) ,
            (arange(3)[:, None], arange(2))]
draw_examples("broadcast", [{"a": a, "b":b, "ret": a + b} for a, b in examples])

In [None]:
# to flip to column
arange(5), arange(5)[:, None]

In [None]:
def where(q, a, b):
    "Use this function to replace an if-statement."
    return (q * a) + (~q) * b

# In diagrams, orange is positive/True, where is zero/False, and blue is negative.

examples = [(tensor([False]), tensor([10]), tensor([0])),
            (tensor([False, True]), tensor([1, 1]), tensor([-10, 0])),
            (tensor([False, True]), tensor([1]), tensor([-10, 0])),
            (tensor([[False, True], [True, False]]), tensor([1]), tensor([-10, 0])),
            (tensor([[False, True], [True, False]]), tensor([[0], [10]]), tensor([-10, 0])),
           ]
draw_examples("where", [{"q": q, "a":a, "b":b, "ret": where(q, a, b)} for q, a, b in examples])

In [None]:
~arange(5)

In [None]:
### puzzle 1
def ones_spec(out):
    for i in range(len(out)):
        out[i] = 1
        
def ones(i: int) -> TT["i"]:
    """create tensor of length i, zero out, then add 1"""
    return arange(i) * 0 + 1
    
test_ones = make_test("one", ones, ones_spec, add_sizes=["i"])

In [None]:
run_test(test_ones)

In [None]:
# puzzle 2
def sum_spec(a, out):
    out[0] = 0
    for i in range(len(a)):
        out[0] += a[i]
        
def sum(a: TT["i"]) -> TT[1]:
    """multiply a (1, X) matrix by a (X, 1) matrix"""
    return (a[None, :] @ ones(len(a))[:, None])[0]

test_sum = make_test("sum", sum, sum_spec)

In [None]:
run_test(test_sum)

In [None]:
# puzzle 3
def outer_spec(a, b, out):
    for i in range(len(out)):
        for j in range(len(out[0])):
            out[i][j] = a[i] * b[j]
            
def outer(a: TT["i"], b: TT["j"]) -> TT["i", "j"]:
    """goal is (mxn matrix); if a is of length m and b of length n, convert them to (mx1) (1xn) matrices"""
    return a[:, None] @ b[None, :]
    
test_outer = make_test("outer", outer, outer_spec)

In [None]:
run_test(test_outer)