# Conv2d

In [1]:
%load_ext autoreload
%autoreload 2
import numpy as np
from nets.functional import conv2d_forward, conv2d_backward, mse_loss_backward, mse_loss_forward

from tests.check import check_equals

N, C_in, C_out, H, W = 10, 4, 64, 123, 123
stride= 3
pad = 5
K = 4
num_groups= 2

x = np.random.randn(N, C_in, H, W)
kernel = np.random.randn(C_out, int(C_in/num_groups), K, K)
b = np.random.randn(C_out)

Hp = (H + 2*pad)
Wp = (W + 2*pad)
outH = int(np.ceil((Hp - (K - 1))/stride))
outW = int(np.ceil((Wp - (K - 1)) / stride))

y = np.random.randn(N, C_out, outH, outW)

In [2]:
# Numpy Example
y_hat, cache1 = conv2d_forward(x, kernel, b, pad=pad, stride=stride, num_groups=num_groups)
loss, cache2 = mse_loss_forward(y_hat, y)

dL_dyhat = mse_loss_backward(cache2)
dL_dx, dL_dkernel, dL_db = conv2d_backward(dL_dyhat, cache1)

In [3]:
# Checking against Torch
import torch
x_torch = torch.tensor(x, requires_grad=True)
kernel_torch = torch.tensor(kernel, requires_grad=True)
b_torch = torch.tensor(b, requires_grad=True)
y_torch = torch.tensor(y)

yhat_torch = torch.nn.functional.conv2d(
    x_torch, kernel_torch, b_torch, padding=pad, stride=stride, groups=num_groups
)

# Checking Forward Pass
check_equals(y_hat, yhat_torch)
loss = torch.nn.functional.mse_loss(yhat_torch, y_torch)
loss.backward()

check_equals(x_torch.grad, dL_dx)
check_equals(kernel_torch.grad, dL_dkernel)
check_equals(b_torch.grad, dL_db)

1.7763568394002505e-14
3.2526065174565133e-19
1.8041124150158794e-16
1.3877787807814457e-17


# Softmax

In [4]:
from nets.functional import softmax_forward, softmax_backward
N, K = 8, 10

x = np.random.randn(N, K)
y = np.random.randn(N, K)

out, cache1 = softmax_forward(x)
loss, cache2 = mse_loss_forward(out, y)

dL_dout = mse_loss_backward(cache2)
dL_dx = softmax_backward(dL_dout, cache1)

In [5]:
import torch

x_torch = torch.tensor(x, requires_grad=True)
y_torch = torch.tensor(y, requires_grad=True)

out_torch = torch.nn.functional.softmax(x_torch, dim=-1)

# Checking Forward Pass
check_equals(out_torch, out)
loss = torch.nn.functional.mse_loss(out_torch, y_torch)
loss.backward()

check_equals(x_torch.grad, dL_dx)

5.551115123125783e-17
1.734723475976807e-18


# BatchNorm

In [6]:
from nets.functional import batchnorm_forward, batchnorm_backward

N, D = 8, 10

x = np.random.randn(N, D)
g = np.random.randn(D)
b = np.random.randn(D)
y = np.random.randn(N, D)

y_hat, cache1 = batchnorm_forward(x, g, b)
loss, cache2 = mse_loss_forward(y_hat, y)

dL_dout = mse_loss_backward(cache2)
dL_dx, dL_dg, dL_db = batchnorm_backward(dL_dout, cache1)


# Checking against Torch
x_torch = torch.tensor(x, requires_grad=True)
g_torch = torch.tensor(g, requires_grad=True)
b_torch = torch.tensor(b, requires_grad=True)
y_torch = torch.tensor(y)

yhat_torch = torch.nn.functional.batch_norm(
    x_torch, None, None, g_torch, b_torch, training=True
)

# Checking Forward Pass
check_equals(y_hat, yhat_torch)
loss = torch.nn.functional.mse_loss(yhat_torch, y_torch)
loss.backward()

check_equals(x_torch.grad, dL_dx)
check_equals(g_torch.grad, dL_dg)
check_equals(b_torch.grad, dL_db)

8.881784197001252e-16
9.367506770274758e-17
5.551115123125783e-17
5.551115123125783e-17


# LayerNorm

In [7]:
from nets.functional import layernorm_forward, layernorm_backward

N, D = 8, 10

x = np.random.randn(N, D)
g = np.random.randn(D)
b = np.random.randn(D)
y = np.random.randn(N, D)

