# scalar

In [1]:
import torch

def x_square(x):
    return x**2

def d_x_square(x):
    return 2*x

X = torch.rand(1, requires_grad=True)
x = X.item()

Y = x_square(X)
Y.backward()
print(X.grad, "<>", d_x_square(x))

tensor([0.8622]) <> 0.8622314929962158


In [2]:
def lossfn(y_actual, y_pred):
    return y_pred - y_actual

def d_x_lossfn(y_actual, y_pred, x):
    # d(loss)/dx => d(loss)dy * dy/dx
    return (0 - 1) * d_x_square(x)

y_actual = 10.0
X = torch.rand(1, requires_grad=True)
x = X.item()
Y_pred = x_square(X)
y_pred = x_square(x)

loss = lossfn(Y_pred, y_actual)
loss.backward()
print(X.grad, "<>", d_x_lossfn(y_actual, y_pred, x))

tensor([-1.5458]) <> -1.545842170715332


In [3]:
def lossfn2(y_actual, y_pred):
    return (y_pred - y_actual) ** 2

def d_lossfn2(y_actual, y_pred, x):
    # df/dx where f = g**2 and g = y - y_pred
    # df/dx = df/dg . dg/dx = 2g . dg/dx
    return  2 * lossfn(y_pred, y_actual) * d_x_lossfn(y_actual, y_pred, x)
    
x = 1
y_actual = 10
X = torch.tensor(x, dtype=float, requires_grad=True)
loss = lossfn2(y_actual, x_square(X))
loss.backward()

print(X.grad, "<>", d_lossfn2(y_actual, x_square(x), x))

tensor(-36., dtype=torch.float64) <> -36


In [4]:
def linear_fn(w, x, b):
    return w*x+b

def d_w_linear_fn(w, x, b):
    # d(w.x+b)/dw = x
    return x

def d_b_linear_fn(w, x, b):
    # d(w.x+b)/db = 1
    return 1

w = 5
b = 4
x = 2.5
W = torch.tensor(w, dtype=float, requires_grad=True)
B = torch.tensor(b, dtype=float, requires_grad=True)
X = torch.tensor(x, dtype=float, requires_grad=True)
Z = linear_fn(W, x, B)
Z.backward()
print(W.grad, B.grad, "<>", d_w_linear_fn(w, x, b), d_b_linear_fn(w, x, b))


tensor(2.5000, dtype=torch.float64) tensor(1., dtype=torch.float64) <> 2.5 1


# Vector

In [5]:
import numpy as np

def f(x):
    return sum(x**2);

x = np.array([1.0, 5.0, 6.0])
X  = torch.tensor(x, requires_grad=True, dtype=float)
Y = f(X)
Y.backward()
print(Y, X)
print(X.grad)


tensor(62., dtype=torch.float64, grad_fn=<AddBackward0>) tensor([1., 5., 6.], dtype=torch.float64, requires_grad=True)
tensor([ 2., 10., 12.], dtype=torch.float64)


## Linear function z = w.x + b

In [6]:
import torch 

X = torch.randn(5, 2, requires_grad=True)
W = torch.randn(2, 3, requires_grad=True)
B = torch.randn(1, 3, requires_grad=True)

print("B=", B)
print("X=", X)
print("W=", W)

print(X.shape, W.shape)
Q = torch.tensordot(X, W, dims=1)
Q.retain_grad()
print("Q=", Q.shape, Q)

Z = Q + B
Z.retain_grad()
print("Z=", Z)

O = torch.sum(Z)
print("O=", O)

O.backward()
#print("dz", Z.grad)
print("dw", W.grad)
print("dx", X.grad, "<>", sum(W[0]), sum(W[1]))
print("db", B.grad)

B= tensor([[1.1714, 0.1635, 0.9987]], requires_grad=True)
X= tensor([[ 1.1939,  0.1343],
        [-0.1130,  0.0077],
        [-0.2277, -1.1534],
        [ 1.6545, -0.7241],
        [-0.4871,  0.5044]], requires_grad=True)
W= tensor([[-0.3280,  1.1977,  0.4110],
        [ 1.8343, -0.7891,  0.8515]], requires_grad=True)
torch.Size([5, 2]) torch.Size([2, 3])
Q= torch.Size([5, 3]) tensor([[-0.1453,  1.3240,  0.6051],
        [ 0.0513, -0.1415, -0.0399],
        [-2.0409,  0.6374, -1.0757],
        [-1.8709,  2.5530,  0.0635],
        [ 1.0849, -0.9814,  0.2293]], grad_fn=<ReshapeAliasBackward0>)
