In [None]:
using DifferentialEquations
using Plots
using LinearAlgebra
using Statistics
using QuadGK
using ForwardDiff
using TaylorSeries
using Printf


In [None]:
# Parameters
# Degree of taylor expansion
const Order = 5



$$
-\frac{\partial V_0}{\partial x} = f_0(x) = \frac{dx}{dt}
$$

In [None]:
const shift_V_0 = 1
const V_0(x) = (x-shift_V_0)^4/12 - 1/2*(x-shift_V_0)^2 +2/10 * (shift_V_0-4) + 1
const f_0(x) = -ForwardDiff.derivative(V_0, x)

# Taylor expansion
t = Taylor1(Float64, Order)
tV_0 = V_0(t)
tf_0 = -derivative(tV_0)

x = -4:0.2:4
y1 = V_0.(x)
y2 = tV_0.(x)
y3 = f_0.(x)
y4 = tf_0.(x)

# p1 = plot(x, y1, ylims)

plot(x, y1, ylims=(-2, 5), label="V_0")
p1 = plot!(x, y2, ylims=(-2, 5), label="taylor approx V_0")
plot(x, y3, ylims=(-2, 5), label="f_0")
p2 = plot!(x, y4, ylims=(-2,5), label="taylor approx f_0")
plot(p1, p2, layout = (1, 2))

# Scalar function (m=1)

In [None]:
# Extended dimension
N = 2
Ω₁ = fill(1, (N, N)) / N


"""
The embedding function
Argument:
    - X: value to evaluate (N, 1)
    - a: the Taylor coefficients of the function
    - commutative (default true): if the embedding function is commutative
Return:
    - The value of the embedding (N, 1)
"""
function CommF(X, a, commutative=true)
    N = length(X)
    M = length(a) - 1
    result = zeros(N, N)
    for i=0:M
        if commutative == 1
            result += a[i+1] * Ω₁ * diagm(0 => X)^i
        else
            result += a[i+1] * (Ω₁ * diagm(0 => X))^i
        end
    end
    return result * fill(1, (N, 1)) # b is set to be 1 here, but it can be more generalized
end


X = [1; 2]
const a = tf_0.coeffs
const M = length(a) - 1


display(CommF(X, a))
display(CommF(X, a, false))


$$
\frac{dX}{dt} = \sum_{i=0}^N \Omega_1 a_i X^i \vec{1} - \alpha (I - \Omega_1) \vec{1}
$$

In [None]:

"""
The derivative function for the dynamics
Argument:
    - dX: the placeholder to receive output
    - X: the value to evaluate
    - p: the parameter list
        - p[1]: the α in the argument
        - p[2]: whether to use commutative embedding function
        - p[3]: the number of a (i.e. N in the formula)
        - p[4..]: the list of a
Return:
    Inplace output in dX

"""
function f!(dX, X, p, t=0)
    α = p[1] # First parameter for α
    comm = p[2] # Second parameter for commutative function
    len_a = Integer(p[3])
    a = @view p[4:4+len_a - 1]
    
    result = CommF(X, a, comm)
    result -= α * (I - Ω₁) * X
    dX[1] = result[1,1]
    dX[2] = result[2,1]
    dX
end


function direct_f(X, p, t=0)
    dX = zeros(Float64, 2)
    f!(dX, X, p, t)
    return dX
end

dX = [0. 0.]

p = vcat([1, true, length(a)], a)


println(f!(dX, [1, 2], p))
println(dX)
println(direct_f([1, 2], p))

## One dynamic

In [None]:
"""
Get the solution of the ODE problem defined by f!
Argument:
    - X₀: the initial value
    - tspan: the time range for the problem
    - p: the parameter to pass in the problem
Return:
    - The solution

"""
function getSolution(X₀, tspan, p)
    prob = ODEProblem(f!, X₀, tspan, p)
    sol = solve(prob)
    return sol
end

X₀ = [3; 2]
tspan = (0, 10)

sol = getSolution(X₀, tspan, p)

x = [mean(vec) for vec in sol.u]

