In [375]:
from dataclasses import dataclass
from typing import (
    Any,
    Callable,
    Dict,
    Iterator,
    List,
    Optional,
    Union,
)

import re

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import einops

In [331]:
SEED=42
rng = np.random.default_rng(SEED)

## Defining the Tensor and Recipe classes

In [16]:
@dataclass(frozen=True)
class Recipe:
    func: Callable
    args: tuple
    kwargs: Dict[str, Any]
    parents: Dict[int, Tensor]

In [163]:
Arr = np.ndarray

class Tensor:
    '''
    A drop-in replacement for torch.Tensor supporting a subset of features.
    '''

    array: Arr
    "The underlying array. Can be shared between multiple Tensors."
    requires_grad: bool
    "If True, calling functions or methods on this tensor will track relevant data for backprop."
    grad: Optional["Tensor"]
    "Backpropagation will accumulate gradients into this field."
    recipe: Optional[Recipe]
    "Extra information necessary to run backpropagation."

    def __init__(self, array: Union[Arr, list], requires_grad=False):
        self.array = array if isinstance(array, Arr) else np.array(array)
        if self.array.dtype == np.float64:
            self.array = self.array.astype(np.float32)
        self.requires_grad = requires_grad
        self.grad = np.zeros_like(self.array).astype(np.float32)
        self.recipe = None
        "If not None, this tensor's array was created via recipe.func(*recipe.args, **recipe.kwargs)."

    def __neg__(self) -> "Tensor":
        return negative(self)

    def __add__(self, other) -> "Tensor":
        return add(self, other)

    def __radd__(self, other) -> "Tensor":
        return add(other, self)

    def __sub__(self, other) -> "Tensor":
        return subtract(self, other)

    def __rsub__(self, other):
        return subtract(other, self)

    def __mul__(self, other) -> "Tensor":
        return multiply(self, other)

    def __rmul__(self, other) -> "Tensor":
        return multiply(other, self)

    def __truediv__(self, other) -> "Tensor":
        return true_divide(self, other)

    def __rtruediv__(self, other) -> "Tensor":
        return true_divide(other, self)

    def __matmul__(self, other) -> "Tensor":
        return matmul(self, other)

    def __rmatmul__(self, other) -> "Tensor":
        return matmul(other, self)

    def __eq__(self, other) -> "Tensor":
        return eq(self, other)

    def __repr__(self) -> str:
        return f"Tensor({repr(self.array)}, requires_grad={self.requires_grad})"

    def __len__(self) -> int:
        if self.array.ndim == 0:
            raise TypeError
        return self.array.shape[0]

    def __hash__(self) -> int:
        return id(self)

    def __getitem__(self, index) -> "Tensor":
        return getitem(self, index)

    def add_(self, other: "Tensor", alpha: float = 1.0) -> "Tensor":
        add_(self, other, alpha=alpha)
        return self

    @property
    def T(self) -> "Tensor":
        return permute(self, axes=(-1, -2))

    def item(self):
        return self.array.item()

    def sum(self, dim=None, keepdim=False):
        return sum(self, dim=dim, keepdim=keepdim)

    def log(self):
        return log(self)

    def exp(self):
        return exp(self)

    def reshape(self, new_shape):
        return reshape(self, new_shape)

    def expand(self, new_shape):
        return expand(self, new_shape)

    def permute(self, dims):
        return permute(self, dims)

    def maximum(self, other):
        return maximum(self, other)

    def relu(self):
        return relu(self)

    def argmax(self, dim=None, keepdim=False):
        return argmax(self, dim=dim, keepdim=keepdim)

    def uniform_(self, low: float, high: float) -> "Tensor":
        self.array[:] = np.random.uniform(low, high, self.array.shape)
        return self

    def backward(self, end_grad: Union[Arr, "Tensor", None] = None) -> None:
        if isinstance(end_grad, Arr):
            end_grad = Tensor(end_grad)
        return backprop(self, end_grad)

    def size(self, dim: Optional[int] = None):
        if dim is None:
            return self.shape
        return self.shape[dim]

    @property
    def shape(self):
        return self.array.shape

    @property
    def ndim(self):
        return self.array.ndim

    @property
    def is_leaf(self):
        '''Same as https://pytorch.org/docs/stable/generated/torch.Tensor.is_leaf.html'''
        if self.requires_grad and self.recipe and self.recipe.parents:
            return False
        return True

    def __bool__(self):
        if np.array(self.shape).prod() != 1:
            raise RuntimeError("bool value of Tensor with more than one value is ambiguous")
        return bool(self.item())

