# 7.9 BPTT (backpropagation through time)

In [None]:
from pylab import *
%matplotlib notebook

def theta(x):
    return 0.5*(1 + sign(x))


def f(x):
    return np.tanh(x)


def df(x):
    return 1/np.cosh(10*np.tanh(x/10))**2  # the tanh prevents oveflow

In [None]:
using Base: @kwdef
using Parameters: @unpack # or using UnPack

In [1]:
using Random

In [7]:
sqrt(2)

1.4142135623730951

In [None]:
@kwdef struct RNNParameter{FT}
    α::FT = 1.0
    τm::FT
end

In [None]:
@kwdef mutable struct RNN{FT}
    param::RNNParameter = RNNParameter{FT}()
    n_in::UInt32
    n_rec::UInt32
    n_out::UInt32
    h0::Array{FT}

    # weights
    w_in::Array{FT} = 0.1*(rand(n_rec, n_in) .- 1)
    w_rec::Array{FT} = 1.5*randn(n_rec, n_rec)/sqrt(n_rec)
    w_out::Array{FT} = 0.1*(2*rand(n_out, n_rec) .- 1)/sqrt(n_rec)
end

In [None]:
function update!(variable::RNN, param::RNNParameter, inputs::Array, training::Bool)
    @unpack num_units_lv0, num_units_lv1, num_units_lv2, num_lv1, k2, r, rh, U, Uh = variable
    @unpack α, αh, var, vartd, inv_var, inv_vartd, k1, λ = param

    
    [eta3, eta2, eta1] = eta  # learning rates for w_in, w_rec, and w_out
    t_max = np.shape(x)[0]  # number of timesteps

    dw_in, dw_rec, dw_out = 0, 0, 0  # changes to weights

    u = np.zeros((t_max, self.n_rec))  # input (feedforward plus recurrent)
    h = np.zeros((t_max, self.n_rec))  # time-dependent RNN activity vector
    h[0] = self.h0  # initial state
    y = np.zeros((t_max, self.n_out))  # RNN output
    err = np.zeros((t_max, self.n_out))  # readout error

    for tt in range(t_max-1):
        u[tt+1] = np.dot(self.w_rec, h[tt]) + np.dot(self.w_in, x[tt+1])
        h[tt+1] = h[tt] + (-h[tt] + f(u[tt+1]))/self.tau_m
        y[tt+1] = np.dot(self.w_out, h[tt+1])
        err[tt+1] = y_[tt+1] - y[tt+1]  # readout error
    end
    
    # backward pass for BPTT
    if training 
        z = np.zeros((t_max, self.n_rec))
        z[-1] = np.dot((self.w_out).T, err[-1])
        for tt in range(t_max-1, 0, -1):
            z[tt-1] = z[tt]*(1 - 1/self.tau_m)
            z[tt-1] += np.dot((self.w_out).T, err[tt])
            z[tt-1] += np.dot(z[tt]*df(u[tt]), self.w_rec)/self.tau_m

            # Updates for the weights:
            dw_out += eta1*np.outer(err[tt], h[tt])/t_max
            dw_rec += eta2/(t_max*self.tau_m)*np.outer(z[tt]*df(u[tt]),
                                                        h[tt-1])
            dw_in += eta3/(t_max*self.tau_m)*np.outer(z[tt]*df(u[tt]),
                                                       x[tt])
        end
        
        # wait until end of trial to update weights
        self.w_out = self.w_out + dw_out
        self.w_rec = self.w_rec + dw_rec
        self.w_in = self.w_in + dw_in
    end
        
    # Calculate errors
    error = inputs - fx # (3, 256)

    return error
end