In [102]:
from tinygrad import Device
print(Device.DEFAULT)

GPU


### Data

In [103]:
from tinygrad import Tensor, nn
from tinygrad.nn.datasets import mnist

X_train, Y_train, X_test, Y_test = mnist()
print(X_train.shape, X_train.dtype, Y_train.shape, Y_train.dtype)

(60000, 1, 28, 28) dtypes.uchar (60000,) dtypes.uchar


In [104]:
# Data proprocessing

X_train = X_train / X_train.max()
X_train = X_train.reshape(X_train.shape[0], -1)

Y_train = Y_train.one_hot(10)

X_train.shape, Y_train.shape

((60000, 784), (60000, 10))

In [105]:
# Initialization 

input_shape, hidden_shape, output_shape = X_train.shape[1], 64, Y_train.shape[1]

w1 = Tensor.rand(input_shape, hidden_shape)
b1 = Tensor.rand(1, hidden_shape)

w2 = Tensor.rand(hidden_shape, output_shape)
b2 = Tensor.rand(1, output_shape)

In [106]:
# Forward pass

minibatch_size = 32

minibatch_idx = Tensor.randint(minibatch_size, low=0, high=X_train.shape[0])
minibatch = X_train[minibatch_idx]

z1 =  minibatch @ w1 + b1
hidden = z1.sigmoid()  

logits = hidden @ w2 + b2

# Softmax by hand 
normalized_logits = logits - logits.max(1, keepdim=True)
exp_logits = normalized_logits.exp()
sum_logits = exp_logits.sum(1, keepdim=True)
probs = exp_logits.div(sum_logits)

# Loss
loss = ((probs - Y_train[minibatch_idx])**2).mean()

### Manual Backwards Pass

In [108]:
# Function for comparing gradients by tinygrad to my calculated gradients

def compare_gradients(name, gradient, tensor, previous):
  exact = (gradient - tensor.gradient(previous)).all()

  print(f'{s:15s} | exact: {str(ex):5s}')

In this case our loss is:
$$
L = \frac{1}{N}\sum_{i=1}^{N}(\hat{\mathbf{y}}_{i}-\mathbf{y}_{i})^{2}
$$
In turn, if we solve for the derivative with respect to the probabilities our model gives us we get: 
$$
\frac{ \partial L }{ \partial  \hat{\mathbf{y}_{i}} } = \frac{2}{N}  (\hat{\mathbf{y}_{i}}-\mathbf{y}_{i})
$$

Since $\hat{\mathbf{y}}$ or `probs`is a $32 \times 10$ tensor (batch size x classes), `dprobs` ($\frac{ \partial L }{ \partial  \hat{\mathbf{y}} }$) should have the same shape  but with each element scaled by $\frac{1}{N}$ due to the mean operation.

In [109]:
# Manual backward pass

N = minibatch_size * output_shape  # 32 * 10 = 320
dprobs = 2 * (probs - Y_train[minibatch_idx]) / N

compare_gradients("probs", dprobs, loss, probs)




RuntimeError: UOp(Ops.MUL, dtypes.float, arg=None, src=(
  UOp(Ops.EXP2, dtypes.float, arg=None, src=(
    UOp(Ops.MUL, dtypes.float, arg=None, src=(
      UOp(Ops.ADD, dtypes.float, arg=None, src=(
        UOp(Ops.RESHAPE, dtypes.float, arg=(32, 10), src=(
          UOp(Ops.BUFFER, dtypes.float, arg=320, src=(
            x5:=UOp(Ops.DEVICE, dtypes.void, arg='GPU', src=()),
            UOp(Ops.UNIQUE, dtypes.void, arg=456, src=()),)),)),
        UOp(Ops.MUL, dtypes.float, arg=None, src=(
          UOp(Ops.EXPAND, dtypes.float, arg=(32, 10), src=(
            UOp(Ops.RESHAPE, dtypes.float, arg=(32, 1), src=(
              UOp(Ops.BUFFER, dtypes.float, arg=32, src=(
                 x5,
                UOp(Ops.UNIQUE, dtypes.void, arg=487, src=()),)),)),)),
          UOp(Ops.CONST, dtypes.float, arg=-1.0, src=(
            x13:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 10), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=(
               x5,)),)),)),)),
      UOp(Ops.CONST, dtypes.float, arg=1.4426950408889634, src=(
         x13,)),)),)),
  UOp(Ops.RECIP, dtypes.float, arg=None, src=(
    UOp(Ops.EXPAND, dtypes.float, arg=(32, 10), src=(
      UOp(Ops.RESHAPE, dtypes.float, arg=(32, 1), src=(
        UOp(Ops.BUFFER, dtypes.float, arg=32, src=(
           x5,
          UOp(Ops.UNIQUE, dtypes.void, arg=488, src=()),)),)),)),)),))

not found in

UOp(Ops.MUL, dtypes.float, arg=None, src=(
  UOp(Ops.RESHAPE, dtypes.float, arg=(), src=(
    UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (0, 1)), src=(
      UOp(Ops.POW, dtypes.float, arg=None, src=(
        UOp(Ops.RESHAPE, dtypes.float, arg=(32, 10), src=(
          UOp(Ops.BUFFER, dtypes.float, arg=320, src=(
            x5:=UOp(Ops.DEVICE, dtypes.void, arg='GPU', src=()),
            UOp(Ops.UNIQUE, dtypes.void, arg=497, src=()),)),)),
        UOp(Ops.CONST, dtypes.float, arg=2.0, src=(
          UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 10), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=(
             x5,)),)),)),)),)),
  UOp(Ops.CONST, dtypes.float, arg=0.003125, src=(
    UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(), strides=(), offset=0, mask=None, contiguous=True),)), src=(
       x5,)),)),))