In [1]:
import torch
import ipycanvas
import numpy as np


In [2]:
def compute(x, h1, h2, gt1_w1, gt1_w2, w1, w2):
    # Single recursive node
    y = w1 * x + w2 * h1
    # dh1/dh2 * dh2/dw + dh1/dw
    gt_w1 = w2 * gt1_w1 + x
    gt_w2 = w2 * gt1_w2 + h1
    return y, gt_w1, gt_w2


w1 = torch.rand(2, requires_grad=True)
w2 = torch.rand(2, requires_grad=True)
h0 = 0
x1 = torch.rand(2, requires_grad=True)
h1, dh_w1, dh_w2 = compute(x1, h0, 0, 0, 0, w1, w2)
#assert h1 == 1

x2 = torch.rand(2, requires_grad=True)
h2, dh_w1, dh_w2 = compute(x2, h1, h0, dh_w1, dh_w2, w1, w2)
#assert h2 == (1.5 + 0.5)

x3 = torch.rand(2, requires_grad=True)
h3, dh_w1, dh_w2 = compute(x3, h2, h1, dh_w1, dh_w2, w1, w2)

y = torch.rand(2)
loss = torch.nn.MSELoss()(h3, y)
loss.backward()

error_grad = 2 * (h3 - y) / 2
# YAY! gt1_w1 is correct! With gt_w1 = w2 * gt1_w1 + x
# YAY! gt1_w2 is correct! With gt_w2 = w2 * gt1_w2 + h1
print(w1.grad, w2.grad, dh_w1*error_grad, dh_w2*error_grad)

tensor([0.2634, 4.0096]) tensor([0.1255, 3.0548]) tensor([0.2634, 4.0096], grad_fn=<MulBackward0>) tensor([0.1255, 3.0548], grad_fn=<MulBackward0>)


In [51]:


c = ipycanvas.Canvas(width=800, height=200)

def draw_node(c, x, y):
    c.stroke_style = "red"
    c.stroke_circle(x, y, 40)

