In [4]:
import torch
import ipycanvas
import numpy as np


In [5]:
def compute(x, h1, h2, gt1_w1, gt1_w2, w1, w2):
    # Single recursive node
    y = w1 * x + w2 * h1
    # dh1/dh2 * dh2/dw + dh1/dw
    gt_w1 = w2 * gt1_w1 + x
    gt_w2 = w2 * gt1_w2 + h1
    return y, gt_w1, gt_w2


w1 = torch.rand(2, requires_grad=True)
w2 = torch.rand(2, requires_grad=True)
h0 = 0
x1 = torch.rand(2, requires_grad=True)
h1, dh_w1, dh_w2 = compute(x1, h0, 0, 0, 0, w1, w2)
#assert h1 == 1

x2 = torch.rand(2, requires_grad=True)
h2, dh_w1, dh_w2 = compute(x2, h1, h0, dh_w1, dh_w2, w1, w2)
#assert h2 == (1.5 + 0.5)

x3 = torch.rand(2, requires_grad=True)
h3, dh_w1, dh_w2 = compute(x3, h2, h1, dh_w1, dh_w2, w1, w2)

y = torch.rand(2)
loss = torch.nn.MSELoss()(h3, y)
loss.backward()

error_grad = 2 * (h3 - y) / 2
# YAY! gt1_w1 is correct! With gt_w1 = w2 * gt1_w1 + x
# YAY! gt1_w2 is correct! With gt_w2 = w2 * gt1_w2 + h1
print(w1.grad, w2.grad, dh_w1*error_grad, dh_w2*error_grad)

tensor([ 0.7943, -0.3047]) tensor([ 0.2122, -0.0638]) tensor([ 0.7943, -0.3047], grad_fn=<MulBackward0>) tensor([ 0.2122, -0.0638], grad_fn=<MulBackward0>)


In [89]:


c = ipycanvas.Canvas(width=800, height=200)

def draw_node(c, x, y):
    c.stroke_style = "red"
    c.stroke_circle(x, y, 40)

