In [None]:
using Distributions;
using Random;
using DataFrames;
using CSV;
using Statistics;
using LinearAlgebra;
using ForwardDiff;
using Tracker;
using Flux;
using Flux: params, 
            Dense, 
            Chain, 
            glorot_normal, 
            normalise, 
            Optimiser,
            train!;


#initialise error file and create row headers
df_row = DataFrame(step = "step",value="value")

#CSV.write("deep_pde_approx/non_linear_bs_test_4.csv", df_row, append = true);

N = 96
T = 1/3
h = T/N

#benchmark values for d=10
#this is what we want to hit at T=1/3
u_ref = 40.7611353

#parameters of the function
delta, R = 2/3, 0.02 
mu_bar , sigma_bar = 0.02,0.2 
v_h,v_l = 50,70
gamma_h,gamma_l = 0.2,0.02

#testing values 
d = 10
#batch_samples = 100
train_steps =3000
batch_samples = 4096
#train_steps = 3000

learning_rate = 0.1 #initial learning rate
learn_rate_decrease = 2500 #frequency of LR decay

y_start = Array(ones((d,batch_samples))) .* 50

#initial condition phi 
function phi(y)
    return minimum(y; dims=1)
end

#non-linear function 
function f(x,y,z)
    return ((-1) * (1 - delta) * (min.(max.(((y .- v_h).*(gamma_h - gamma_l) ./(v_h - v_l) .+ gamma_h)
            ,gamma_l),gamma_h)).*y) - (R*y)
end

#evolution of the SDE using EM
function y_sde(m)  

    y_sde = Array(ones((d,batch_samples))) * 50
    y_1 = Array(ones((d,batch_samples))) * 50
    y_2 = Array(ones((d,batch_samples))) * 50
    
    for i in 1:m 
        y_sde = y_sde.*((1 + (T/N)*mu_bar) 
                    .+ (sigma_bar .* rand(Normal(0,sqrt(T/N)),(d, batch_samples))))
                        #evolve the SDE for N-n steps
        
        #one step earlier
        if i == m-1 y_1 = y_sde end 
        #all the way
        if i == m y_2 = y_sde end
    end        

    return y_1,y_2
end

#define network layers
input = Dense(d, d + 10 + 40, relu; 
                       bias = false, 
                       init = glorot_normal)

hidden1 = Dense(d + 10 + 40, d + 10 + 40, relu;
                        bias = false,
                        init = glorot_normal)

hidden2 = Dense(d + 10 + 40, d + 10 + 40, relu;
                        bias = false,
                        init = glorot_normal)

#no activation on the last layer
output = Dense(d + 10 + 40,1,identity)

batch_norm_layer = BatchNorm(d + d, identity;
                                        initβ = zeros, 
                                        initγ = ones,
                                        ϵ = 1e-6, 
                                        momentum = 0.9)

#define network architecture for fixed model
m_fix = Chain(input,
    #        batch_norm_layer,
            hidden1,
    #        batch_norm_layer,
            hidden2,
    #        batch_norm_layer,
            output)

#define network layers for a clean set of parameters
input_params = Dense(d, d + 10 + 40, relu; 
                       bias = false, 
                       init = glorot_normal)

hidden_params1 = Dense(d + 10 + 40, d + 10 + 40, relu;
                        bias = false,
                        init = glorot_normal)

hidden_params2 = Dense(d + 10 + 40, d + 10 + 40, relu;
                        bias = false,
                        init = glorot_normal)

#no activation on the last layer
output_params = Dense(d + 10 + 40,1,identity)

#define network architecture for trianing model
m_parameters = Chain(input_params,
    #        batch_norm_layer,
            hidden_params1,
    #        batch_norm_layer,
            hidden_params2,
    #        batch_norm_layer,
            output_params)

#fixed set of starting parameters for re-setting the nn at each step. 
ps_start = Flux.params(m_parameters)
#must be a better way to do this???

#define network layers
input_train = Dense(d, d + 10 + 40, relu; 
                       bias = false, 
                       init = glorot_normal)

hidden_train1 = Dense(d + 10 + 40, d + 10 + 40, relu;
                        bias = false,
                        init = glorot_normal)

hidden_train2 = Dense(d + 10 + 40, d + 10 + 40, relu;
                        bias = false,
                        init = glorot_normal)

#no activation on the last layer
output_train = Dense(d + 10 + 40,1,identity)

#define network architecture for trianing model
m_train = Chain(input_train,
    #        batch_norm_layer,
            hidden_train1,
    #        batch_norm_layer,
            hidden_train2,
    #        batch_norm_layer,
            output_train)

opt = Optimiser(ExpDecay(learning_rate,0.1,learn_rate_decrease,1e-8),ADAM()) #optimiser

loss(u,v) = mean((m_train(u) - v).^2)

#initialise V_0
V_i = phi
y_test = Array(ones((d,batch_samples))).*50

for i in 0:N
        
    #reset parameters on trainable model
    Flux.loadparams!(m_train, ps_start)     

    ps = params(m_train)
    
    y_input = y_test
    
    #find parameters for nth nn
    for k in 1:train_steps

        #find the 2 last time steps of SDE
        y = y_sde(N - i + 1)
        
        #format into training data y_1 and y_2
        y_n = y[1]
        y_n_minus_1 = y[2]
        #if i == N rand(Uniform(-0.5,0.5),(d,batch_samples)) else y[2] end
                
        ## GRADIENTS METHOD 2
                #Finding Gradients -
        grad_V_i = Array(ones((d,batch_samples)))

        for j in 1:batch_samples
            grad_V_i_j = ForwardDiff.jacobian(V_i,y_n_minus_1[:,j])
            grad_V_i[:,j] = grad_V_i_j
        end 
        
        #calculate the function V
        Vv_y = V_i(y_n_minus_1) + h*f(y_n_minus_1,V_i(y_n_minus_1),grad_V_i)

        data = [(y_n,Vv_y)]  
        
        #parameter update step
        train!(loss,ps,data,opt)
        
        #used for calculating the output
        y_input = y_n
        
    end
    
    #load learned params into fixed model
    Flux.loadparams!(m_fix, ps)     
    
    #set new V_i to be the fixed model with new parameters
    V_i = m_fix #nn function with parameters 
    
    # println("y_n" , y_input)
    
    #take the test value
    val = mean(V_i(y_input))
    println("v_i output",val, "step", i)    
    
    df_row = DataFrame(step =i,value=val)

  #  CSV.write("deep_pde_approx/non_linear_bs_test_4.csv", df_row, append = true)

end