y_hat, cache1 = layernorm_forward(x, g, b)
loss, cache2 = mse_loss_forward(y_hat, y)

dL_dout = mse_loss_backward(cache2)
dL_dx, dL_dg, dL_db = layernorm_backward(dL_dout, cache1)


# Checking against Torch
x_torch = torch.tensor(x, requires_grad=True)
g_torch = torch.tensor(g, requires_grad=True)
b_torch = torch.tensor(b, requires_grad=True)
y_torch = torch.tensor(y)

yhat_torch = torch.nn.functional.layer_norm(
    x_torch, (D,), g_torch, b_torch, eps=1e-5
)

# Checking Forward Pass
check_equals(y_hat, yhat_torch)
loss = torch.nn.functional.mse_loss(yhat_torch, y_torch)
loss.backward()

check_equals(x_torch.grad, dL_dx)
check_equals(g_torch.grad, dL_dg)
check_equals(b_torch.grad, dL_db)


6.661338147750939e-16
5.551115123125783e-17
1.1102230246251565e-16
1.1102230246251565e-16


# RMSNorm

In [8]:
from nets.functional import rms_norm_forward, rms_norm_backward

N, D = 8, 10

x = np.random.randn(N, D)
g = np.random.randn(D)
y = np.random.randn(N, D)

y_hat, cache1 = rms_norm_forward(x, g)
loss, cache2 = mse_loss_forward(y_hat, y)

dL_dout = mse_loss_backward(cache2)
dL_dx, dL_dg = rms_norm_backward(dL_dout, cache1)

# Checking against Torch
x_torch = torch.tensor(x, requires_grad=True)
g_torch = torch.tensor(g, requires_grad=True)
y_torch = torch.tensor(y)

yhat_torch = torch.nn.functional.rms_norm(x_torch, (D,), g_torch, eps=1e-5)

# Checking Forward Pass
check_equals(y_hat, yhat_torch)
loss = torch.nn.functional.mse_loss(yhat_torch, y_torch)
loss.backward()

check_equals(x_torch.grad, dL_dx)
check_equals(g_torch.grad, dL_dg)


4.440892098500626e-16
5.551115123125783e-17
1.1102230246251565e-16


# BCE Loss

In [9]:
import numpy as np
import torch
from nets.functional import bce_loss_forward, bce_loss_backward
from tests.check import check_equals

N = 12
p = np.random.rand(N)*0.5 + 0.25
y = np.random.randint(2, size=(N,))



loss, cache = bce_loss_forward(p, y)
p_torch = torch.tensor(p, requires_grad=True)
y_torch = torch.tensor(y, dtype=torch.float64)
loss_torch = torch.nn.functional.binary_cross_entropy(p_torch, y_torch)


dL_dp = bce_loss_backward(cache)
check_equals(loss, loss_torch)
loss_torch.backward()
check_equals(p_torch.grad, dL_dp)

1.1102230246251565e-16
5.551115123125783e-17


# Self Attention

In [39]:
from nets.functional import attn_forward, attn_backward, mse_loss_forward, mse_loss_backward 
B, H, S, T, D, C = 12, 4, 100, 200, 64, 7
Q = np.random.randn(B, H, S, D)
K = np.random.randn(B, H, T, D)
V = np.random.randn(B, H, T, C)

y = np.random.randn(B, H, S, C)


out, cache1 = attn_forward(Q, K, V)
print(out.shape)
print(y.shape)

loss, cache2 = mse_loss_forward(y, out)

dL_dout = mse_loss_backward(cache2)
dL_dQ, dL_dK, dL_dV = attn_backward(dL_dout, cache1)


# Torch Check
Q_torch = torch.tensor(Q, requires_grad=True)
K_torch = torch.tensor(K, requires_grad=True)
V_torch = torch.tensor(V, requires_grad=True)
y_torch = torch.tensor(y)

out_torch = torch.nn.functional.scaled_dot_product_attention(Q_torch, K_torch, V_torch)
check_equals(out, out_torch)


loss = torch.nn.functional.mse_loss(y_torch, out_torch)
loss.backward()
check_equals(Q_torch.grad, dL_dQ)
check_equals(K_torch.grad, dL_dK)
check_equals(V_torch.grad, dL_dV)


(12, 4, 100, 7)
(12, 4, 100, 7)
8.881784197001252e-16
6.396614999689303e-05
6.876525802156287e-05
9.898514233929432e-05