println(x[end])
plot(x, label="x")



## Slope field

In [None]:
meshgrid(x, y) = (repeat(x, outer=length(y)), repeat(y, inner=length(x)))
x, y = meshgrid(-4:0.2:4, -4:0.2:4)

u = zeros(Float64, length(x))
v = zeros(Float64, length(y))

tmp = map((xi, yi) -> direct_f([xi; yi], p), x, y)

u = [i[1] for i in tmp]
v = [i[2] for i in tmp]

scale = 0.002
u .= scale * u
v .= scale * v
quiver(x, y, quiver=(u, v))
plot!(size=(800,800))

## Initial points

In [None]:
plot()

x, y = meshgrid(-4:0.1:4, -4:0.1:4)
grid = [[a, b] for (a, b) in zip(x, y)]

global_count, local_count = 0, 0

for i in 1:1000
    X0 =  -4. .+ 8. .* rand(2)
    sol = getSolution(X0, (0, 10), p)

    # Extract x and y values
    x_values = [point[1] for point in sol.u]
    y_values = [point[2] for point in sol.u]

    # Create scatter plot
    if y_values[end] < 0
        color = :blue
        global_count += 1
    else
        color = :red
        local_count += 1
    end
    
    plot!(x_values, y_values, title="Results of different initial points", 
        xlabel="x", ylabel="y", legend=false, linecolor = color)
end

@printf("Probability of global convergence: %.2f%%\n", global_count / (global_count + local_count) * 100)
# println()

plot!(size=(800,800))


## Intersection of the basin of attractions

In [None]:
coords_mean(x) = mean(direct_f([x,x], p))

f(x) = coords_mean(x) 
F(x) = -quadgk(f, 0, x)[1] + V_0(0)


# Create arrays for x values and y values
x_values = range(-4, stop=4, length=100)  # Adjust range and number of points as needed
y_values = F.(x_values)  # Calculate integral for each x value
y_real = V_0.(x_values)

# Plot the integral of the function
plot(x_values, y_values, title="Integral of function", xlabel="x", ylabel="Integral")
plot!(x_values, y_real, label="V_0")


### $x_2 = x_1 + p$

In [None]:
offset = 1
range_x = 4
x = -range_x:0.2:range_x

shift_F(x) = (V_0(x+offset) + V_0(x)) / 2

f(x) = mean(direct_f([x; x+offset], p, 0))
F(x) = -quadgk(f, 0, x)[1] + shift_F(0)

y1 = F.(x)
y2 = shift_F.(x)

plot(x, y1, xlims=(-range_x, range_x), label="Real x")
plot!(x, y2, xlims=(-range_x, range_x), label="Calculated x from V_0")
title!("x_2 = x_1 + p")

### $x_1 + x_2 = p$

In [None]:
offset = 3
range_x = 3
x = -range_x:0.2:range_x

shift_F(x) = (V_0(x) - V_0(offset-x))/2
f(x) = mean(direct_f([x; offset-x], p))
F(x) = -quadgk(f, 0, x)[1] + shift_F(0)

y1 = F.(x)
y2 = shift_F.(x)

# println(std(y1 .- y2))
# println(mean(y1 .- y2))

plot(x, y1, xlims=(-range_x, range_x), label="actual V(x)")
plot!(x, y2, xlims=(-range_x, range_x), label="calculated from V_0")
title!("x₁ + x₂ = p")


### $x_2 = k x_1 + b$

In [None]:
offset = 3
b = 2
range_x = 3
x = -range_x:0.2:range_x

shift_F(x) = (V_0(x) + 1/offset * V_0(offset * x + b)) / 2


f(x) = mean(direct_f([x; offset * x + b], p, 0))
F(x) = -quadgk(f, 0, x)[1] + shift_F(0)

y1 = F.(x)
y2 = shift_F.(x)

plot(x, y1, xlims=(-range_x, range_x), label="offset")
plot!(x, y2, xlims=(-range_x, range_x), label="calculated")
title!("x₂ = kx₁ + b")

## Two particles dynamics