def empty(*shape: int) -> Tensor:
    '''Like torch.empty.'''
    return Tensor(np.empty(shape))

def zeros(*shape: int) -> Tensor:
    '''Like torch.zeros.'''
    return Tensor(np.zeros(shape))

def arange(start: int, end: int, step=1) -> Tensor:
    '''Like torch.arange(start, end).'''
    return Tensor(np.arange(start, end, step=step))

def tensor(array: Arr, requires_grad=False) -> Tensor:
    '''Like torch.tensor.'''
    return Tensor(array, requires_grad=requires_grad)

## Forward and Backward Functions

In [382]:
# TODO: implement broadcasting and unbroadcasting?
def log_back(grad_out, out, x):
    """Returns the gradient of the loss w.r.t. x via out = log(x)."""
    return grad_out / x

def add_back0(grad_out, out, a, b):
    """Returns the gradient of the loss w.r.t. a via out = a + b, where + is the element-wise addition."""
    return grad_out

def add_back1(grad_out, out, a, b):
    """Returns the gradient of the loss w.r.t. b via out = a + b, where + is element-wise addition."""
    return grad_out

def multiply_back0(grad_out, out, a, b):
    """Returns the gradient of the loss w.r.t. a via out = a * b, where * is element-wise multiplication."""
    return grad_out * b

def multiply_back1(grad_out, out, a, b):
    """Returns the gradient of the loss w.r.t. b via out = a * b, where * is element-wise multiplication."""
    return grad_out * a

def true_divide_back0(grad_out, out, a, b):
    return grad_out / b

def true_divide_back1(grad_out, out, a, b):
    return -grad_out * a / (b ** 2)

def maximum_back0(grad_out, out, a, b):
    return np.where(a >= b, grad_out, np.zeros_like(grad_out))

def maximum_back1(grad_out, out, a, b):
    return np.where(a < b, grad_out, np.zeros_like(grad_out))

def matmul_back0(grad_out, out, a, b):
    return grad_out @ b.T

def matmul_back1(grad_out, out, a, b):
    return a.T @ grad_out

def sum_back(grad_out, out, x, axis=None, keepdims=False):
    if not isinstance(grad_out, Arr):
        grad_out = np.array(grad_out)

    if axis is None:
        axis = list(range(x.ndim))

    if keepdims == False:
        print(grad_out, axis)
        grad_out = np.expand_dims(grad_out, axis)

    return np.broadcast_to(grad_out, x.shape)

In [383]:
from collections import defaultdict

BACK_FUNCS = defaultdict(dict)

BACK_FUNCS[np.log][0] = log_back
BACK_FUNCS[np.add][0] = add_back0
BACK_FUNCS[np.add][1] = add_back1
BACK_FUNCS[np.multiply][0] = multiply_back0
BACK_FUNCS[np.multiply][1] = multiply_back1
BACK_FUNCS[np.true_divide][0] = true_divide_back0
BACK_FUNCS[np.true_divide][1] = true_divide_back1
BACK_FUNCS[np.maximum][0] = maximum_back0
BACK_FUNCS[np.maximum][1] = maximum_back1
BACK_FUNCS[np.matmul][0] = matmul_back0
BACK_FUNCS[np.matmul][1] = matmul_back1
BACK_FUNCS[np.sum][0] = sum_back

