In [78]:
using Distributions;
using Random;
using DataFrames;
using CSV;
using ForwardDiff;
using Statistics;
using LinearAlgebra;
using Tracker;
using Flux;
using Flux: params, 
            Dense, 
            Chain, 
            glorot_normal, 
            normalise, 
            Optimiser,
            train!;


#initialise error file and create row headers
df_row = DataFrame(step = "step",value="value")

CSV.write("hjb_test_v2.csv", df_row, append = true);


N = 8
T = 1/3
h = T/N

#parameters of the function
mu_bar = 0
sigma_bar = sqrt(2)

#dimensions
d = 10

#data sizes
batch_samples = 10000

#training epochs parameters
train_steps = 6000
learning_rate = 0.1 #initial learning rate
learn_rate_decrease = 400 #frequency of LR decay

#initial condition phi 
function phi(u)
    return sum(u.^2; dims=1).^(0.25)
end


#non-linear function 
f(x,y,z) = - sum(z.^2; dims=1)

function y_sde(m)  

    #y_evolve = Array(zeros((d,batch_samples)))
    #y_1 = Array(zeros((d,batch_samples))) 
    #y_2 = Array(zeros((d,batch_samples))) 
    
    if m == 1
        #y_0
        y_1 = Array(zeros((d,batch_samples))) 
    else
        y_1 = sigma_bar .* rand(Normal(0,sqrt((m-1)*T/N)),(d, batch_samples))
    end
    
    y_2 = y_1 .+ (sigma_bar .* rand(Normal(0,sqrt(T/N)),(d, batch_samples)))

    return y_1,y_2
end

#define network layers
input = Dense(d, d + 10, relu; 
                       bias = false, 
                       init = glorot_normal)

hidden1 = Dense(d + 10 , d + 10, relu;
                        bias = false,
                        init = glorot_normal)

hidden2 = Dense(d + 10, d + 10 , relu;
                        bias = false,
                        init = glorot_normal)

#no activation on the last layer
output = Dense(d + 10 ,1,identity)

batch_norm_layer = BatchNorm(d + d, identity;
                                        initβ = zeros, 
                                        initγ = ones,
                                        ϵ = 1e-6, 
                                        momentum = 0.9)

#define network architecture for fixed model
m_fix = Chain(input,
    #        batch_norm_layer,
            hidden1,
    #        batch_norm_layer,
            hidden2,
    #        batch_norm_layer,
            output)

#define network layers for a clean set of parameters
input_params = Dense(d, d + 10 , relu; 
                       bias = false, 
                       init = glorot_normal)

hidden_params1 = Dense(d + 10 , d + 10, relu;
                        bias = false,
                        init = glorot_normal)

hidden_params2 = Dense(d + 10 , d + 10 , relu;
                        bias = false,
                        init = glorot_normal)

#no activation on the last layer
output_params = Dense(d + 10,1,identity)

#define network architecture for trianing model
m_parameters = Chain(input_params,
    #        batch_norm_layer,
            hidden_params1,
    #        batch_norm_layer,
            hidden_params2,
    #        batch_norm_layer,
            output_params)

#fixed set of starting parameters for re-setting the nn at each step. 
ps_start = Flux.params(m_parameters)

#define network layers
input_train = Dense(d, d + 10 , relu; 
                       bias = false, 
                       init = glorot_normal)

hidden_train1 = Dense(d + 10 , d + 10, relu;
                        bias = false,
                        init = glorot_normal)

hidden_train2 = Dense(d + 10, d + 10, relu;
                        bias = false,
                        init = glorot_normal)

#no activation on the last layer
output_train = Dense(d + 10,1,identity)

#define network architecture for trianing model
m_train = Chain(input_train,
    #        batch_norm_layer,
            hidden_train1,
    #        batch_norm_layer,
            hidden_train2,
    #        batch_norm_layer,
            output_train)


loss(u,v) = mean((m_train(u) - v).^2)

#initialise V_0
V_i = phi

y_test = Array(zeros((d,batch_samples)))
y_test_mc = Array(zeros((d,batch_samples)))

#set the starting point

for i in 0:N-1
    
    println("Iteration ", i)
    
    #reset parameters on trainable model
    Flux.loadparams!(m_train, ps_start)     

    ps = params(m_train)
    
    opt = Optimiser(ExpDecay(learning_rate,0.1,learn_rate_decrease,1e-8),ADAM()) #optimiser

    #find parameters for nth nn 
    for k in 1:train_steps

        #find the 2 last time steps of SDE
        y = y_sde(N - i)
        
        #format into training data y_1 and y_2
        y_n = y[1]
        y_n_minus_1 = y[2]
                        
        #Finding Gradients -
        grad_V_i = Array(ones((d,batch_samples)))

        for j in 1:batch_samples
            grad_V_i_j = ForwardDiff.jacobian(V_i,y_n_minus_1[:,j])
            grad_V_i[:,j] = grad_V_i_j
        end 
        
         Vv_y = V_i(y_n_minus_1) + h*f(y_n_minus_1,V_i(y_n_minus_1),grad_V_i)

        data = [(y_n,Vv_y)]  

        #parameter update step
        train!(loss,ps,data,opt)
        
        #manual update to opt at 500 steps
        if k== 500 
            opt = ADAM(learning_rate/100) #optimiser
        end

    y_test = y_n
    
        
    end
    
    println("time ", (N - i) / N) 
    
    #load learned params into fixed model
    Flux.loadparams!(m_fix, ps)     
    
    #set new V_i to be the fixed model with new parameters
    V_i = m_fix #nn function with parameters 
    
    #take the test value from nn
    val = mean(V_i(y_test))
    
    println("v_i output",val)    
    
  #  y_test_mc = if i==0 y_test 
  #  else (y_test .+ sigma_bar .* rand(Normal(0,sqrt(i*T/N)),(d, batch_samples))) end 
    
  #  mc_sample = sum(exp.( - phi(y_test_mc)))./batch_samples
    
  #  mc_val = -log.(mc_sample)
    
#
#    Include error bars here if u want....
    
#    #sd_errs = np.sqrt((mean_sq_samples - mean_samples_sq)/mc_samples)
#
#    #print("u_ref = ", -np.log(mean_samples))
#
#    #print("upper end = ", -np.log(mean_samples-sd_errs))
#    #print("lower end = ", -np.log(mean_samples+sd_errs))
#    
    
  #  println("m_c output",mc_val)    
    
    println("#################")
#    
    df_row = DataFrame(step =i,
                        value=val)

    CSV.write("hjb_test_v2.csv", df_row, append = true)

end



Iteration 0
time 1.0
v_i output1.5697021981161812
#################
Iteration 1
time 0.875
v_i output1.5684641029119015
#################
Iteration 2
time 0.75
v_i output1.5625712409239978
#################
Iteration 3
time 0.625
v_i output1.5608615538468062
#################
Iteration 4
time 0.5
v_i output1.5596993616690438
#################
Iteration 5
time 0.375
v_i output1.557649229891107
#################
Iteration 6
time 0.25
v_i output1.5560061608311997
#################
Iteration 7
time 0.125
v_i output1.5553784370422363
#################
