In [1]:
import torch
import ipycanvas
import numpy as np

from IPython.display import display


In [2]:
def compute(x, h1, gt1_w1, gt1_w2, w1, w2):
    # Single recursive node
    y = w1 * x + w2 * h1
    # dh1/dh2 * dh2/dw + dh1/dw
    gt_w1 = w2 * gt1_w1 + x
    gt_w2 = w2 * gt1_w2 + h1
    return y, gt_w1, gt_w2


w1 = torch.rand(2, requires_grad=True)
w2 = torch.rand(2, requires_grad=True)
h0 = 0
x1 = torch.rand(2, requires_grad=True)
h1, dh_w1, dh_w2 = compute(x1, h0, 0, 0, w1, w2)
#assert h1 == 1

x2 = torch.rand(2, requires_grad=True)
h2, dh_w1, dh_w2 = compute(x2, h1, dh_w1, dh_w2, w1, w2)
#assert h2 == (1.5 + 0.5)

x3 = torch.rand(2, requires_grad=True)
h3, dh_w1, dh_w2 = compute(x3, h2, dh_w1, dh_w2, w1, w2)

y = torch.rand(2)
loss = torch.nn.MSELoss()(h3, y)
loss.backward()

error_grad = 2 * (h3 - y) / 2
# YAY! gt1_w1 is correct! With gt_w1 = w2 * gt1_w1 + x
# YAY! gt1_w2 is correct! With gt_w2 = w2 * gt1_w2 + h1
print(w1.grad, w2.grad, dh_w1*error_grad, dh_w2*error_grad)

tensor([ 0.3456, -0.3735]) tensor([ 0.1448, -0.2196]) tensor([ 0.3456, -0.3735], grad_fn=<MulBackward0>) tensor([ 0.1448, -0.2196], grad_fn=<MulBackward0>)


In [3]:


c = ipycanvas.Canvas(width=800, height=200)

def draw_node(c, x, y):
    c.stroke_style = "red"
    c.stroke_circle(x, y, 40)