def draw_recurrent_node(c, x=60, y=100):

    draw_node(c, x, y)

    c.begin_path()
    y_arc = y - 30
    x0 = x+20
    x1 = x-20
    c.move_to(x0, y_arc)
    c.quadratic_curve_to(x0 + 5, y_arc-50, x0 + (x1 - x0)//2, y_arc - 50)
    c.quadratic_curve_to(x1 - 5, y_arc-50, x1, y_arc)
    c.stroke()
    c.begin_path()
    arrow_width = 5
    c.line_to(x1-arrow_width, y_arc)
    c.line_to(x1, y_arc+arrow_width)
    c.line_to(x1+arrow_width, y_arc)
    c.fill()
    
def draw_arrow(c, x0, y0, x1, y1, arrow_width=5, direction="right"):
    c.begin_path()
    c.move_to(x0, y0)
    c.line_to(x1, y1)
    c.stroke()
    c.begin_path()
    if direction == "right":
        c.line_to(x1-arrow_width, y1-arrow_width)
        c.line_to(x1-arrow_width, y1+arrow_width)
        c.line_to(x1, y1)
    elif direction == "left":
        c.line_to(x1+arrow_width, y1-arrow_width)
        c.line_to(x1+arrow_width, y1+arrow_width)
        c.line_to(x1, y1)
    else:
        c.line_to(x1+arrow_width, y1+arrow_width)
        c.line_to(x1-arrow_width, y1+arrow_width)
        c.line_to(x1, y1)
    c.fill()


In [90]:
c = ipycanvas.Canvas(width=800, height=300)

draw_recurrent_node(c, x=60, y=100)
draw_recurrent_node(c, x=200, y= 100)
draw_arrow(c, 100, 100, 160, 100)
draw_arrow(c, 160, 110, 100, 110, direction="left")
draw_arrow(c, x0=60, y0=180, x1=60, y1=140, direction="up")
draw_arrow(c, x0=200, y0=180, x1=200, y1=140, direction="up")

draw_arrow(c, x0=60, y0=180, x1=190, y1=140, direction="up")
draw_arrow(c, x0=200, y0=180, x1=70, y1=140, direction="up")

draw_node(c, x=60, y=220)
draw_node(c, x=200, y=220)

c.font = '24px serif'
c.fill_text("2", 55, 230)
c.fill_text("0", 55, 110)
c.fill_text("3", 195, 230)
c.fill_text("1", 195, 110)
    
c

Canvas(height=300, width=800)

In [92]:
def calc_next_G_fc_loop(G, W, inputs, hiddens):
    G_new = torch.zeros_like(G, requires_grad=False)
    for k in range(G.shape[0]):
        for i in range(G.shape[1]):
            for j in range(G.shape[2]):
                # p should just belong to the hiddens 0,1
                for p in range(G.shape[0]):
                    G_new[k,i,j] += W[k,p] * G[p,i,j]
                if i == k:
                    # hiddens = 0,1
                    if j >= hiddens.shape[0]:
                        G_new[k,i,j] += inputs[j-hiddens.shape[0], 0]
                    else:
                        G_new[k,i,j] += hiddens[j,0]
    
    return G_new

def calc_next_G_fc(G, W, inputs, hiddens):
    G_new = torch.zeros_like(G, requires_grad=False)
    for k in range(G.shape[0]):
        for i in range(G.shape[1]):
            G_new[k,i:i+1,:] += torch.mm(W[k:k+1,0:G.shape[0]], G[:,i,:])
        
        i = k
        for j in range(G.shape[2]):
            if j >= hiddens.shape[0]:
                G_new[k,i,j] += inputs[j-hiddens.shape[0], 0]
            else:
                G_new[k,i,j] += hiddens[j,0]

    return G_new

def forward_grad_2hidden_calc_activation(hiddens, G):
    new_hiddens = torch.sigmoid(hiddens)
    sigmoid_deriv = torch.sigmoid(hiddens) * (1 - torch.sigmoid(hiddens))
    for k in range(G.shape[0]):
        G[k,:,:] *= sigmoid_deriv[k]
    return new_hiddens, G

def forward_grad_2hidden_calc_fc(G, W, inputs, hiddens):
    z = torch.cat((hiddens, inputs))
    new_hiddens = torch.mm(W[:hiddens.shape[0]], z)
    
    G = calc_next_G_fc(G, W, inputs, hiddens)
    
    new_hiddens, G = forward_grad_2hidden_calc_activation(new_hiddens, G)
    
    return (new_hiddens, G)
    

def forward_grad_2hidden_fc():
    num_hidden = 2
    num_input = 2
    num_nodes = num_hidden + num_input
    # Should G be k=num_hidden, i=num_nodes, j=num_hidden?
    # Gij^k
    G = torch.zeros((num_hidden, num_nodes, num_nodes), requires_grad=True)
    W = torch.rand((num_nodes, num_nodes), requires_grad=True)
    num_itr = 10
    hidden = torch.zeros((num_hidden,1), requires_grad=True)
    
    for itr in range(num_itr):
        inputs = torch.rand((num_input,1))        
        hidden, G = forward_grad_2hidden_calc_fc(G, W, inputs, hidden)
        
    y = torch.rand((2,1))
    error = torch.nn.MSELoss()(hidden,y)
    error_grad = 2 * (hidden - y) / 2
    print(error, error_grad)
    error.backward()
    G_grad = G
    G_grad[0,:,:] *= error_grad[0]
    G_grad[1,:,:] *= error_grad[1]
    print(G_grad.sum(dim=0))
    print(W.grad)
    torch.testing.assert_allclose(G_grad.sum(dim=0), W.grad)
    

    
    
forward_grad_2hidden_fc()

tensor(0.0969, grad_fn=<MseLossBackward0>) tensor([[-0.2021],
        [ 0.3911]], grad_fn=<DivBackward0>)
tensor([[-0.0304, -0.0329, -0.0031, -0.0189],
        [ 0.0416,  0.0453,  0.0060,  0.0279],
        [ 0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000]], grad_fn=<SumBackward1>)
tensor([[-0.0304, -0.0329, -0.0031, -0.0189],
        [ 0.0416,  0.0453,  0.0060,  0.0279],
        [ 0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000]])


In [77]:
c = ipycanvas.Canvas(width=800, height=300)

draw_recurrent_node(c, x=60, y=100)
draw_recurrent_node(c, x=200, y= 100)
draw_arrow(c, 100, 100, 160, 100)
draw_arrow(c, x0=60, y0=180, x1=60, y1=140, direction="up")
draw_arrow(c, x0=200, y0=180, x1=200, y1=140, direction="up")


draw_node(c, x=60, y=220)
draw_node(c, x=200, y=220)

c.font = '24px serif'
c.fill_text("2", 55, 230)
c.fill_text("0", 55, 110)
c.fill_text("3", 195, 230)
c.fill_text("1", 195, 110)
    
c

Canvas(height=300, width=800)