### Practice implementing a forward function

In [66]:
def log(x: Tensor) -> Tensor:
    t = tensor(np.log(x.array), requires_grad=True)
    t.recipe = Recipe(
        func=np.log,
        args=x,
        kwargs=None,
        parents={0: x}
    )
    return t

### Generalized forward function wrapper

In [91]:
def wrap_forward_fn(inner_func: Callable[[Arr], Arr]) -> Callable[[Tensor], Tensor]:
    def func(*args, **kwargs):
        np_args = tuple(arg.array for arg in args)
        np_kwargs = {key: arg.array for key, arg in kwargs}
        t = tensor(inner_func(*np_args, **np_kwargs), requires_grad=True)
        t.recipe = Recipe(
            func=inner_func,
            args=np_args,
            kwargs=np_kwargs,
            parents={i: arg for i, arg in enumerate(args)}
        )
        return t

    return func

### Forward functions

In [328]:
log = wrap_forward_fn(np.log)
add = wrap_forward_fn(np.add)
multiply = wrap_forward_fn(np.multiply)
true_divide = wrap_forward_fn(np.true_divide)
maximum = wrap_forward_fn(np.maximum)
matmul = wrap_forward_fn(np.matmul)
sum = wrap_forward_fn(np.sum)

In [359]:
def relu(x: Tensor) -> Tensor:
    return maximum(x, tensor(np.zeros_like(x.array)))

## Autograd

In [203]:
def topological_sort(root: Tensor) -> List[Tensor]:
    # first do a DFS to get nodes with no parents and also a mapping from a node to its children
    node_to_num_remaining_parents = {}
    node_to_children = defaultdict(list)
    def scan(node: Tensor):
        if node.recipe is None:
            node_to_num_remaining_parents[node] = 0
            return
        node_to_num_remaining_parents[node] = len(node.recipe.parents)
        for parent in node.recipe.parents.values():
            scan(parent)
            node_to_children[parent].append(node)

    scan(root)

    # then iteratively add to the topological order from nodes with no remaining parents
    topological_order = []
    visited = set()

    nodes_with_no_remaining_parents = [
        node
        for node, num_remaining_parents in node_to_num_remaining_parents.items()
        if num_remaining_parents == 0
    ]
    
    while len(nodes_with_no_remaining_parents) > 0:
        cur_node = nodes_with_no_remaining_parents.pop()
        topological_order.append(cur_node)
        visited.add(cur_node)
        for child in node_to_children[cur_node]:
            if child in visited:
                raise AssertionError("Found a cycle in the DAG")
            if child not in visited:
                node_to_num_remaining_parents[child] -= 1
                if node_to_num_remaining_parents[child] == 0:
                    nodes_with_no_remaining_parents.append(child)
                
    return topological_order

In [371]:
def backprop(t: Tensor, end_grad: Optional[Tensor]):
    # do a topological sort.
    # visit the nodes in order and for each node call the right backward function
    # what is end grad for? it's the last grad_out -- if it's none just make it all ones.
    topologically_sorted_computational_graph = topological_sort(t)
    if end_grad is None:
        end_grad = np.ones_like(t.array)
    t.grad = end_grad
    
    for cur_tensor in topologically_sorted_computational_graph[::-1]:
        if cur_tensor.is_leaf:
            continue
        for cur_tensor_parent_index, cur_tensor_parent in cur_tensor.recipe.parents.items():
            print(cur_tensor.grad)
            print(cur_tensor.recipe.args)
            print(cur_tensor.recipe.kwargs)
            print(f"{cur_tensor.recipe.func=}")
            print(BACK_FUNCS[cur_tensor.recipe.func][cur_tensor_parent_index])
            cur_tensor_parent.grad += BACK_FUNCS[cur_tensor.recipe.func][cur_tensor_parent_index](
                cur_tensor.grad,
                cur_tensor.array,
                *cur_tensor.recipe.args,
                **cur_tensor.recipe.kwargs,
            )