Z= tensor([[ 1.0261,  1.4875,  1.6037],
        [ 1.2227,  0.0220,  0.9588],
        [-0.8695,  0.8009, -0.0770],
        [-0.6995,  2.7165,  1.0622],
        [ 2.2564, -0.8179,  1.2279]], grad_fn=<AddBackward0>)
O= tensor(11.9209, grad_fn=<SumBackward0>)
dw tensor([[ 2.0206,  2.0206,  2.0206],
        [-1.2311, -1.2311, -1.2311]])
dx tensor([[1.2807, 1.8967],
        [1.2807, 1.8967],
        [1.2

# Recursive function

Computation graph (CS231 course)

h(t) = w_h * h(t-1)
![comp_graph_recursive_function.jpg](./comp_graph_recursive_function.jpg)

In [7]:
import torch

W_h = torch.tensor(2.0, requires_grad=True)
w_h = W_h.item()

H0 = torch.tensor(0.5, requires_grad=True)
h0 = H0.item()

H1 = W_h * H0
h1 = H1.item()

H2 = W_h * H1
h2 = H2.item()

H2.backward()
dw_h_over_h0 = 0 #because H0 is constant for w_h
dw_h_over_h1 = h0 + w_h * dw_h_over_h0
dw_h_over_h2 = h1 + w_h * dw_h_over_h1

dh0_over_h0 = 1
dh0_over_h1 = w_h * dh0_over_h0
dh0_over_h2 = w_h * dh0_over_h1

print(W_h.grad, "<>", dw_h_over_h2, "<>", (h0 * w_h + h1))
print(H0.grad, "<>", dh0_over_h2, "<>", (w_h * w_h))

tensor(2.) <> 2.0 <> 2.0
tensor(4.) <> 4.0 <> 4.0


## recursive with tanh function

In [8]:
# h0 = 0.5
# w_h = 2.0
# w_y = 40.0
# h1_raw = w_h * h0
# h1 = tanh(h1_raw)
# h2_raw = w_h * h1
# h2 = tanh(h2_raw)
# y2 = w_y * h2
# find dy2/dw_h & dy2/dw_y ?
# y2 = 40 * tanh(w_h * tanh(w_h * 0.5))

import torch

W_h = torch.tensor(2.0, requires_grad=True)
w_h = W_h.detach().item()

W_y = torch.tensor(40.0, requires_grad=True)
w_y = W_y.detach().item()

H0 = torch.tensor(0.5, requires_grad=False)
h0 = H0.item()

H1_raw = W_h * H0
h1_raw = H1_raw.item()
H1 = torch.tanh(H1_raw)
h1 = H1.detach().item()

H2_raw = W_h * H1
h2_raw = H2_raw.detach().item()
H2 = torch.tanh(H2_raw)
Y2 = W_y * H2
h2 = H2.detach().item()

Y2.backward()

# analytic method for derivatives
#   dy/dw_y = h2 * dw_y/dw_y = h2 (product rule)
#   dy/dw_h = h2 * dw_y/dw_h + w_y * dh2/dw_h = w_y * dh2/dw_h (product rule & w_y is contant for w_h)
#   dh2/dw_h = 1 - h2_raw**2 * dh2_raw/dw_h (chain rule)
#   dh2_raw/dw_h = h1 * dw_h/dw_h + w_h * dh1/dw_h = h1 + w_h * dh1/dw_h (product rule)
#   dh1/dw_h = 1 - h1_raw**2 * dh1_raw/dw_h (chain rule)
#   dh1_raw/dw_h = h0 * dw_h/dw_h + w_h * dh0/dw_h = h0 (product rule & h0 is constant)
dh0_dwh = 0 #because H0 is constant for w_h
dh1raw_dwh = h0 + w_h * dh0_dwh
dh1_dwh = (1 - h1**2) * dh1raw_dwh
dh2raw_dwh = h1 + w_h * dh1_dwh
dh2_dwh = (1 - h2**2) * dh2raw_dwh
dwy_dwh = 0
dy2_dwh = h2 * dwy_dwh + w_y * dh2_dwh

dh2_dwy = 0
dy2_dwy = h2 + w_y * dh2_dwy

# backward pass over computation graph
dy2 = 1
dw_y = h2 * 1
dh2 = w_y * 1
dh2_raw = (1 - h2**2) * dh2
dh1 = w_h * dh2_raw
dh1_raw = (1 - h1**2) * dh1
dw_h = h1 * dh2_raw + h0 * dh1_raw

print(W_h.grad, "<>", dy2_dwh, "<>", dw_h)
print(W_y.grad, "<>", dy2_dwy, "<>", dw_y)

tensor(8.1888) <> 8.18880672682621 <> 8.18880672682621
tensor(0.9093) <> 0.9092516899108887 <> 0.9092516899108887


In [26]:
# Z_x = X . W_x.T + B_x
# Z_h = H . W_h.T + B_h
# H = tanh(Z_h)
# Y = H . W_y.T + B_y
# Y_sum += Y

import torch
import numpy as np

class TorchRNN:
    def __init__(self, input_dim=1, hidden_dim=3, output_dim=1, N=2):
        self.f_x = torch.nn.Linear(in_features=1, out_features=3)
        self.f_h = torch.nn.Linear(in_features=3, out_features=3)
        self.f_y = torch.nn.Linear(in_features=3, out_features=1)
        
        self.H = torch.tensor(np.zeros((1, 3), dtype=np.float32))
        self.N = N
        self.loss_fn = torch.nn.MSELoss(reduction='sum')
    
    def forward(self, X, Y):
        J_sum = 0
        loss = 0
        Hs, Ys_hat, diffs, Xs = [self.H], [], [], []
        for i in range(self.N):
            Z_x = self.f_x(X[i])
            Z_h = self.f_h(self.H)
            self.H = torch.tanh(Z_x + Z_h)
            Y_hat = self.f_y(self.H)
            diff = (Y[i] - Y_hat).squeeze(dim=1)
            J_sum += diff**2
            loss += self.loss_fn(Y_hat.squeeze(dim=1), Y[i])
            
            Hs.append(self.H)
            Ys_hat.append(Y_hat)
            diffs.append(diff)
            Xs.append(X[i])
            
        
        self.cache = (Hs, Ys_hat, diffs, Xs)
        self.J_sum = J_sum
        self.loss = loss
        return loss, J_sum
    
    def backward(self):
        self.loss.backward()
    
    def gradients(self):
        return (
            self.f_x.weight.grad, 
            self.f_x.bias.grad,
            self.f_h.weight.grad,
            self.f_h.bias.grad,
            self.f_y.weight.grad,
            self.f_y.bias.grad
        )

class VanillaRNN:
    def __init__(self, W_x, B_x, W_h, B_h, W_y, B_y, H, N=2):
        self.W_x = W_x
        self.B_x = B_x
        self.W_h = W_h
        self.B_h = B_h
        self.W_y = W_y
        self.B_y = B_y
        self.H = H
        self.N = N
    
    def forward(self, X, Y):
        if X.shape[0] < self.N:
            raise AssertionError("X.shape[0] is less than {}".format(self.N))
        
        J_sum = 0
        Hs, Ys_hat, diffs, Xs = [self.H], [], [], []
        for i in range(self.N):
            Z_x = np.dot(X[i], self.W_x.T) + self.B_x
            Z_h = np.dot(self.H, self.W_h.T) + self.B_h
            self.H = np.tanh(Z_x + Z_h)
            Y_hat = np.dot(self.H, self.W_y.T) + self.B_y
            diff = np.squeeze(Y[i] - Y_hat, axis=1)
            J_sum += diff**2
            
            Hs.append(self.H)
            Ys_hat.append(Y_hat)
            diffs.append(diff)
            Xs.append(X[i])
        
        self.cache = (Hs, Ys_hat, diffs, Xs)
        self.J_sum = J_sum
        return J_sum

    def backward(self):
        Hs, Ys_hat, diffs, Xs = self.cache
        dW_x, dB_x, dW_h, dB_h, dW_y, dB_y = 0, 0, 0, 0, 0, 0
        H_iter = reversed(Hs)
        H_cur = next(H_iter)
        dH_next = np.zeros(Hs[0].shape)
        dJ_sum = 1
        for t in reversed(range(self.N)):
            H_prev = next(H_iter)
            
            #diffs[t] = Y[t] - Y_hat[t]
            #J[t] = ( diffs[t] ) ** 2
            dJ = 2 * diffs[t] * (0 - 1) * dJ_sum
            dY = np.full(Ys_hat[0].shape, dJ)
        
            #Y_hat[t] = H[t] . W_y.T + B_y
            dB_y = dB_y + dY
            dW_y = dW_y + np.dot(H_cur.T, dY).T
            dH = np.dot(dY, self.W_y) + dH_next
            
            # H[t] = tanh( H_raw[t] )
            dH_raw_local = 1 - H_cur * H_cur
        
            # Z_xs[t] = X[t]. W_x.T + B_x
            # Z_hs[t] = H[t-1]. W_h.T + B_h
            # H_raw[t] = Z_xs[t] + Z_hs[t] 
            dH_raw = dH_raw_local * dH
            dB_h = dB_h + dH_raw
            dW_h = dW_h + np.dot(H_prev.T, dH_raw).T
            dW_x = dW_x + np.dot(Xs[t].T, dH_raw).T
            
            dH_next = np.dot(dH_raw, self.W_h)
            H_cur = H_prev
        
        self.grads = (dW_x, dB_x, dW_h, dB_h, dW_y, dB_y)
    
    def gradients(self):
        return self.grads



In [27]:
N = 2
torchRNN = TorchRNN(
    input_dim=1, 
    hidden_dim=3, 
    output_dim=1, 
    N=N)

vanilaRNN = VanillaRNN(
    torchRNN.f_x.weight.detach().numpy(),
    torchRNN.f_x.bias.detach().numpy(),
    torchRNN.f_h.weight.detach().numpy(),
    torchRNN.f_h.bias.detach().numpy(),
    torchRNN.f_y.weight.detach().numpy(),
    torchRNN.f_y.bias.detach().numpy(),
    torchRNN.H.detach().numpy(),
    N=N
)

X = np.array([ [1.0], [2.0], [3.0] ])
X_torch = torch.tensor(np.array([ [1.0], [2.0], [3.0] ], dtype=np.float32))

Y = np.array([ [5.0], [10.0], [15.0] ])
Y_torch = torch.tensor(np.array([ [5.0], [10.0], [15.0] ], dtype=np.float32))

#torchRNN.f_h.weight.register_hook(lambda grad: print("from hook", grad))

tJ_sum, _ = torchRNN.forward(X_torch, Y_torch)
torchRNN.backward()
(tdW_x, tdB_x, tdW_h, tdB_h, tdW_y, tdB_y) = torchRNN.gradients()

J_sum = vanilaRNN.forward(X, Y)
vanilaRNN.backward()
(dW_x, dB_x, dW_h, dB_h, dW_y, dB_y) = vanilaRNN.gradients()

print("J_sum: by_torch={}, by_hand={}".format(tJ_sum, J_sum))
print("diffs: by_torch={}, by_hand={}".format(torchRNN.cache[2], vanilaRNN.cache[2]))

print("dB_y: by_torch={} by_hand={}".format(tdB_y, dB_y))
print("dW_y: by_torch={} by_hand={}".format(tdW_y, dW_y))

print("dW_H: by_torch={} by_hand={}".format(tdW_h, dW_h))
print("dB_H: by_torch={} by_hand={}".format(tdB_h, dB_h))

print("dW_X: by_torch={} by_hand={}".format(tdW_x, dW_x))
print("dB_X: by_torch={} by_hand={}".format(tdB_x, dB_x))


J_sum: by_torch=147.34588623046875, by_hand=[147.34589557]
diffs: by_torch=[tensor([5.6724], grad_fn=<SqueezeBackward1>), tensor([10.7317], grad_fn=<SqueezeBackward1>)], by_hand=[array([5.6724266]), array([10.73170406])]
dB_y: by_torch=tensor([-32.8083]) by_hand=[[-32.80826132]]
dW_y: by_torch=tensor([[-31.6833, -25.5239,  -5.9449]]) by_hand=[[-31.68331576 -25.52391181  -5.94494814]]
dW_H: by_torch=tensor([[0.4772, 0.3028, 0.1222],
        [0.9417, 0.5975, 0.2411],
        [2.9400, 1.8654, 0.7527]]) by_hand=[[0.47715462 0.30275505 0.12216023]
 [0.94172597 0.59752601 0.24109892]
 [2.93996332 1.86540947 0.75268389]]
dB_H: by_torch=tensor([1.3071, 1.0650, 3.6685]) by_hand=[[1.30710812 1.06500711 3.66852105]]
dW_X: by_torch=tensor([[1.8133],
        [2.0641],
        [6.7876]]) by_hand=[1.81332578 2.06409273 6.78755478]
dB_X: by_torch=tensor([1.3071, 1.0650, 3.6685]) by_hand=0


# References
https://arxiv.org/pdf/1802.01528.pdf

https://youtu.be/d14TUNcbn1k?si=hyEeGpEt5hP1XVHA

https://kratzert.github.io/2016/02/12/understanding-the-gradient-flow-through-the-batch-normalization-layer.html
