## Tensor backbone

In [1]:
from dataclasses import dataclass
from typing import Union, List, Callable, Optional, Tuple, Literal

import numpy as np

Scalar = Union[int, float]

Data = Union[Scalar, list, np.ndarray, "Tensor"]


@dataclass(frozen=True)
class Leaf:
    value: "Tensor"
    grad_fn: Callable[[np.ndarray], np.ndarray]


class Tensor:
    def __init__(
        self,
        data: Data,
        requires_grad: bool = False,
        dependencies: Optional[List[Leaf]] = None,
        dtype=np.float32
    ):
        self._data = Tensor.build_ndarray(data, dtype)
        self.dtype = dtype

        self.requires_grad = requires_grad
        self.dependencies = dependencies or []

        self.grad = np.zeros_like(self._data) if requires_grad else None

    @property
    def data(self) -> np.ndarray:
        return self._data

    @data.setter
    def data(self, data: Data):
        self._data = Tensor.build_ndarray(data, self.dtype)
        if self.requires_grad:
            self.zero_grad()

    @property
    def size(self) -> int:
        return self.data.size

    @property
    def shape(self) -> Tuple[int, ...]:
        return self.data.shape

    @property
    def ndim(self) -> int:
        return self.data.ndim

    @staticmethod
    def build_ndarray(data: Data, dtype=np.float32) -> np.ndarray:
        if isinstance(data, Tensor):
            return np.array(data.data, dtype=dtype)
        if isinstance(data, np.ndarray):
            return data.astype(dtype)
        return np.array(data, dtype=dtype)
    
    @staticmethod
    def data_gate(data: Data) -> "Tensor":
        if isinstance(data, Tensor):
            return data
        return Tensor(data)

    def __repr__(self):
        return f"Tensor({self.data}, requires_grad={self.requires_grad}, shape={self.shape})"

    def zero_grad(self):
        if self.grad is None:
            self.grad = np.zeros_like(self._data)
        else:
            self.grad.fill(0.0)

    def backward(self, grad: Optional[np.ndarray] = None) -> None:
        if not self.requires_grad:
            raise RuntimeError(
                "Cannot call backward() on a tensor that does not require gradients. "
                "If you need gradients, ensure that requires_grad=True when creating the tensor."
            )

        if grad is None:
            if self.shape == ():
                grad = np.array(1.0)
            else:
                raise ValueError("Grad must be provided if tensor has shape")
            
        self.grad = self.grad + grad

        for dependency in self.dependencies:
            backward_grad = dependency.grad_fn(grad)
            dependency.value.backward(backward_grad)

    def transpose(self, axes: Tuple[int, ...] = None) -> "Tensor":
        # Perform the transpose operation
        output = np.transpose(self.data, axes=axes)

        # Handle dependencies for autograd
        dependencies: List[Leaf] = []

        if self.requires_grad:
            def _bkwd(grad: np.ndarray) -> np.ndarray:
                # Compute the inverse permutation of axes for the backward function
                if axes is None:
                    # Implicitly reverses transpose
                    return np.transpose(grad)  
                else:
                    # Compute the inverse permutation of axes
                    inv_axes = tuple(np.argsort(axes))
                    # Transpose the gradient back using the inverse permutation
                    return np.transpose(grad, axes=inv_axes)

            dependencies.append(
                Leaf(value=self, grad_fn=_bkwd)
            )

        # Return the new tensor with the transposed data
        return Tensor(
            output,
            requires_grad=self.requires_grad,
            dependencies=dependencies
        )

    @property
    def T(self):
        return self.transpose()

    @staticmethod
    def matmul(a: "Tensor", b: "Tensor") -> "Tensor":
        r"""
        Static method to perform matrix multiplication of two tensors.

        Args:
            a (Tensor): First matrix.
            b (Tensor): Second matrix.

        Returns:
            Tensor: Resulting tensor with tracked dependencies.
        """
        
        output = a.data @ b.data
        requires_grad = a.requires_grad or b.requires_grad
        dependencies = []

        if a.requires_grad:
            def _bkwd_a(grad: np.ndarray) -> np.ndarray:
                if b.ndim > 1:
                    return grad @ b.data.swapaxes(-1, -2)
                return np.outer(grad, b.data.T).squeeze()
            
            dependencies.append(
                Leaf(
                    value=a,
                    grad_fn=_bkwd_a
                )
            )

        if b.requires_grad:
            def _bkwd_b(grad: np.ndarray) -> np.ndarray:
                if a.ndim > 1:
                    return a.data.swapaxes(-1, -2) @ grad
                return np.outer(a.data.T, grad).squeeze()
            
            dependencies.append(
                Leaf(
                    value=b,
                    grad_fn=_bkwd_b
                )
            )

        return Tensor(output, requires_grad, dependencies)
    
    def dot(self, other: Data) -> "Tensor":
        return Tensor.matmul(self, Tensor.data_gate(other))
    
    def __matmul__(self, other: Data) -> "Tensor":
        return self.dot(other)

    @staticmethod
    def bkwd_broadcast(tensor: "Tensor"):
        def _bkwd(grad: np.ndarray) -> np.ndarray:
            if tensor.ndim == 0:
                return np.sum(grad)

            if grad.ndim == 0:
                return grad

            ndim_added = max(0, grad.ndim - tensor.ndim)

            if ndim_added > 0:
                grad = grad.sum(axis=tuple(range(ndim_added)), keepdims=False)

            reduce_axes = tuple(
                dim for dim in range(tensor.ndim)
                if tensor.shape[dim] == 1 and grad.shape[dim] > 1
            )

            if reduce_axes:
                grad = grad.sum(axis=reduce_axes, keepdims=True)

            if grad.shape != tensor.shape:
                grad = grad.reshape(tensor.shape)

            return grad

        return _bkwd

    @staticmethod
    def add(a: "Tensor", b: "Tensor") -> "Tensor":
        output = a.data + b.data

        requires_grad = a.requires_grad or b.requires_grad

        dependencies = []

        if a.requires_grad:
            dependencies.append(
                Leaf(
                    value=a,
                    grad_fn=Tensor.bkwd_broadcast(a)
                )
            )

        if b.requires_grad:
            dependencies.append(
                Leaf(
                    value=b,
                    grad_fn=Tensor.bkwd_broadcast(b)
                )
            )

        return Tensor(output, requires_grad, dependencies)

    def __add__(self, other: Data) -> "Tensor":
        return Tensor.add(self, Tensor.data_gate(other))
    
    def __radd__(self, other: Data) -> "Tensor":
        return Tensor.add(Tensor.data_gate(other), self)
    
    def __iadd__(self, other: Data) -> "Tensor":
        self.data = self.data + Tensor.build_ndarray(other)
        return self
    
    def __neg__(self) -> "Tensor":
        output = -self.data
        dependencies = []

        if self.requires_grad:
            dependencies.append(
                Leaf(value=self, grad_fn=lambda grad: -grad)
            )

        return Tensor(output, self.requires_grad, dependencies)
    
    def __sub__(self, other: Data) -> "Tensor":
        return self + (-Tensor.data_gate(other))

    def __rsub__(self, other: Data) -> "Tensor":
        return Tensor.data_gate(other) + (-self)

    def __isub__(self, other: Data) -> "Tensor":
        self.data = self.data - Tensor.build_ndarray(other)
        return self
    
    @staticmethod
    def mul(a: "Tensor", b: "Tensor") -> "Tensor":
        output = a.data * b.data

        requires_grad = a.requires_grad or b.requires_grad
        dependencies = []

        def _backward(a: "Tensor", b: "Tensor"):
            def _bkwd(grad: np.ndarray) -> np.ndarray:
                grad = grad * b
                return Tensor.bkwd_broadcast(a)(grad)
            return _bkwd

        if a.requires_grad:
            dependencies.append(
                Leaf(
                    value=a,
                    grad_fn=_backward(a, b)
                )
            )

        if b.requires_grad:
            dependencies.append(
                Leaf(
                    value=b,
                    grad_fn=_backward(b, a)
                )
            )

        return Tensor(output, requires_grad, dependencies)
    
    def __mul__(self, other: Data) -> "Tensor":
        return Tensor.mul(self, Tensor.data_gate(other))
    
    def __rmul__(self, other: Data) -> "Tensor":
        return Tensor.mul(Tensor.data_gate(other), self)
    
    def __imul__(self, other: Data) -> "Tensor":
        self.data = self.data * Tensor.build_ndarray(other)
        return self

    def log(self) -> "Tensor":
        output = np.log(self.data)

        dependencies = []

        if self.requires_grad:
            def _bkwd(grad: np.ndarray) -> np.ndarray:
                return grad / self.data
            
            dependencies.append(
                Leaf(
                    value=self,
                    grad_fn=_bkwd
                )
            )

        return Tensor(output, self.requires_grad, dependencies)

    def tanh(self) -> "Tensor":
        output = np.tanh(self.data)

        dependencies = []

        if self.requires_grad:
            def _bkwd(grad: np.ndarray) -> np.ndarray:
                return grad * (1 - output**2)

            dependencies.append(
                Leaf(
                    value=self,
                    grad_fn=_bkwd
                )
            )

        return Tensor(output, self.requires_grad, dependencies)

    def pow(self, p: Union[int, float]) -> "Tensor":
        output = self.data**p

        dependencies = []

        if self.requires_grad:
            def _bkwd(grad: np.ndarray) -> np.ndarray:
                return grad * (p * (self.data**(p - 1)))

            dependencies.append(
                Leaf(
                    value=self,
                    grad_fn=_bkwd
                )
            )

        return Tensor(output, self.requires_grad, dependencies)
    
    def __pow__(self, p: Union[int, float]) -> "Tensor":
        return self.pow(p)

    def __truediv__(self, other: Data) -> "Tensor":
        other = Tensor.data_gate(other)
        return self * (other**-1)

    def __rtruediv__(self, other: Data) -> "Tensor":
        other = Tensor.data_gate(other)
        return other * (self**-1)

    def __itruediv__(self, other: Data) -> "Tensor":
        self.data = self.data / Tensor.build_ndarray(other)
        return self

    def exp(self) -> "Tensor":
        output = np.exp(self.data)

        dependencies = []

        if self.requires_grad:
            def _bkwd(grad: np.ndarray) -> np.ndarray:
                return grad * output

            dependencies.append(
                Leaf(
                    value=self,
                    grad_fn=_bkwd
                )
            )

        return Tensor(output, self.requires_grad, dependencies)

    def squeeze(self, axis: Optional[Union[int, Tuple[int]]] = None) -> "Tensor":
        output = np.squeeze(self.data, axis=axis)

        dependencies = []

        if self.requires_grad:
            def _bkwd(grad: np.ndarray) -> np.ndarray:
                if axis is None:
                    return grad.reshape(self.shape)
                return np.expand_dims(grad, axis=axis)
            
            dependencies.append(Leaf(value=self, grad_fn=_bkwd))

        return Tensor(output, self.requires_grad, dependencies)

    def unsqueeze(self, dim: int) -> "Tensor":
        output = np.expand_dims(self.data, axis=dim)

        dependencies = []

        if self.requires_grad:
            def _bkwd(grad: np.ndarray) -> np.ndarray:
                return np.squeeze(grad, axis=dim)
            
            dependencies.append(Leaf(value=self, grad_fn=_bkwd))
        
        return Tensor(output, self.requires_grad, dependencies)
    
    def view(self, shape: Tuple[int, ...]) -> "Tensor":
        output = self.data.reshape(shape)
        dependencies = []

        if self.requires_grad:
            def _bkwd(grad: np.ndarray) -> np.ndarray:
                return grad.reshape(self.shape)
            
            dependencies.append(Leaf(value=self, grad_fn=_bkwd))

        return Tensor(output, self.requires_grad, dependencies)
    
    # Comparison Operators
    def __lt__(self, other: Data) -> "Tensor":
        other = Tensor.data_gate(other)
        return Tensor(self.data < other.data)
    
    def __gt__(self, other: Data) -> "Tensor":
        other = Tensor.data_gate(other)
        return Tensor(self.data > other.data)
    
    def __eq__(self, other: Data) -> "Tensor":
        other = Tensor.data_gate(other)
        return Tensor(self.data == other.data)
    
    def __le__(self, other: Data) -> "Tensor":
        other = Tensor.data_gate(other)
        return Tensor(self.data <= other.data)
    
    def __ge__(self, other: Data) -> "Tensor":
        other = Tensor.data_gate(other)
        return Tensor(self.data >= other.data)

    def __ne__(self, other: Data) -> "Tensor":
        other = Tensor.data_gate(other)
        return Tensor(self.data != other.data)

    @staticmethod
    def where(condition: "Tensor", a: "Tensor", b: "Tensor") -> "Tensor":
        output = np.where(condition.data, a.data, b.data)

        requires_grad = a.requires_grad or b.requires_grad
        dependencies = []

        if a.requires_grad:
            def _bkwd_a(grad: np.ndarray) -> np.ndarray:
                return np.where(condition.data, grad, 0.0)

            dependencies.append(Leaf(value=a, grad_fn=_bkwd_a))
        
        if b.requires_grad:
            def _bkwd_b(grad: np.ndarray) -> np.ndarray:
                return np.where(condition.data, 0.0, grad)

            dependencies.append(Leaf(value=b, grad_fn=_bkwd_b))

        return Tensor(output, requires_grad, dependencies)

    @staticmethod
    def maximum(a: Data, b: Data) -> "Tensor":
        a, b = Tensor.data_gate(a), Tensor.data_gate(b)

        return Tensor.where(a > b, a, b)

    @staticmethod
    def minimum(a: Data, b: Data) -> "Tensor":
        a, b = Tensor.data_gate(a), Tensor.data_gate(b)

        return Tensor.where(a < b, a, b)

    def threshold(self, threshold: float, value: float) -> "Tensor":
        return Tensor.where(self > threshold, self, Tensor(value))

    def masked_fill(self, mask: "Tensor", value: float) -> "Tensor":
        return Tensor.where(mask, Tensor(value), self)

    def sign(self) -> "Tensor":
        return Tensor.where(
            self > 0, Tensor(1),
            Tensor.where(self < 0, Tensor(-1), Tensor(0))
        )

    def clip(self, min_value: Optional[float] = None, max_value: Optional[float] = None) -> "Tensor":
        return Tensor.where(
            self < min_value, Tensor(min_value),
            Tensor.where(self > max_value, Tensor(max_value), self)
        )

    def __getitem__(self, index: Union[int, slice, List[int], Tuple[int, ...], np.ndarray, "Tensor"]) -> "Tensor":
        if isinstance(index, (Tensor, np.ndarray)):
            index = Tensor.data_gate(index).data

        output = self.data[index]
        dependencies = []

        if self.requires_grad:
            def _bkwd(grad: np.ndarray) -> np.ndarray:
                full_grad = np.zeros_like(self.data)
                np.add.at(full_grad, index, grad)
                return full_grad
            dependencies.append(Leaf(value=self, grad_fn=_bkwd))

        return Tensor(output, self.requires_grad, dependencies)

    def abs(self) -> "Tensor":
        return Tensor.where(self >= 0, self, -self)

    def bkwd_minmax(
        self,
        output: np.ndarray,
        axis: Optional[Union[int, Tuple[int, ...]]] = None,
        keepdims: bool = False
    ) -> np.ndarray:
        def _bkwd(grad: np.ndarray) -> np.ndarray:
            mask = (self.data == output)

            count = np.sum(mask) if axis is None \
                else np.sum(mask, axis=axis, keepdims=True)
            
            grad_expanded = grad if keepdims or axis is None \
                else np.expand_dims(grad, axis=axis)
            
            return mask * (grad_expanded / count)

        return _bkwd

    def min(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> "Tensor":
        output = np.min(self.data, axis=axis, keepdims=keepdims)
        dependencies = []

        if self.requires_grad:
            dependencies.append(
                Leaf(
                    value=self,
                    grad_fn=self.bkwd_minmax(output, axis, keepdims)
                )
            )

        return Tensor(output, self.requires_grad, dependencies)

    def max(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> "Tensor":
        output = np.max(self.data, axis=axis, keepdims=keepdims)
        dependencies = []

        if self.requires_grad:
            dependencies.append(
                Leaf(
                    value=self,
                    grad_fn=self.bkwd_minmax(output, axis, keepdims)
                )
            )

        return Tensor(output, self.requires_grad, dependencies)

    def sum(self, axis: Optional[int] = None, keepdims: bool = False) -> "Tensor":
        output = np.sum(self.data, axis=axis, keepdims=keepdims)
        dependencies = []

        if self.requires_grad:
            def _bkwd(grad: np.ndarray) -> np.ndarray:
                full_grad = np.ones_like(self.data)

                if axis is None:
                    return full_grad * grad
                
                grad_expanded = grad if keepdims else np.expand_dims(grad, axis=axis)

                return full_grad * grad_expanded

            dependencies.append(
                Leaf(
                    value=self,
                    grad_fn=_bkwd
                )
            )

        return Tensor(output, self.requires_grad, dependencies)

    def mean(self, axis: Optional[int] = None, keepdims: bool = False) -> "Tensor":
        count = self.data.shape[axis] if axis is not None else self.size
        return self.sum(axis=axis, keepdims=keepdims) / count


In [None]:
a = Tensor([[1, 2],
            [3, 4]], requires_grad=True)
b = Tensor([10], requires_grad=True)

c = a + b

d = c - a * b + c * a * b

q = c / (d.log().tanh() ** 3).exp() / 2 + c

q.backward(np.ones_like(d.data))

a.grad, b.grad


## Chain Rule

$$
\frac{dz}{dx} = \frac{dz}{dy} \cdot \frac{dy}{dx}
$$

If we have a function composition:  

$$
f(x) = g(h(x))
$$

Then, by the chain rule:

$$
f'(x) = g'(h(x)) \cdot h'(x)
$$

In [None]:
t = Tensor([1, 2, 3], requires_grad=True)
t.data = [[1, 3, 5], [2, 3, 4]]
t_T = t.T

t_T.backward(np.ones_like(t_T.data))

In [None]:
InitMethod = Literal["xavier", "he", "he_leaky", "normal", "uniform"]


class Parameter(Tensor):
    def __init__(
        self,
        *shape: int,
        data: Optional[np.ndarray] = None,
        init_method: InitMethod = "xavier",
        gain: float = 1.0,
        alpha: float = 0.01,
    ):
        if data is None:
            data = self._init(shape, init_method, gain, alpha)

        super().__init__(data=data, requires_grad=True)

    def _init(
        self,
        shape: Tuple[int, ...], 
        init_method: InitMethod = "xavier", 
        gain: float = 1.0, 
        alpha: float = 0.01
    ):
        weights = np.random.randn(*shape)

        if init_method == "xavier":
            std = gain * np.sqrt(1.0 / shape[0])
            return std * weights
        if init_method == "he":
            std = gain * np.sqrt(2.0 / shape[0])
            return std * weights
        if init_method == "he_leaky":
            std = gain * np.sqrt(2.0 / (1 + alpha**2) * (1 / shape[0]))
            return std * weights
        if init_method == "normal":
            return gain * weights
        if init_method == "uniform":
            return gain * np.random.uniform(-1, 1, size=shape)

        raise ValueError(f"Unknown initialization method: {init_method}")

In [None]:
from typing import List

class Module:
    def __call__(self, *args, **kwds) -> Tensor:
        return self.forward(*args, **kwds)

    def forward(self, *args, **kwds):
        raise NotImplementedError()

    def parameters(self) -> List[Parameter]:
        r"""
        Returns a list of all parameters in the module and its submodules.
        """
        params = []
        for _, item in self.__dict__.items():
            if isinstance(item, Parameter):
                params.append(item)
            elif isinstance(item, Module):
                params.extend(item.parameters())
        return params

    def zero_grad(self) -> None:
        r"""
        Zeroes the gradients of all parameters in the module and its submodules.
        """
        for param in self.parameters():
            param.zero_grad()

    def params_count(self) -> int:
        return sum(param.size for param in self.parameters())


class Sequential(Module):
    def __init__(self, *modules: Module):
        self.modules = modules

    def parameters(self) -> List[Parameter]:
        r"""
        Returns a list of all parameters in the sequential module and its submodules.
        """
        params = []
        for module in self.modules:
            params.extend(module.parameters())
        return params

    def forward(self, x):
        r"""
        Passes the input through all modules in sequence.
        """
        for module in self.modules:
            x = module(x)
        return x
    

class DummyModule(Module):
    def __init__(self, dims: int):
        super().__init__()

        self.dims = dims
        self.param = Parameter(dims)

    def forward(self, x: Tensor):
        return x.transpose()

In [None]:
model = Sequential(
    DummyModule(10),
    DummyModule(100),
    DummyModule(1000),
)

model.params_count()

## Linear Layer: Matrix-Matrix Dot Product  

At layer $i$, the transformation is defined as:  

$$A_i(\mathbf{X}) = \mathbf{X} \mathbf{W}_i^T + \mathbf{B}_i$$

For a single layer:  

$$F_i(\mathbf{X}) = \sigma(A_i(\mathbf{X}))$$

where $A_i(\mathbf{X})$ is the linear transformation at layer $i$.  

A deep neural network applies these transformations layer by layer, leading to the final output:  

$$F(\mathbf{X}) = \sigma(A_L(\sigma(A_{L-1}(\dots \sigma(A_1(\mathbf{X})) \dots )))$$

Using **functional composition**, this process is compactly written as:  

$$F(\mathbf{X}) = A_L \circ \sigma \circ A_{L-1} \circ \dots \circ \sigma \circ A_1 (\mathbf{X})$$

In [None]:
class Linear(Module):
    def __init__(
        self,
        in_features: int,
        out_features: int,
        bias: bool = True,
        init_method: InitMethod = "xavier",
    ):
        super().__init__()

        self.in_features = in_features
        self.out_features = out_features

        self.weights = Parameter(out_features, in_features, init_method=init_method)
        self.bias = Parameter(out_features, init_method="normal", gain=0.01) if bias else None

    def forward(self, x: Tensor):
        # Check dimensions of input tensors
        if x.ndim not in (2, 3):
            raise ValueError(f"Input must be 2D or 3D Tensor! x.ndim={x.ndim}")

        # Check if the last dimension of input matches in_features
        if x.shape[-1] != self.in_features:
            raise ValueError(
                f"Last dimension of input: {x.shape[-1]} does not match in_features: {self.in_features}"
            )

        # Compute matrix multiplication: x @ weight^T
        output = x @ self.weights.T

        if self.bias is not None:
            output = output + self.bias

        return output


## Dot product

**Backward Pass (Gradients Computation)**

Compute $\frac{\partial L}{\partial A}$ and $\frac{\partial L}{\partial B}$ using the chain rule.

**Gradient w.r.t. A**
The gradient of the loss $L$ with respect to $A$ is given by:

$$\frac{\partial L}{\partial A} = \frac{\partial L}{\partial Z} \times B^T$$

```python
if a.requires_grad:
    def _bkwd(grad: np.ndarray) -> np.ndarray:
        if b.ndim > 1:
            return grad @ b.data.swapaxes(-1, -2)  # grad * B^T
        return np.outer(grad, b.data.T).squeeze()  # Handles 1D case
```

- If $B$ is 2D, we use `b.data.swapaxes(-1, -2)` to compute $B^T$.
- If $B$ is 1D, we use `np.outer(grad, b.data.T)` to ensure correct shape.


**Gradient w.r.t. B**

The gradient of the loss $L$ with respect to $B$ is given by:

$$\frac{\partial L}{\partial B} = A^T \times \frac{\partial L}{\partial Z}$$

Where $A^T$ is the **transpose of A**.

This is implemented as:

```python
if b.requires_grad:
    def _bkwd(grad: np.ndarray) -> np.ndarray:
        if a.ndim > 1:
            return a.data.swapaxes(-1, -2) @ grad  # A^T * grad
        return np.outer(a.data.T, grad).squeeze()  # Handles 1D case
```

- If $A$ is 2D, we use `a.data.swapaxes(-1, -2)` to compute $A^T$.
- If $A$ is 1D, we use `np.outer(a.data.T, grad)`.


**Why Do We Use `swapaxes(-1, -2)` Instead of `.T`?**

`swapaxes(-1, -2)` is a **general approach** for transposing the last two dimensions. This ensures compatibility with **both 2D matrices and higher-dimensional tensors** (e.g., batches of matrices).

- `.T` works **only for 2D matrices**, affecting all axes in higher dimensions.
- `swapaxes(-1, -2)` **preserves batch and other leading dimensions**, modifying only the last two.

Example:

| Shape of Tensor | `.T` Output | `swapaxes(-1, -2)` Output |
|----------------|------------|---------------------------|
| `(m, n)` | `(n, m)` | `(n, m)` |
| `(batch, m, n)` | `(n, m, batch)` (incorrect) | `(batch, n, m)` (correct) |
| `(batch, time, m, n)` | `(n, m, time, batch)` (incorrect) | `(batch, time, n, m)` (correct) |


Matrix multiplication follows the chain rule. The backward pass computes gradients for both $A$ and $B$ using transposes. Uses `swapaxes(-1, -2)` to generalize for higher-dimensional cases.

| Tensor  | Gradient Formula | Code Implementation |
|---------|-----------------|----------------------|
| $A$ | $\frac{\partial L}{\partial A} = \frac{\partial L}{\partial Z} \times B^T$ | `grad @ b.data.swapaxes(-1, -2)` |
| $B$ | $\frac{\partial L}{\partial B} = A^T \times \frac{\partial L}{\partial Z}$ | `a.data.swapaxes(-1, -2) @ grad` |


In [None]:
def matmul(a: "Tensor", b: "Tensor") -> "Tensor":
    r"""
    Static method to perform matrix multiplication of two tensors.

    Args:
        a (Tensor): First matrix.
        b (Tensor): Second matrix.

    Returns:
        Tensor: Resulting tensor with tracked dependencies.
    """
    
    output = a.data @ b.data
    requires_grad = a.requires_grad or b.requires_grad
    dependencies = []

    if a.requires_grad:
        def _bkwd_a(grad: np.ndarray) -> np.ndarray:
            if b.ndim > 1:
                return grad @ b.data.swapaxes(-1, -2)
            return np.outer(grad, b.data.T).squeeze()
        
        dependencies.append(
            Leaf(
                value=a,
                grad_fn=_bkwd_a
            )
        )

    if b.requires_grad:
        def _bkwd_b(grad: np.ndarray) -> np.ndarray:
            if a.ndim > 1:
                return a.data.swapaxes(-1, -2) @ grad
            return np.outer(a.data.T, grad).squeeze()
        
        dependencies.append(
            Leaf(
                value=b,
                grad_fn=_bkwd_b
            )
        )

    return Tensor(output, requires_grad, dependencies)

## More operations: `bkwd_broadcasting`

In [None]:
import numpy as np

# Define array A with shape (3, 1)
A = np.array([
    [1],
    [2],
    [3],
])
print(f"Array A shape: {A.shape}")

# Define array B with shape (1, 4)
B = np.array([
    [1, 2, 3, 4],
])
print(f"Array B shape: {B.shape}")

# Perform broadcasting addition
result = A + B

print("A + B result: ")
print(result)

print(f"Result of A + B shape: {result.shape}")


The `bkwd_broadcast` method ensures gradients are correctly summed across broadcasted dimensions in `backward` mode

In [None]:
def bkwd_broadcast(tensor: "Tensor"):
    def _bkwd(grad: np.ndarray) -> np.ndarray:
        if tensor.ndim == 0:
            return np.sum(grad)

        if grad.ndim == 0:
            return grad
        
        ndim_added = max(0, grad.ndim - tensor.ndim)

        if ndim_added > 0:
            grad = grad.sum(axis=tuple(range(ndim_added)), keepdims=False)

        reduce_axes = tuple(
            dim for dim in range(tensor.ndim)
            if tensor.shape[dim] == 1 and grad.shape[dim] > 1
        )

        if reduce_axes:
            grad = grad.sum(axis=reduce_axes, keepdims=True)

        if grad.shape != tensor.shape:
            grad = grad.reshape(tensor.shape)

        return grad

    return _bkwd

### Examples of backward broadcasting

In **Scenario 1**, `b` has shape `(1,)`, meaning it was **expanded to match both dimensions** of `a`. We **sum over all extra axes `(0,1)`** (`keepdims=False`) to return to shape `(1,)`.  


In [None]:
a = np.array([[1, 2], 
              [3, 4]])  # Shape: (2, 2)

b = np.array([10])      # Shape: (1,)  (Broadcasted across both axis)

c = a + b
print(f"c: {c}")

grad_c = np.ones_like(c)
print(f"grad_c: {grad_c}")

# Since `a` was not broadcasted, the gradient just passes through
grad_a = grad_c
print(f"grad_a: {grad_a}")

# Since `b` was **broadcasted along both axis**, we must **sum** over 
# that axis to reduce it back to `b`'s original shape `(1,)`
grad_b = grad_c.sum(axis=(0, 1), keepdims=False)
print(f"grad_b: {grad_b}")


In **Scenario 2**, `b` has shape `(2,1)`, meaning it was broadcasted along axis `1` to match `a`'s shape `(2,2)`

We **sum over axis 1** (`keepdims=True`) to restore `b`'s original shape `(2,1)`.  

In [None]:
a = np.array([[1, 2], 
              [3, 4]])  # Shape: (2, 2)

b = np.array([[10],
              [20]])    # Shape: (2, 1)  (Broadcasted across axis 1)

c = a + b
print(f"c: {c}")

grad_c = np.ones_like(c)
print(f"grad_c: {grad_c}")

# Since `a` was not broadcasted, the gradient just passes through
grad_a = grad_c
print(f"grad_a: {grad_a}")

# Since `b` was **broadcasted along axis 1**, we must **sum** over 
# that axis to reduce it back to `b`'s original shape `(2,1)`
grad_b = grad_c.sum(axis=1, keepdims=True)
print(f"grad_b: {grad_b}")

### `add`, `sub` and their friends

$$f(a, b) = a + b$$

The derivative of $a + b$ with respect to $a$ and $b$ is 1:

$$\frac{d}{da} (a + b) = 1$$

$$\frac{d}{db} (a + b) = 1$$

In [None]:
def add(a: "Tensor", b: "Tensor") -> "Tensor":
    output = a.data + b.data

    requires_grad = a.requires_grad or b.requires_grad

    dependencies = []

    if a.requires_grad:
        dependencies.append(
            Leaf(
                value=a,
                grad_fn=Tensor.bkwd_broadcast(a)
            )
        )

    if b.requires_grad:
        dependencies.append(
            Leaf(
                value=b,
                grad_fn=Tensor.bkwd_broadcast(b)
            )
        )

    return Tensor(output, requires_grad, dependencies)

### `mul`

The **multiplication** operation computes the element-wise product of two tensors.

$$f(a, b) = a \cdot b$$

The derivative of $a \cdot b$ with respect to $a$ and $b$ is:

$$\frac{d}{da} (a \cdot b) = b$$

$$\frac{d}{db} (a \cdot b) = a$$

In [None]:
def mul(a: "Tensor", b: "Tensor") -> "Tensor":
    output = a.data * b.data

    requires_grad = a.requires_grad or b.requires_grad

    dependencies = []

    def _backward(a: "Tensor", b: "Tensor"):
        def _bkwd(grad: np.ndarray) -> np.ndarray:
            grad = grad * b
            return Tensor.bkwd_broadcast(a)(grad)
        
        return _bkwd

    if a.requires_grad:
        dependencies.append(
            Leaf(
                value=a,
                grad_fn=_backward(a, b)
            )
        )

    if b.requires_grad:
        dependencies.append(
            Leaf(
                value=b,
                grad_fn=_backward(b, a)
            )
        )

    return Tensor(output, requires_grad, dependencies)

### `log`

The **logarithmic** operation computes the natural logarithm of each element in the tensor.

$$f(x) = \log(x)$$

The derivative of $\log(x)$ is:

$$\frac{d}{dx} \log(x) = \frac{1}{x}$$

### `tanh`

The **tanh** operation computes the hyperbolic tangent of each element in the tensor.

$$\tanh(x) = \frac{e^x - e^{-x}}{e^x + e^{-x}}$$

The derivative of $\tanh(x)$ is:

$$\frac{d}{dx} \tanh(x) = 1 - \tanh^2(x)$$

### `pow`

The **power** operation raises each element of the tensor to the specified power.

$$f(x) = x^p$$

The derivative of $x^p$ is:

$$\frac{d}{dx} x^p = p \cdot x^{p-1}$$

### `exp`

The **exponential** operation computes the exponent (base $e$) of each element in the tensor.

$$\exp(x) = e^x$$

The derivative of $e^x$ is:

$$\frac{d}{dx} e^x = e^x$$

## Activations

### Tanh Function

$$\tanh(x) = \frac{e^{x} - e^{-x}}{e^{x} + e^{-x}}$$

A smooth, S-shaped activation function that maps input to the range $(-1, 1)$.

### Sigmoid Function

$$\sigma(x) = \frac{1}{1 + e^{-x}}$$

A smooth activation function that squashes input to the range $(0, 1)$, often used for probabilities.

In [None]:
class Tanh(Module):
    def forward(self, input: Tensor) -> Tensor:
        return input.tanh()

class Sigmoid(Module):
    def forward(self, input: Tensor) -> Tensor:
        return 1 / (1 + Tensor.exp(-input))

In [None]:
input = Tensor([1, 2, 3], requires_grad=True)

tanh_activation = Tanh()
tanh_activation.forward(input).backward(np.ones_like(input.data))

In [None]:
sigmoid_activation = Sigmoid()
sigmoid_activation.forward(input).backward(np.ones_like(input.data))

### `abs`

The **absolute value** operation computes the magnitude of each element in a tensor, disregarding the sign. 

$$\text{abs}(x) = |x|$$

The derivative of $\text{abs}(x)$ is:

$$\frac{d}{dx} |x| = \text{sgn}(x)$$

where the sign function $\text{sgn}(x)$ is defined as:

$$\text{sgn}(x) =
\begin{cases}
  1, & \text{if } x > 0 \\
  -1, & \text{if } x < 0 \\
  0, & \text{if } x = 0
\end{cases}$$



### `max`

The **max operation** returns the maximum value of a tensor along a specified axis. If no axis is specified, it returns the maximum value from the entire tensor.

For differentiation, the gradient of the maximum function is defined as:

$$\frac{d}{dx} \max(X) =
\begin{cases}
  1, & \text{if } x \text{ is the maximum value} \\
  0, & \text{otherwise}
\end{cases}$$

### More Activation Functions!

The numerically stable softmax calculation:

$$\text{Softmax}(x)_i = \frac{\exp(x_i - \max(x))}{\sum \exp(x_j - \max(x))}$$

In [None]:
class ReLU(Module):
    def forward(self, input: Tensor) -> Tensor:
        # Apply ReLU: max(0, x)
        return Tensor.maximum(0, input)

class LeakyReLU(Module):
    def __init__(self, alpha: float = 0.01):
        super().__init__()
        self.alpha = alpha

    def forward(self, input: Tensor) -> Tensor:
        # Apply LeakyReLU: max(0, x) + alpha * min(0, x)
        return Tensor.maximum(0, input) + self.alpha * Tensor.minimum(0, input)

class Softmax(Module):
    def __init__(self, dim: int = -1):
        super().__init__()
        self.dim = dim

    def forward(self, input: Tensor) -> Tensor:
        exp_input = (input - input.max(axis=self.dim, keepdims=True)).exp()
        return exp_input / exp_input.sum(axis=self.dim, keepdims=True)