## Defining nn.Parameter and nn.Module

In [225]:
class Parameter(Tensor):
    def __init__(self, t: Tensor, requires_grad=True):
        return super().__init__(t.array, requires_grad=requires_grad)

    def __repr__(self):
        return f"Parameter containing: {super().__repr__()}"

In [317]:
class Module:
    _modules: Dict[str, "Module"]
    _parameters: Dict[str, Parameter]

    def __init__(self):
        self._modules = {}
        self._parameters = {}

    def modules(self):
        '''Return the direct child modules of this module.'''
        return self.__dict__["_modules"].values()

    def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
        '''
        Return an iterator over Module parameters.

        recurse: if True, the iterator includes parameters of submodules, recursively.
        '''
        def param_iter():
            for param in self._parameters.values():
                yield param

            if recurse:
                for submodule in self._modules.values():
                    for param in submodule.parameters(recurse=True):
                        yield param

        return param_iter()

    def __setattr__(self, key: str, val: Any) -> None:
        '''
        If val is a Parameter or Module, store it in the appropriate _parameters or _modules dict.
        Otherwise, call __setattr__ from the superclass.
        '''
        if isinstance(val, Parameter):
            self.__dict__["_parameters"][key] = val
        elif isinstance(val, Module):
            self.__dict__["_modules"][key] = val
        else:
            super().__setattr__(key, val)

    def __getattr__(self, key: str) -> Union[Parameter, "Module"]:
        '''
        If key is in _parameters or _modules, return the corresponding value.
        Otherwise, raise KeyError.
        '''
        if key in self.__dict__["_parameters"].keys():
            return self.__dict__["_parameters"][key]
        elif key in self.__dict__["_modules"].keys():
            return self.__dict__["_modules"][key]
        else:
            raise KeyError("key not in submodules or parameters")

    def __call__(self, *args, **kwargs):
        return self.forward(*args, **kwargs)

    def forward(self):
        raise NotImplementedError("Subclasses must implement forward!")

    def __repr__(self):
        def _indent(s_, numSpaces):
            return re.sub("\n", "\n" + (" " * numSpaces), s_)
        lines = [f"({key}): {_indent(repr(module), 2)}" for key, module in self._modules.items()]
        return "".join([
            self.__class__.__name__ + "(",
            "\n  " + "\n  ".join(lines) + "\n" if lines else "", ")"
        ])

### Defining Linear and ReLU modules

In [386]:
class Linear(Module):
    def __init__(self, in_features, out_features, bias=None):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features

        # TODO: implement permute and swap the order
        self.weight = Parameter(tensor(rng.standard_normal((in_features, out_features))))
        self.bias = Parameter(tensor(rng.standard_normal(out_features))) if bias is not None else None

    def forward(self, x):
        out = matmul(x, self.weight)
        if self.bias is not None:
            out = out + self.bias

        return out

In [387]:
class ReLU(Module):
    def forward(self, x: Tensor) -> Tensor:
        return relu(x)

In [388]:
class SimpleCIFAR10MLP(Module):
    def __init__(self):
        super().__init__()
        self.linear1 = Linear(3072, 200)
        self.relu1 = ReLU()
        self.linear2 = Linear(200, 200)
        self.relu2 = ReLU()
        self.output = Linear(200, 10)

    def forward(self, x):
        x = self.relu1(self.linear1(x))
        x = self.relu2(self.linear2(x))
        x = self.output(x)
        return x

## Defining our own XE Loss

In [None]:
def cross_entropy_loss(logits: Tensor, true_labels: Tensor) -> Tensor:
    # TODO: finish this
    pass

## Tests

### Calling Backward Functions directly

In [180]:
a = np.array([1, 2])
b = np.array([3, 4])
c = np.array([5, 3])
d = np.add(a, b)
e = np.log(c)
f = np.multiply(d, e)
g = np.sum(f)