def draw_recurrent_node(c, x=60, y=100):

    draw_node(c, x, y)

    c.begin_path()
    y_arc = y - 40
    x0 = x+15
    x1 = x-15
    c.move_to(x0, y_arc)
    c.quadratic_curve_to(x0 + 5, y_arc-25, x0 + (x1 - x0)//2, y_arc - 25)
    c.quadratic_curve_to(x1 - 5, y_arc-25, x1, y_arc)
    c.stroke()
    c.begin_path()
    arrow_width = 5
    c.line_to(x1-arrow_width, y_arc)
    c.line_to(x1, y_arc+arrow_width)
    c.line_to(x1+arrow_width, y_arc)
    c.fill()
    
def draw_arrow(c, x0, y0, x1, y1, arrow_width=5, direction="right"):
    c.begin_path()
    c.move_to(x0, y0)
    c.line_to(x1, y1)
    c.stroke()
    c.begin_path()
    if direction == "right":
        c.line_to(x1-arrow_width, y1-arrow_width)
        c.line_to(x1-arrow_width, y1+arrow_width)
        c.line_to(x1, y1)
    elif direction == "left":
        c.line_to(x1+arrow_width, y1-arrow_width)
        c.line_to(x1+arrow_width, y1+arrow_width)
        c.line_to(x1, y1)
    elif direction == "down":
        c.line_to(x1+arrow_width, y1-arrow_width)
        c.line_to(x1-arrow_width, y1-arrow_width)
        c.line_to(x1, y1)
    else:
        c.line_to(x1+arrow_width, y1+arrow_width)
        c.line_to(x1-arrow_width, y1+arrow_width)
        c.line_to(x1, y1)
    c.fill()


In [52]:
c = ipycanvas.Canvas(width=800, height=300)

draw_recurrent_node(c, x=60, y=100)
draw_recurrent_node(c, x=200, y= 100)
draw_arrow(c, 100, 100, 160, 100)
draw_arrow(c, 160, 110, 100, 110, direction="left")
draw_arrow(c, x0=60, y0=180, x1=60, y1=140, direction="up")
draw_arrow(c, x0=200, y0=180, x1=200, y1=140, direction="up")

draw_arrow(c, x0=60, y0=180, x1=190, y1=140, direction="up")
draw_arrow(c, x0=200, y0=180, x1=70, y1=140, direction="up")

draw_node(c, x=60, y=220)
draw_node(c, x=200, y=220)

c.font = '24px serif'
c.fill_text("2", 55, 230)
c.fill_text("0", 55, 110)
c.fill_text("3", 195, 230)
c.fill_text("1", 195, 110)
c.font = '12px serif'
c.fill_text("num_hidden=2 num_inputs=2", 40, 10)
    
c

Canvas(height=300, width=800)

In [73]:
def calc_next_G_fc_loop(G, W, inputs, hiddens):
    G_new = torch.zeros_like(G, requires_grad=False)
    for k in range(G.shape[0]):
        for i in range(G.shape[1]):
            for j in range(G.shape[2]):
                # p should just belong to the hiddens 0,1
                for p in range(G.shape[0]):
                    G_new[k,i,j] += W[k,p] * G[p,i,j]
                if i == k:
                    # hiddens = 0,1
                    if j >= hiddens.shape[0]:
                        G_new[k,i,j] += inputs[j-hiddens.shape[0], 0]
                    else:
                        G_new[k,i,j] += hiddens[j,0]
    
    return G_new

def calc_next_G_fc(G, W, inputs, hiddens):
    G_new = torch.zeros_like(G, requires_grad=False)
    for k in range(G.shape[0]):
        for i in range(G.shape[1]):
            G_new[k,i:i+1,:] += torch.mm(W[k:k+1,0:G.shape[0]], G[:,i,:])
        
        i = k
        for j in range(G.shape[2]):
            if j >= hiddens.shape[0]:
                G_new[k,i,j] += inputs[j-hiddens.shape[0], 0]
            else:
                G_new[k,i,j] += hiddens[j,0]

    return G_new

def forward_grad_2hidden_calc_activation(hiddens, G):
    new_hiddens = torch.sigmoid(hiddens)
    sigmoid_deriv = torch.sigmoid(hiddens) * (1 - torch.sigmoid(hiddens))
    for k in range(G.shape[0]):
        G[k,:,:] *= sigmoid_deriv[k]
    return new_hiddens, G

def forward_grad_2hidden_calc_fc(G, W, inputs, hiddens):
    # Forward prop
    z = torch.cat((hiddens, inputs))
    new_hiddens = torch.mm(W[:hiddens.shape[0]], z)
    # Calculate gradients
    G = calc_next_G_fc(G, W, inputs, hiddens)
    # Nonlinear activation
    new_hiddens, G = forward_grad_2hidden_calc_activation(new_hiddens, G)
    
    return (new_hiddens, G)
    

def forward_grad_2hidden_fc():
    num_hidden = 10
    num_input = 30
    num_nodes = num_hidden + num_input
    # Gij^k
    G = torch.zeros((num_hidden, num_nodes, num_nodes), requires_grad=True)
    W = torch.rand((num_nodes, num_nodes), requires_grad=True)
    num_itr = 10
    hidden = torch.zeros((num_hidden,1), requires_grad=True)
    
    for itr in range(num_itr):
        inputs = torch.rand((num_input,1))        
        hidden, G = forward_grad_2hidden_calc_fc(G, W, inputs, hidden)
        
    y = torch.rand((num_hidden,1))
    error = torch.nn.MSELoss()(hidden,y)
    error_grad = 2 * (hidden - y) / num_hidden
    print(error, error_grad)
    error.backward()
    G_grad = G
    G_grad[0,:,:] *= error_grad[0]
    G_grad[1,:,:] *= error_grad[1]
    print("===== Calculated =====")
    print(G_grad.sum(dim=0))
    print("===== Actual =====")
    print(W.grad)
    torch.testing.assert_allclose(G_grad.sum(dim=0), W.grad)
    

    
    
forward_grad_2hidden_fc()

tensor(0.2861, grad_fn=<MseLossBackward0>) tensor([[0.0676],
        [0.1318],
        [0.1086],
        [0.0532],
        [0.1925],
        [0.0209],
        [0.1618],
        [0.0654],
        [0.0733],
        [0.0675]], grad_fn=<DivBackward0>)
===== Calculated =====
tensor([[2.7418e-07, 2.7418e-07, 2.7418e-07,  ..., 5.8324e-08, 7.3706e-08,
         1.8916e-07],
        [7.8619e-08, 7.8619e-08, 7.8619e-08,  ..., 1.6725e-08, 2.1139e-08,
         5.4237e-08],
        [6.9141e-06, 6.9141e-06, 6.9141e-06,  ..., 1.4705e-06, 1.8578e-06,
         4.7706e-06],
        ...,
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00]], grad_fn=<SumBackward1>)
===== Actual =====
tensor([[2.7410e-07, 2.7410e-07, 2.7410e-07,  ..., 5.8298e-08, 7.3653e-08,
         1.8912

In [61]:
c = ipycanvas.Canvas(width=800, height=300)



def draw_node_layer(c, x_left, x_right, y_top, y_bottom, radius, n1, n2):
    draw_recurrent_node(c, x=x_right, y=y_top)
    draw_recurrent_node(c, x=x_right, y=y_bottom)

    draw_arrow(c, x0=x_left+radius, y0=y_top,    x1=x_right-radius, y1=y_top,    direction="right")
    draw_arrow(c, x0=x_left+radius, y0=y_bottom, x1=x_right-radius, y1=y_bottom, direction="right")
    draw_arrow(c, x0=x_left+radius, y0=y_top,    x1=x_right-radius, y1=y_bottom, direction="right")
    draw_arrow(c, x0=x_left+radius, y0=y_bottom, x1=x_right-radius, y1=y_top,    direction="right")

    draw_arrow(c, x0=x_right+5, y0=y_bottom-radius, x1=x_right+5, y1=y_top+radius,    direction="up")
    draw_arrow(c, x0=x_right-5, y0=y_top+radius,    x1=x_right-5, y1=y_bottom-radius, direction="down")
    
    c.fill_text(n1, x_right-5, y_top+10)
    c.fill_text(n2, x_right-5, y_bottom+10)

draw_node(c, x=60, y=100)
draw_node(c, x=60, y=220)
c.font = '24px serif'
c.fill_text("4", 55, 230)
c.fill_text("5", 55, 110)

draw_node_layer(c, x_left=60, x_right=200, y_top=100, y_bottom=220, radius=40, n1="2", n2="3")
draw_node_layer(c, x_left=200, x_right=340, y_top=100, y_bottom=220, radius=40, n1="0", n2="1")


c.font = '12px serif'
c.fill_text("num_hidden=2 num_inputs=2", 40, 10)
    
c

Canvas(height=300, width=800)

In [90]:
def forward_grad_4hidden_fc():
    num_hidden_layers = 2
    num_hidden = 2
    num_input = 2
    num_nodes = num_hidden + num_input
    # Gij^k
    G = torch.zeros((num_hidden_layers, num_hidden, num_nodes, num_nodes), requires_grad=False)
    W = torch.rand((num_hidden_layers, num_nodes, num_nodes), requires_grad=True)
    num_itr = 5
    
    hidden1 = torch.zeros((num_hidden,1), requires_grad=True)
    hidden2 = torch.zeros((num_hidden,1), requires_grad=True)
    
    for itr in range(num_itr):
        inputs = torch.rand((num_input,1))        
        hidden1, G1 = forward_grad_2hidden_calc_fc(G[0], W[0], inputs, hidden1)
        G[0] = G1
        inputs = hidden1
        #G[1,:,:,2] = G[0,]
        hidden2, G2 = forward_grad_2hidden_calc_fc(G[1], W[1], inputs, hidden2)
        G[1] = G2
        
    y = torch.rand((num_hidden,1))
    error = torch.nn.MSELoss()(hidden2,y)
    error_grad = 2 * (hidden2 - y) / num_hidden
    print(error, error_grad)
    error.backward()
    # TODO: Check W.grad[0] for the first layer
    G_grad = G[1]
    G_grad[0,:,:] *= error_grad[0]
    G_grad[1,:,:] *= error_grad[1]
    print("===== Calculated =====")
    print(G_grad.sum(dim=0))
    print("===== Actual =====")
    print(W.grad[1])
    torch.testing.assert_allclose(G_grad.sum(dim=0), W.grad[1])
    
forward_grad_4hidden_fc()

tensor(0.0218, grad_fn=<MseLossBackward0>) tensor([[-0.1134],
        [ 0.1751]], grad_fn=<DivBackward0>)
===== Calculated =====
tensor([[-0.0170, -0.0179, -0.0156, -0.0178],
        [ 0.0220,  0.0231,  0.0201,  0.0229],
        [ 0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000]], grad_fn=<SumBackward1>)
===== Actual =====
tensor([[-0.0170, -0.0179, -0.0156, -0.0178],
        [ 0.0220,  0.0231,  0.0201,  0.0229],
        [ 0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000]])
