In [19]:
import torch
import ipycanvas
import numpy as np


In [20]:
def compute(x, h1, h2, gt1_w1, gt1_w2, w1, w2):
    # Single recursive node
    y = w1 * x + w2 * h1
    # dh1/dh2 * dh2/dw + dh1/dw
    gt_w1 = w2 * gt1_w1 + x
    gt_w2 = w2 * gt1_w2 + h1
    return y, gt_w1, gt_w2


w1 = torch.rand(2, requires_grad=True)
w2 = torch.rand(2, requires_grad=True)
h0 = 0
x1 = torch.rand(2, requires_grad=True)
h1, dh_w1, dh_w2 = compute(x1, h0, 0, 0, 0, w1, w2)
#assert h1 == 1

x2 = torch.rand(2, requires_grad=True)
h2, dh_w1, dh_w2 = compute(x2, h1, h0, dh_w1, dh_w2, w1, w2)
#assert h2 == (1.5 + 0.5)

x3 = torch.rand(2, requires_grad=True)
h3, dh_w1, dh_w2 = compute(x3, h2, h1, dh_w1, dh_w2, w1, w2)

y = torch.rand(2)
loss = torch.nn.MSELoss()(h3, y)
loss.backward()

error_grad = 2 * (h3 - y) / 2
# YAY! gt1_w1 is correct! With gt_w1 = w2 * gt1_w1 + x
# YAY! gt1_w2 is correct! With gt_w2 = w2 * gt1_w2 + h1
print(w1.grad, w2.grad, dh_w1*error_grad, dh_w2*error_grad)

tensor([ 0.2090, -0.3023]) tensor([ 0.0921, -0.0306]) tensor([ 0.2090, -0.3023], grad_fn=<MulBackward0>) tensor([ 0.0921, -0.0306], grad_fn=<MulBackward0>)


In [21]:


c = ipycanvas.Canvas(width=800, height=200)


