In [7]:
using Distributions;
using Random;
using DataFrames;
using CSV;
using ForwardDiff;
using Statistics;
using LinearAlgebra;
using Tracker;
using Flux;
using Flux: params, 
            Dense, 
            Chain, 
            glorot_normal, 
            normalise, 
            Optimiser,
            train!;


#initialise error file and create row headers
df_row = DataFrame(step = "step",value="value")
    #,mc_val = "mc value", sd_err = "sd error")

CSV.write("hjb_test.csv", df_row, append = true);


N = 8
T = 8/24
h = T/N

#parameters of the function
mu_bar = 0
sigma_bar = sqrt(2)

#dimensions
d = 10

#data sizes
batch_samples = 256

#training epochs parameters
train_steps = 600
learning_rate = 0.1 #initial learning rate
learn_rate_decrease = 400 #frequency of LR decay

#initial condition phi 
function phi(u)
    return sum(u.^2; dims=1).^(0.25)
end


#non-linear function 
f(x,y,z) = - sum(z.^2; dims=1)

function y_sde(m)  

    y_evolve = Array(zeros((d,batch_samples)))
    y_1 = Array(zeros((d,batch_samples))) 
    y_2 = Array(zeros((d,batch_samples))) 
    
    for i in 1:m 
        
        y_evolve .+= rand(Normal(0,sqrt(2*T/N)),(d, batch_samples))
                        #evolve the SDE for N-n steps
        
        #one step earlier
        if i == m-1 y_1 = y_evolve end 
        #all the way
        if i == m y_2 = y_evolve end
    end   
    
    #see if this makes a difference? 
#    if m == 1
#        y_1 = Array(zeros((d,batch_samples))) 
#    end

    return y_1,y_2
end

#define network layers
input = Dense(d, d + 10, relu; 
                       bias = false, 
                       init = glorot_normal)

hidden1 = Dense(d + 10 , d + 10, relu;
                        bias = false,
                        init = glorot_normal)

hidden2 = Dense(d + 10, d + 10 , relu;
                        bias = false,
                        init = glorot_normal)

#no activation on the last layer
output = Dense(d + 10 ,1,identity)

batch_norm_layer = BatchNorm(d + d, identity;
                                        initβ = zeros, 
                                        initγ = ones,
                                        ϵ = 1e-6, 
                                        momentum = 0.9)

#define network architecture for fixed model
m_fix = Chain(input,
    #        batch_norm_layer,
            hidden1,
    #        batch_norm_layer,
            hidden2,
    #        batch_norm_layer,
            output)

#define network layers for a clean set of parameters
input_params = Dense(d, d + 10 , relu; 
                       bias = false, 
                       init = glorot_normal)

hidden_params1 = Dense(d + 10 , d + 10, relu;
                        bias = false,
                        init = glorot_normal)

hidden_params2 = Dense(d + 10 , d + 10 , relu;
                        bias = false,
                        init = glorot_normal)

#no activation on the last layer
output_params = Dense(d + 10,1,identity)

#define network architecture for trianing model
m_parameters = Chain(input_params,
    #        batch_norm_layer,
            hidden_params1,
    #        batch_norm_layer,
            hidden_params2,
    #        batch_norm_layer,
            output_params)

#fixed set of starting parameters for re-setting the nn at each step. 
ps_start = Flux.params(m_parameters)

#define network layers
input_train = Dense(d, d + 10 , relu; 
                       bias = false, 
                       init = glorot_normal)

hidden_train1 = Dense(d + 10 , d + 10, relu;
                        bias = false,
                        init = glorot_normal)

hidden_train2 = Dense(d + 10, d + 10, relu;
                        bias = false,
                        init = glorot_normal)

#no activation on the last layer
output_train = Dense(d + 10,1,identity)

#define network architecture for trianing model
m_train = Chain(input_train,
    #        batch_norm_layer,
            hidden_train1,
    #        batch_norm_layer,
            hidden_train2,
    #        batch_norm_layer,
            output_train)


loss(u,v) = mean((m_train(u) - v).^2)

#initialise V_0
V_i = phi

y_test = Array(ones((d,batch_samples)))

#set the starting point

for i in 1:N
    
    #reset parameters on trainable model
    Flux.loadparams!(m_train, ps_start)     

    ps = params(m_train)
    
    opt = Optimiser(ExpDecay(learning_rate,0.1,learn_rate_decrease,1e-8),ADAM()) #optimiser

    #find parameters for nth nn 
    for k in 1:train_steps

        #find the 2 last time steps of SDE
        y = y_sde(N - i + 1)
        
        #format into training data y_1 and y_2
        y_n = y[1]
        y_n_minus_1 = y[2]
                        
        #Finding Gradients -
        grad_V_i = Array(ones((d,batch_samples)))

        for j in 1:batch_samples
            grad_V_i_j = ForwardDiff.jacobian(V_i,y_n_minus_1[:,j])
            grad_V_i[:,j] = grad_V_i_j
        end 
     #   grad_V_i = ForwardDiff.jacobian(V_i,y_n_minus_1)
        
       # Vv_y = V_i(y_n_minus_1) + h*f(y_n_minus_1,V_i(y_n_minus_1),grad_V_i)
         Vv_y = V_i(y_n_minus_1) + h*f(y_n_minus_1,V_i(y_n_minus_1),grad_V_i)

        data = [(y_n,Vv_y)]  

        #parameter update step
        train!(loss,ps,data,opt)
        
        #manual update to opt at 500 steps
   #     if k== 500 
   #         opt = ADAM(learning_rate/100) #optimiser
   #     end

    y_test = y_n
        
    end

    
    #load learned params into fixed model
    Flux.loadparams!(m_fix, ps)     
    
    #set new V_i to be the fixed model with new parameters
    V_i = m_fix #nn function with parameters 
    
    #take the test value from nn
    val = mean(V_i(y_test))
    
    println("v_i output",val)    
 #   println("y_n", mean(y_n;dims = 2))
    
    df_row = DataFrame(step =i,
                        value=val)

    CSV.write("hjb_test.csv", df_row, append = true)

end



v_i output1.5201885478815134
v_i output1.4525123668730013
v_i output1.2670165923403
v_i output1.1685132222585828
v_i output1.0351758698890043
v_i output0.8859008126094915
v_i output0.7266224163872155
v_i output0.3542717397212982