dg_dg = np.ones_like(g)
dg_df = sum_back(dg_dg, g, f)
dg_de = multiply_back1(dg_df, f, d, e)
dg_dd = multiply_back0(dg_df, f, d, e)
dg_dc = log_back(dg_de, e, c)
dg_db = add_back1(dg_dd, d, a, b)
dg_da = add_back0(dg_dd, d, a, b)

In [181]:
old_g = g.copy()
a = np.array([1, 2])
b = np.array([3, 4])
c = np.array([5.01, 3])
d = np.add(a, b)
e = np.log(c)
f = np.multiply(d, e)
g = np.sum(f)

In [182]:
g - old_g

0.00799201065069255

In [183]:
dg_dc

array([0.8, 2. ])

### Forward Functions and Autograd

In [192]:
a = tensor([1, 2])
b = tensor([3, 4])
c = tensor([5, 3])
d = add(a, b)
e = log(c)
f = multiply(d, e)
g = sum(f)

g.backward()

1.0
(array([6.437752, 6.591674], dtype=float32),)
{}
<function sum_back at 0x13d0ecd60>
1.0 [0]
[1. 1.]
(array([4, 6]), array([1.609438 , 1.0986123], dtype=float32))
{}
<function multiply_back0 at 0x13d3c0ea0>
[1. 1.]
(array([4, 6]), array([1.609438 , 1.0986123], dtype=float32))
{}
<function multiply_back1 at 0x13d326980>
[1.609438  1.0986123]
(array([1, 2]), array([3, 4]))
{}
<function add_back0 at 0x13dcc6de0>
[1.609438  1.0986123]
(array([1, 2]), array([3, 4]))
{}
<function add_back1 at 0x13dcc7560>
[4. 6.]
(array([5, 3]),)
{}
<function log_back at 0x13dcc7100>


### testing the MLP

In [390]:
mlp = SimpleCIFAR10MLP()
sample_input = tensor(rng.random((5, 3072)))
sample_output = mlp(sample_input)

sample_output.backward()

[[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
(array([[5.58115273e+01, 2.01142319e+02, 0.00000000e+00, 3.32242706e+02,
        0.00000000e+00, 4.11441193e+02, 0.00000000e+00, 1.73602966e+02,
        2.70570350e+00, 1.60386032e+02, 4.48463364e+01, 0.00000000e+00,
        3.17046623e+01, 1.09150528e+02, 0.00000000e+00, 3.09756165e+02,
        0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 2.64124115e+02,
        1.93841187e+02, 3.54664337e+02, 1.81410568e+02, 9.26867447e+01,
        1.34789230e+02, 1.48257019e+02, 7.17004028e+02, 0.00000000e+00,
        0.00000000e+00, 4.43841476e+01, 1.57959610e+02, 3.76723755e+02,
        3.10669434e+02, 0.00000000e+00, 3.08302643e+02, 5.14865479e+02,
        0.00000000e+00, 1.98457611e+02, 0.00000000e+00, 1.14460510e+02,
        4.05174744e+02, 1.29503525e+02, 0.00000000e+00, 3.90764923e+02,
        0.00000000e+00, 0.00000000e+00, 0

## Defining a basic training + testing loop

In [None]:
# TODO: implement
class SGD:
    def __init__(self, params: Iterable[Parameter], lr: float):
        pass

    def zero_grad(self) -> None:
        pass

    def step(self) -> None:
        pass


def train(model: MLP, train_loader: DataLoader, optimizer: SGD, epoch: int, train_loss_list: Optional[list] = None):
    pass

def test(model: MLP, test_loader: DataLoader, test_loss_list: Optional[list] = None):
    pass

## Pulling in CIFAR-10

In [262]:
def unpickle(file):
    import pickle
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict

In [278]:
data_batch_1 = unpickle('data/cifar-10-batches-py/data_batch_1')
data = data_batch_1[b'data']

## Training