def draw_recurrent_node(c, x=60, y=100):

    draw_node(c, x, y)

    c.begin_path()
    y_arc = y - 40
    x0 = x+15
    x1 = x-15
    c.move_to(x0, y_arc)
    c.quadratic_curve_to(x0 + 5, y_arc-25, x0 + (x1 - x0)//2, y_arc - 25)
    c.quadratic_curve_to(x1 - 5, y_arc-25, x1, y_arc)
    c.stroke()
    c.begin_path()
    arrow_width = 5
    c.line_to(x1-arrow_width, y_arc)
    c.line_to(x1, y_arc+arrow_width)
    c.line_to(x1+arrow_width, y_arc)
    c.fill()
    
def draw_arrow(c, x0, y0, x1, y1, arrow_width=5, direction="right"):
    c.begin_path()
    c.move_to(x0, y0)
    c.line_to(x1, y1)
    c.stroke()
    c.begin_path()
    if direction == "right":
        c.line_to(x1-arrow_width, y1-arrow_width)
        c.line_to(x1-arrow_width, y1+arrow_width)
        c.line_to(x1, y1)
    elif direction == "left":
        c.line_to(x1+arrow_width, y1-arrow_width)
        c.line_to(x1+arrow_width, y1+arrow_width)
        c.line_to(x1, y1)
    elif direction == "down":
        c.line_to(x1+arrow_width, y1-arrow_width)
        c.line_to(x1-arrow_width, y1-arrow_width)
        c.line_to(x1, y1)
    else:
        c.line_to(x1+arrow_width, y1+arrow_width)
        c.line_to(x1-arrow_width, y1+arrow_width)
        c.line_to(x1, y1)
    c.fill()


In [4]:
c = ipycanvas.Canvas(width=800, height=300)



def draw_node_layer(c, x_left, x_right, y_top, y_bottom, radius, n1, n2):
    draw_recurrent_node(c, x=x_right, y=y_top)
    draw_recurrent_node(c, x=x_right, y=y_bottom)

    draw_arrow(c, x0=x_left+radius, y0=y_top,    x1=x_right-radius, y1=y_top,    direction="right")
    draw_arrow(c, x0=x_left+radius, y0=y_bottom, x1=x_right-radius, y1=y_bottom, direction="right")
    draw_arrow(c, x0=x_left+radius, y0=y_top,    x1=x_right-radius, y1=y_bottom, direction="right")
    draw_arrow(c, x0=x_left+radius, y0=y_bottom, x1=x_right-radius, y1=y_top,    direction="right")

    draw_arrow(c, x0=x_right+5, y0=y_bottom-radius, x1=x_right+5, y1=y_top+radius,    direction="up")
    draw_arrow(c, x0=x_right-5, y0=y_top+radius,    x1=x_right-5, y1=y_bottom-radius, direction="down")
    
    c.fill_text(n1, x_right-5, y_top+10)
    c.fill_text(n2, x_right-5, y_bottom+10)

draw_node(c, x=60, y=100)
draw_node(c, x=60, y=220)
c.font = '24px serif'
c.fill_text("2", 55, 230)
c.fill_text("3", 55, 110)

draw_node_layer(c, x_left=60, x_right=200, y_top=100, y_bottom=220, radius=40, n1="0", n2="1")


c.font = '12px serif'
c.fill_text("num_hidden=2 num_inputs=2", 40, 10)

display(c)

Canvas(height=300, width=800)

In [5]:
def calc_next_G_fc_loop(G, W, inputs, hiddens):
    G_new = torch.zeros_like(G, requires_grad=False)
    for k in range(G.shape[0]):
        for i in range(G.shape[1]):
            for j in range(G.shape[2]):
                # p should just belong to the hiddens 0,1
                for p in range(G.shape[0]):
                    G_new[k,i,j] += W[k,p] * G[p,i,j]
                # if the destination equals a node to backprop from
                if i == k:
                    # hiddens = 0,1
                    if j >= hiddens.shape[0]:
                        G_new[k,i,j] += inputs[j-hiddens.shape[0], 0]
                    else:
                        G_new[k,i,j] += hiddens[j,0]
    
    return G_new

def calc_next_G_fc(G, W, inputs, hiddens):
    G_new = torch.zeros_like(G, requires_grad=False)
    for k in range(G.shape[0]):
        for i in range(G.shape[1]):
            G_new[k,i:i+1,:] += torch.mm(W[k:k+1,0:G.shape[0]], G[:,i,:])
        
        i = k
        for j in range(G.shape[2]):
            if j >= hiddens.shape[0]:
                G_new[k,i,j] += inputs[j-hiddens.shape[0], 0]
            else:
                G_new[k,i,j] += hiddens[j,0]

    return G_new

def forward_grad_2hidden_calc_activation(hiddens, G):
    new_hiddens = torch.sigmoid(hiddens)
    sigmoid_deriv = torch.sigmoid(hiddens) * (1 - torch.sigmoid(hiddens))
    for k in range(G.shape[0]):
        G[k,:,:] *= sigmoid_deriv[k]
    return new_hiddens, G

def forward_grad_2hidden_calc_fc(G, W, inputs, hiddens):
    # Forward prop
    z = torch.cat((hiddens, inputs))
    new_hiddens = torch.mm(W[:hiddens.shape[0]], z)
    # Calculate gradients
    G = calc_next_G_fc(G, W, inputs, hiddens)
    # Nonlinear activation
    #new_hiddens, G = forward_grad_2hidden_calc_activation(new_hiddens, G)
    
    return (new_hiddens, G)
    

def forward_grad_2hidden_fc():
    num_hidden = 3
    num_input = 4
    num_nodes = num_hidden + num_input
    # Gij^k
    G = torch.zeros((num_hidden, num_hidden, num_nodes), requires_grad=False)
    W = torch.rand((num_hidden, num_nodes), requires_grad=True)
    num_itr = 5
    hidden = torch.zeros((num_hidden,1), requires_grad=False)
    
    
    for itr in range(num_itr):
        inputs = torch.rand((num_input,1))        
        hidden, G = forward_grad_2hidden_calc_fc(G, W, inputs, hidden)
        
    y = torch.rand((num_hidden,1))
    error = torch.nn.MSELoss()(hidden,y)
    error_grad = 2 * (hidden - y) / num_hidden
    print(error, error_grad)
    error.backward()
    G_grad = G
    for i in range(error_grad.shape[0]):
        G_grad[i,:,:] *= error_grad[i]
    print("===== Calculated =====")
    print(G_grad.sum(dim=0))
    print("===== Actual =====")
    print(W.grad)
    torch.testing.assert_allclose(G_grad.sum(dim=0), W.grad)
    

    
    
forward_grad_2hidden_fc()

tensor(330.9234, grad_fn=<MseLossBackward0>) tensor([[10.9442],
        [12.0475],
        [13.2783]], grad_fn=<DivBackward0>)
===== Calculated =====
tensor([[538.6920, 421.2439, 533.8552, 246.4551, 298.0990,  83.9137,  89.0699],
        [231.0111, 210.3058, 250.2911,  65.7034,  78.2123,  32.0094,  30.8143],
        [482.5910, 389.2635, 486.9231, 206.2603, 248.4964,  73.6297,  77.3962]],
       grad_fn=<SumBackward1>)
===== Actual =====
tensor([[538.6920, 421.2439, 533.8553, 246.4551, 298.0990,  83.9137,  89.0699],
        [231.0111, 210.3058, 250.2911,  65.7034,  78.2123,  32.0094,  30.8143],
        [482.5910, 389.2635, 486.9231, 206.2603, 248.4964,  73.6297,  77.3962]])




In [6]:
c = ipycanvas.Canvas(width=800, height=300)



def draw_node_layer(c, x_left, x_right, y_top, y_bottom, radius, n1, n2):
    draw_recurrent_node(c, x=x_right, y=y_top)
    draw_recurrent_node(c, x=x_right, y=y_bottom)

    draw_arrow(c, x0=x_left+radius, y0=y_top,    x1=x_right-radius, y1=y_top,    direction="right")
    draw_arrow(c, x0=x_left+radius, y0=y_bottom, x1=x_right-radius, y1=y_bottom, direction="right")
    draw_arrow(c, x0=x_left+radius, y0=y_top,    x1=x_right-radius, y1=y_bottom, direction="right")
    draw_arrow(c, x0=x_left+radius, y0=y_bottom, x1=x_right-radius, y1=y_top,    direction="right")

    draw_arrow(c, x0=x_right+5, y0=y_bottom-radius, x1=x_right+5, y1=y_top+radius,    direction="up")
    draw_arrow(c, x0=x_right-5, y0=y_top+radius,    x1=x_right-5, y1=y_bottom-radius, direction="down")
    
    c.fill_text(n1, x_right-5, y_top+10)
    c.fill_text(n2, x_right-5, y_bottom+10)

draw_node(c, x=60, y=100)
draw_node(c, x=60, y=220)
c.font = '24px serif'
c.fill_text("4", 55, 230)
c.fill_text("5", 55, 110)

draw_node_layer(c, x_left=60, x_right=200, y_top=100, y_bottom=220, radius=40, n1="2", n2="3")
draw_node_layer(c, x_left=200, x_right=340, y_top=100, y_bottom=220, radius=40, n1="0", n2="1")


c.font = '12px serif'
c.fill_text("num_hidden=2 num_inputs=2", 40, 10)
    
c

Canvas(height=300, width=800)

In [196]:
def forward_grad_4hidden_fc():
    num_hidden = 4
    num_input = 2
    num_nodes = num_hidden + num_input
    # Gij^k
    G = torch.zeros((num_hidden, num_hidden, num_nodes), requires_grad=False)
    W = torch.rand((num_hidden, num_nodes), requires_grad=True).detach()
    # 16 weights vs 4*6=24 weights
    # No connections to 0,1 from 4,5
    W[0:2, 4:6] = 0
    # No connections to 2,3 from 0,1
    W[2:4, 0:2] = 0
    W.requires_grad = True
    # G_new[k,i,j] += W[k,p] * G[p,i,j]
    # G[0] is still influenced by G[2]
    # At each time step, every node is influenced by every weight by a unique value
    num_itr = 5
    hidden = torch.zeros((num_hidden,1), requires_grad=False)
    
    
    for itr in range(num_itr):
        inputs = torch.rand((num_input,1))        
        hidden, G = forward_grad_2hidden_calc_fc(G, W, inputs, hidden)
        
    y = torch.rand((num_hidden,1))
    error = torch.nn.MSELoss()(hidden,y)
    error_grad = 2 * (hidden - y) / num_hidden
    print(error, error_grad)
    error.backward()
    G_grad = G
    for i in range(error_grad.shape[0]):
        G_grad[i,:,:] *= error_grad[i]
    print("===== Calculated =====")
    print(G_grad.sum(dim=0))
    print("===== Actual =====")
    print(W.grad)
    torch.testing.assert_allclose(G_grad.sum(dim=0), W.grad)
    
forward_grad_4hidden_fc()

tensor(47.9713, grad_fn=<MseLossBackward0>) tensor([[5.0851],
        [4.1420],
        [1.3573],
        [1.7647]], grad_fn=<DivBackward0>)
===== Calculated =====
tensor([[43.2505, 41.6537, 38.2249, 45.3476, 27.4623, 24.5477],
        [33.5932, 32.2978, 29.0316, 34.4059, 20.3953, 18.3798],
        [25.8956, 26.1423, 33.2734, 40.3334, 36.4044, 34.5322],
        [34.2687, 35.3809, 50.4421, 61.5805, 64.4329, 64.5077]],
       grad_fn=<SumBackward1>)
===== Actual =====
tensor([[43.2505, 41.6537, 38.2249, 45.3476, 27.4623, 24.5477],
        [33.5932, 32.2978, 29.0316, 34.4059, 20.3953, 18.3798],
        [25.8956, 26.1423, 33.2734, 40.3334, 36.4044, 34.5322],
        [34.2687, 35.3809, 50.4421, 61.5805, 64.4329, 64.5078]])


In [61]:
[i for i in range(2,4)]

[2, 3]

In [7]:
device = torch.device("cpu") # cuda" if torch.cuda.is_available() else "cpu")

class Node():
    def __init__(self, num_raw_inputs, num_hidden_inputs, num_nodes, is_output_node=True):
        self.num_raw_inputs = num_raw_inputs
        num_inputs = num_raw_inputs + num_hidden_inputs + 1
        self.W = torch.rand((1, num_inputs), requires_grad=True, device=device)
        self.hidden = torch.zeros((1,1), requires_grad=False, device=device)
        
        self.is_output_node = is_output_node
        if self.is_output_node:
            self.G = torch.zeros((num_inputs), requires_grad=False, device=device)
        
        
    def forward(self, x):
        old_hidden = self.hidden
        
        z = torch.cat((x, old_hidden))
        self.z = z
        self.pending_hidden = torch.mm(self.W, z)
        
    def calculate_update(self, g_in, Ws):
        # Need the G's of any node that leads to me including myself
        self.pending_G = self.G.clone().detach()
        with torch.no_grad():
            self.pending_G[-1] += self.W[0,-1] * self.G[-1]
            
            # My result is because of any weights in the path to me
            # Self k == p
            for i in range(self.G.shape[0]):
                update = self.W[0,k+self.num_raw_inputs] * g_in[k]
                self.pending_G[k+self.num_raw_inputs] += update
            # Others
            for j in range(g_in[k].shape[0]):
                # My G is increased by any path from a hidden to me
                update = Ws[k][j] * g_in[k][j]
                torch.testing.assert_allclose(update.shape, torch.Size([]))
                self.pending_G[j] += update

            for j in range(self.z.shape[0]):
                self.pending_G[j] += self.z[j,0]
            
    def commit_update(self):
        self.G = self.pending_G
        print("New G:", self.G[-2])
        self.hidden = self.pending_hidden
    
nodes = [Node(2, 1, 1) for _ in range(2)]
num_itrs = 2
for _ in range(num_itrs):
    x_in = torch.rand((2,1))
    x_in_n0 = torch.cat((x_in, nodes[1].hidden))
    nodes[0].forward(x_in_n0)
    x_in_n1 = torch.cat((x_in, nodes[0].hidden))
    nodes[1].forward(x_in_n1)
    
    nodes[0].calculate_update(g_in=[nodes[1].G)
    nodes[1].calculate_update(g_in=[nodes[0].G)
    
    nodes[0].commit_update()
    nodes[1].commit_update()
    
    if False:
        y = torch.rand((2,1))
        y_hat = torch.cat((nodes[0].hidden, nodes[1].hidden))
        error = torch.nn.MSELoss()(y_hat, y)
        error_grad = 2 * (y_hat - y) / 2
        error.backward(retain_graph=True)
        print("===== Actual =====")
        print(nodes[0].W.grad, nodes[1].W.grad)
        print("===== Calculated =====")
        print(nodes[0].G * error_grad[0].item(), nodes[1].G * error_grad[1].item())
        nodes[0].W.grad = None
        nodes[1].W.grad = None

y = torch.rand((2,1))
y_hat = torch.cat((nodes[0].hidden, nodes[1].hidden))
error = torch.nn.MSELoss()(y_hat, y)
error_grad = 2 * (y_hat - y) / 2
error.backward(retain_graph=False)
print(nodes[0].W.grad, nodes[1].W.grad)
print(nodes[0].G * error_grad[0].item(), nodes[1].G * error_grad[1].item())

IndentationError: expected an indented block (301224314.py, line 31)