In [1]:
using Distributions, CSV, DataFrames, ProgressMeter, Plots
include("../../utils_1d.jl")

MethodOfLines (generic function with 1 method)

In [2]:
T=1.0; r=0.03; R = r; μ=0.03; S0 = 100.0; σ = 0.2; K = 100.0; q = 0.0;

drift(x) = μ*x
diffusion(x) = σ*x
driver(t, x, y, z) = (
        -r*max(y-z/σ, 0.0)
        -R*min(y-z/σ, 0.0)
        -μ*z/σ
        )
# terminal(x) = max(K-x[1],0) 
terminal(x) = max(x-K,0)

bsde = BSDE(T, S0, drift, diffusion, driver, terminal);

In [3]:
Nₜs = [10, 20, 50, 100, 200]
header = vcat(["g", "measurement_type"], string.(Nₜs))
schemes = [
    [DP5(), false]
]

### designing grids
domain = [0.0, 2*bsde.X0];
Nₗ = 1000; Δₗ = (bsde.X0-domain[1])/Nₗ;
Nᵣ = 1000; Δᵣ = (domain[2]-bsde.X0)/Nᵣ;

g = 50
spatial_grid = TavellaRandallGrid(g, g, domain[1], bsde.X0, domain[2], Nₗ, Nᵣ)

In [4]:
### BENCHMARKs
function price(t::Float64, s::Float64)
    s = abs(s)
    expiry=T-t
    if expiry >0
        d1 = (log(s / K) + (r-q+(σ^2)/2)*expiry)/(σ*sqrt(expiry))
        d2 = d1-σ*sqrt(expiry)
        call = s*exp(-q*expiry)*cdf(Normal(), d1)-K*exp(-r*expiry)*cdf(Normal(), d2)
    else
        call = max(s-K, 0.0)
    end
    return call
end

price (generic function with 1 method)

In [5]:
for Nₜ in [57990,57995,58000,58005]
    scheme, EXPINT = schemes[1]
    print(string("Nₜ=", Nₜ,":\n")); flush(stdout)
    exc_start = time()
    res = MethodOfLines(bsde, spatial_grid, Nₜ, scheme, EXPINT)
    exc_stop = time()
    sol = res[1]; s_grid = res[2]; 

    bs_sol = zeros(spatial_grid.N, Nₜ+1)
    for (index, t) in enumerate((bsde.T/Nₜ).*(0:Nₜ))
        bs_sol[:, index] = price.(Ref(T-t), s_grid);
    end

    if size(sol)[2] == Nₜ+2 
        abs_err = abs.(bs_sol-sol[:,1:end-1])
    elseif size(sol)[2] == Nₜ+1
        abs_err = abs.(bs_sol-sol[:,1:end])
    else
        continue
    end
    ind_slice = findall(attr->(attr<1.2*bsde.X0)&&(attr>0.8*bsde.X0), spatial_grid.grid)
    indmin = minimum(ind_slice); indmax = maximum(ind_slice);

    index = spatial_grid.Nₗ+1
    abs_target = abs.(abs_err[index,end])

    print(string("elapsed_time: ",exc_stop-exc_start, "\n")); flush(stdout)
    print(string("absolute error at ", S0, ":", abs_target, "\n")); flush(stdout)
    print(string("maximum abs error in [", S0*0.8, ", ", S0*1.2, "] :", maximum(abs_err[indmin:indmax,:]), "\n")); flush(stdout)

    truncated = max.(min.(sol[1:spatial_grid.N÷50:end,end:-Nₜ÷50:1],250.0),0.0)
    pl=Plots.wireframe((0:T/Nₜ:T)[1:Nₜ÷50:end],spatial_grid.grid[1:spatial_grid.N÷50:end],truncated,camera=(48,44))
    savefig(pl, string("plotNt", Nₜ, ".pdf"))
end

Nₜ=57990:
elapsed_time: 600.5065779685974
absolute error at 100.0:1.9890449211790293e12
maximum abs error in [80.0, 120.0] :3.303539372733063e31
Nₜ=57995:
elapsed_time: 629.9378490447998
absolute error at 100.0:1.324397305683405
maximum abs error in [80.0, 120.0] :1.727101740624271e18
Nₜ=58000:
elapsed_time: 665.0848369598389
absolute error at 100.0:0.00046117822622804283
maximum abs error in [80.0, 120.0] :121826.3221037855
Nₜ=58005:
elapsed_time: 668.8608529567719
absolute error at 100.0:0.0004611782219772209
maximum abs error in [80.0, 120.0] :0.013255616127729297