def draw_recurrent_node(c, x=60, y=100):

    c.stroke_style = "red"
    c.stroke_circle(x, y, 40)

    c.begin_path()
    y_arc = y - 30
    x0 = x+20
    x1 = x-20
    c.move_to(x0, y_arc)
    c.quadratic_curve_to(x0 + 5, y_arc-50, x0 + (x1 - x0)//2, y_arc - 50)
    c.quadratic_curve_to(x1 - 5, y_arc-50, x1, y_arc)
    c.stroke()
    c.begin_path()
    arrow_width = 5
    c.line_to(x1-arrow_width, y_arc)
    c.line_to(x1, y_arc+arrow_width)
    c.line_to(x1+arrow_width, y_arc)
    c.fill()
    
def draw_arrow(c, x0, y0, x1, y1, arrow_width=5, direction="right"):
    c.begin_path()
    c.move_to(x0, y0)
    c.line_to(x1, y1)
    c.stroke()
    c.begin_path()
    if direction == "right":
        c.line_to(x1-arrow_width, y1-arrow_width)
        c.line_to(x1-arrow_width, y1+arrow_width)
        c.line_to(x1, y1)
    else:
        c.line_to(x1+arrow_width, y1+arrow_width)
        c.line_to(x1-arrow_width, y1+arrow_width)
        c.line_to(x1, y1)
    c.fill()

draw_recurrent_node(c, x=60, y=100)
draw_recurrent_node(c, x=200, y= 100)
draw_arrow(c, 100, 100, 160, 100)
draw_arrow(c, 60, 180, 60, 140, direction="up")
draw_arrow(c, 200, 180, 200, 140, direction="up")

c.font = '12px serif'
c.fill_text("w0", 40, 170)
c.fill_text("w1", 50, 15)
c.fill_text("w2", 120, 90)
c.fill_text("w3", 180, 170)
c.fill_text("w4", 190, 15)
    
c

Canvas(height=200, width=800)

In [60]:
# gt_w1 = w2 * gt1_w1 + x
# gt_w2 = w2 * gt1_w2 + h1

def compute_2(xs, hiddens, gs, weights):
    y0 = torch.mm(weights[0], xs[0]) + torch.mm(weights[1], hiddens[0])
    assert y0 == weights[0,0,0] * xs[0,0,0] + weights[1,0,0] * hiddens[0,0,0]
    y1 = torch.mm(weights[3], xs[1]) + torch.mm(weights[4], hiddens[1]) + torch.mm(weights[2], hiddens[0])
    assert y1 == weights[3,0,0] * xs[1,0,0] + weights[4,0,0] * hiddens[1,0,0] + weights[2,0,0] * hiddens[0,0,0]
    
    # [[ dh0/dh0_t, dh0/dh1_t ]
    #  [ dh1/dh0_t, dh1/dh1_t ]]
    Ht = torch.zeros((hiddens.shape[0], hiddens.shape[0]))
    Ht[0,0] = weights[1,0,0]
    Ht[0,1] = 0
    Ht[1,0] = weights[2,0,0]
    Ht[1,1] = weights[4,0,0]
    
    # [[ dh0/dw0 dh0/dw1 dh0/dw2 dh0/dw3 dh0/dw4]
    #  [ dh1/dw0 dh1/dw1 dh1/dw2 dh1/dw3 dh1/dw4]]
    Ft = torch.zeros((hiddens.shape[0], 5))
    Ft[0,0] = xs[0,0,0]
    Ft[0,1] = hiddens[0,0,0]
    
    # What are these?
    Ft[1,0] = 0
    Ft[1,1] = 0
    Ft[1,2] = hiddens[0,0,0]
    
    Ft[1,3] = xs[1,0,0]
    Ft[1,4] = hiddens[1,0,0] 
    
    gs2 = torch.mm(Ht, gs) + Ft
    #gs2[:] = 0
    #gs2[0,0] = weights[1,0,0] * gs[0,0] + xs[0,0,0]
    
    print(gs)
    #torch.testing.assert_allclose(gs_p, gs)
    
    y = torch.stack((y0, y1), dim=0)
    
    assert y.shape == hiddens.shape
    
    return (y, gs2)
    

def forward_grad_2node():
    num_nodes = 2
    weights = torch.rand((5, 1, 1), requires_grad=True)
    # two inputs at t=0
    x0 = torch.rand((num_nodes, 1, 1), requires_grad=True)
    hiddens = torch.zeros((num_nodes, 1, 1), requires_grad=True)
    gs = torch.zeros((num_nodes, 5), requires_grad=False)
    
    (hiddens, gs) = compute_2(x0, hiddens, gs, weights)
    torch.testing.assert_allclose(hiddens, torch.tensor([
        [[ x0[0,0,0] * weights[0,0,0] ]],
        [[ x0[1,0,0] * weights[3,0,0] ]]
    ]))
    
    if True:
        x1 = torch.rand((num_nodes, 1, 1), requires_grad=True)
        (hiddens, gs) = compute_2(x1, hiddens, gs, weights)
    
        if True:
            x2 = torch.rand((num_nodes, 1, 1), requires_grad=True)
            (hiddens, gs) = compute_2(x2, hiddens, gs, weights)
    
    
    y_actual = torch.rand(2,1,1)
    assert y_actual.shape == hiddens.shape
    if False:
        error = torch.nn.MSELoss()(hiddens[0], y_actual[0])
        error.backward()
        error_grad = 2 * (hiddens[0] - y_actual[0])
        gs_grad = (gs * error_grad)[0].view(5,1,1)
        print("Error backprop through node 0")
        print("GS:", gs_grad)
        print("Actual:", weights.grad)
        torch.testing.assert_allclose(gs_grad, weights.grad)
    
    elif True:
        error = torch.nn.MSELoss()(hiddens[1], y_actual[1])
        error.backward()
        error_grad = 2 * (hiddens[1] - y_actual[1])[0,0]
        gs_grad = (gs * error_grad)[1].view(5,1,1)
        print("Error backprop through node 1")
        print("GS:", gs_grad)
        print("Actual:", weights.grad)
        print(weights.grad[2,0,0]/error_grad)
        torch.testing.assert_allclose(gs_grad, weights.grad)
        
    else:
        error = torch.nn.MSELoss()(hiddens, y_actual)
        error.backward()
        error_grad = 2 * (hiddens - y_actual) / 2
        print(error_grad)
        gs_grad = (gs * error_grad)
        print("Error backprop through both")
        print("GS:", gs_grad)
        print("Actual:", weights.grad)
        #torch.testing.assert_allclose(gs_grad, weights.grad)
    
    
forward_grad_2node()

tensor([[0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.]])
tensor([[0.9092, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.2666, 0.0000]], grad_fn=<AddBackward0>)
tensor([[0.4751, 0.7062, 0.0000, 0.0000, 0.0000],
        [0.8663, 0.0000, 0.7062, 1.0641, 0.0051]], grad_fn=<AddBackward0>)
Error backprop through node 1
GS: tensor([[[0.3818]],

        [[0.3714]],

        [[0.3113]],

        [[0.1893]],

        [[0.3834]]], grad_fn=<ViewBackward0>)
Actual: tensor([[[0.3818]],

        [[0.3714]],

        [[0.3113]],

        [[0.1893]],

        [[0.3834]]])
tensor(0.5639, grad_fn=<DivBackward0>)


In [57]:
torch.rand(2,3,5).sum(dim=0)


1.6085

    model_clone = copy.deepcopy(model)
    optimizer = torch.optim.SGD(model_clone.parameters(), 0.) #used for zero.grad() function only here
    T = x.shape[0]    
    theta = torch.nn.utils.convert_parameters.parameters_to_vector(model_clone.parameters())      
    n_params = len(theta)
    
    h = torch.randn(model.hidden_size, dtype=model.dtype, requires_grad=True)
    dh_dtheta = torch.zeros((model.hidden_size, n_params), dtype=model.dtype) 
    dhnext_dhprev = torch.zeros((model.hidden_size, model.hidden_size), dtype=model.dtype)
    partial_dh_dtheta = torch.zeros_like(dh_dtheta)
    
    for t in range(T): 
        h_next = model_clone.h_step(x[t].view(1,1), h.view(1, model.hidden_size)).view(model.hidden_size)
        #compute dh/dhprev and partial dh/dparams
        for i_h in range(model.hidden_size):
            v = torch.zeros(model.hidden_size, dtype=model.dtype)
            v[i_h] = 1.  
            if i_h == model.hidden_size-1:
                h_next.backward(v) 
            else:
                h_next.backward(v, retain_graph=True) 
            dhnext_dhprev[i_h] = h.grad.clone()  
            h.grad = None             
            grad_generator = (param.grad if param.grad is not None else torch.zeros_like(param) for param in model_clone.parameters())                         
            theta_grad = torch.nn.utils.convert_parameters.parameters_to_vector(grad_generator)                         
            partial_dh_dtheta[i_h] = theta_grad.clone()                
            optimizer.zero_grad()        
        dh_dtheta = torch.mm(dhnext_dhprev, dh_dtheta) + partial_dh_dtheta                  
        h_next = h_next.detach()
        h = h_next.clone()            
        h.requires_grad = True           
   
    y_pred = model_clone.h_to_logits(h.view(1, model.hidden_size))
    loss = loss_func(y_pred, y.view(1))    
    loss.backward()        
    #add partial derivative of loss wrt. params and (loss wrt h) times (h wrt params)
    grad_generator = (param.grad if param.grad is not None else torch.zeros_like(param) for param in model_clone.parameters())            
    partial_theta_grad = torch.nn.utils.convert_parameters.parameters_to_vector(grad_generator)    
    theta_grad = partial_theta_grad.clone() + h.grad.clone() @ dh_dtheta   
    return loss.item(), theta_grad

In [18]:

class RNNCell(torch.nn.Module):
    def __init__(self):
        super(RNNCell, self).__init__()
        self.h2h = torch.nn.Linear(1,1, bias=False)
        self.x2h = torch.nn.Linear(1,1, bias=False)
        
    def forward(self, x, h):
        h = self.h2h(h) + self.x2h(x)
        return h

class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.r1 = RNNCell()
        self.r2 = RNNCell()
        
    def forward(self, x1, x2, h1, h2,):
        h1 = self.r1(x=x, h=h1)
        h2 = self.r2(x=h1, h=h2)
        return (h1, h2)

def model_example():
    BS = 1
    x0 = torch.rand((BS, 1))
    h1 = torch.zeros((1,))
    h2 = torch.zeros((1,))
    model = Net()
    h1, h2 = model(x0, h1, h2)
    print(h1, h2)
    x1 = torch.rand((BS, 1))
    h1, h2 = model(x1, h1, h2)
    print(h1, h2)
    print([p for p in model.parameters()])
    theta = torch.nn.utils.convert_parameters.parameters_to_vector(model.parameters())
    print(theta)
    
model_example()
    

tensor([[-0.1369]], grad_fn=<AddBackward0>) tensor([[-0.0677]], grad_fn=<AddBackward0>)
tensor([[0.1197]], grad_fn=<AddBackward0>) tensor([[0.1040]], grad_fn=<AddBackward0>)
[Parameter containing:
tensor([[-0.8895]], requires_grad=True), Parameter containing:
tensor([[-0.1743]], requires_grad=True), Parameter containing:
tensor([[-0.6627]], requires_grad=True), Parameter containing:
tensor([[0.4943]], requires_grad=True)]
tensor([-0.8895, -0.1743, -0.6627,  0.4943], grad_fn=<CatBackward0>)
