# Cart and Inverted Pendulum Simulation

- This is a Julia version of the system used in my Open AI Gym environment:
https://github.com/billtubbs/gym-CartPole-bt-v0/
- Both mimic the system used in Steve Brunton's Control Bootcamp video lecture series.

In [1]:
using Printf
using Test
using DifferentialEquations

In [2]:
# Prepare directories
plot_dir = "plots"
data_dir = "data"

for dir_path in [plot_dir, data_dir]
    if ~isdir(dir_path)
        mkdir(dir_path)
    end
end

In [3]:
function cartpend_dydt(t, y, m=1, M=5, L=2, g=-10, d=1, u=0)
    """Simulates the non-linear dynamics of a simple cart-pendulum system.
    These non-linear ordinary differential equations (ODEs) return the
    time-derivative at the current time given the current state of the
    system.
    Args:
        t (float): Time variable - not used here but included for
            compatibility with solvers like scipy.integrate.solve_ivp.
        y (array): State vector. This should be an array of
            shape (4, ) containing the current state of the system.
            y[0] is the x-position of the cart, y[1] is the velocity
            of the cart (dx/dt), y[2] is the angle of the pendulum
            (theta) from the vertical in radians, and y[3] is the
            rate of change of theta (dtheta/dt).
        m (float): Mass of pendulum.
        M (float): Mass of cart.
        L (float): Length of pendulum.
        g (float): Acceleration due to gravity.
        d (float): Damping coefficient for friction between cart and
            ground.
        u (float): Force on cart in x-direction.
    Returns:
        dy (array): The time derivate of the state (dy/dt) as a
            shape (4, ) array.
    """

    # Temporary variables
    Sy = sin(y[3])
    Cy = cos(y[3])
    mL = m*L
    D = 1/(L*(M + m*(1 - Cy^2)))
    b = mL*y[4]^2*Sy - d*y[2] + u
    dy = zeros(4)

    # Non-linear ordinary differential equations describing
    # simple cart-pendulum system dynamics
    dy[1] = y[2]
    dy[2] = D*(-mL*g*Cy*Sy + L*b)
    dy[3] = y[4]
    dy[4] = D*((m + M)*g*Sy - Cy*b)

    return dy
end;

function cartpend_ss(m=1, M=5, L=2, g=-10, d=1, s=1)
    """Calculates the linearized approximation of the cart-pendulum
    system dynamics at either the vertical-up position (s=1) or
    vertical-down position (s=-1).

    Returns two arrays, A, B which are the system and input matrices
    in the state-space system of differential equations:

        x_dot = Ax + Bu

    where x is the state vector, u is the control vector and x_dot
    is the time derivative (dx/dt).

    Args:
        m (float): Mass of pendulum.
        M (float): Mass of cart.
        L (float): Length of pendulum.
        g (float): Acceleration due to gravity.
        d (float): Damping coefficient for friction between cart and
            ground.
        s (int): 1 for pendulum up position or -1 for down.

    Returns:
        dy (np.array): The time derivate of the state (dy/dt) as a
            shape (4, ) array.
    """

    A = [     0.0        1.0              0.0      0.0;
                0       -d/M           -m*g/M      0.0;
              0.0        0.0              0.0      1.0;
              0.0 -s*d/(M*L) -s*(m+M)*g/(M*L)      0.0]

    B = [        0.0;
               1.0/M;
                 0.0;
         s*1.0/(M*L)]

    return A, B
end;

## Run tests

In [4]:
# Fixed parameter values
m = 1
M = 5
L = 2
g = -10
d = 1
u = 0

0

In [5]:
y_test_values = Dict(
    1 => [0, 0, 0, 0],  # Pendulum down position
    2 => [0, 0, pi, 0],  # Pendulum up position
    3 => [0, 0, 0, 0],
    4 => [0, 0, pi, 0],
    5 => [2.260914, 0.026066, 0.484470, -0.026480]
);

u_test_values = Dict(
    1 => 0.,
    2 => 0.,
    3 => 1.,
    4 => 1.,
    5 => -0.59601
);

# dy values below calculated with MATLAB script from
# Steven L. Brunton's Control Bootcamp videos
expected_results = Dict(
    1 => [0., 0., 0., 0.],
    2 => [0., -2.44929360e-16, 0., -7.34788079e-16],
    3 => [0., 0.2, 0., -0.1],
    4 => [0., 0.2, 0. ,0.1],
    5 => [0.026066, 0.670896, -0.026480, -2.625542]
);

t = 0.0
atol = 1e-6
for i in 1:5
    u = u_test_values[i]
    y = y_test_values[i]
    dy_calculated = cartpend_dydt(t, y, m, M, L, g, d, u)
    dy_expected = expected_results[i]
    @test maximum(abs.(dy_calculated - expected_results[i])) < atol
end

In [6]:
# K values below calculated with MATLAB script from
# Steven L. Brunton's Control Bootcamp videos
test_values = Dict(
    5 => 1,  # Pendulum up position
    6 => -1  # Pendulum down position
)

expected_results = Dict(
    5 => ([0.0   1.0   0.0   0.0;
           0.0  -0.2   2.0   0.0;
           0.0   0.0   0.0   1.0;
           0.0  -0.1   6.0   0.0],
         [ 0.0;  0.2;  0.0;  0.1]),
    6 => ([0.0   1.0   0.0   0.0;
           0.0  -0.2   2.0   0.0;
           0.0   0.0   0.0   1.0;
           0.0   0.1  -6.0   0.0],
         [ 0.0;  0.2;  0.0; -0.1])
);

atol = 1e-6
for i in 5:6
    s = test_values[i]
    A_calculated, B_calculated = cartpend_ss(m, M, L, g, d, s)
    A_expected, B_expected = expected_results[i]
    @test maximum(abs.(A_calculated - A_expected)) < atol
    @test maximum(abs.(B_calculated - B_expected)) < atol
end

In [7]:
struct Foo
   bar
   baz
end

foo = Foo(1, 2)

Foo(1, 2)

In [8]:
struct CartPoleBTEnv
    gravity::Float64
    masscart::Float64
    masspole::Float64
    length::Float64
    friction::Float64
    max_force::Float64
    goal_state::Array
    initial_state::String
    disturbances::String
    initial_state_variance::String
    measurement_error::String
    hidden_states::Bool
    variance_levels::Dict
    tau::Float64
    n_steps::Int
    time_step::Int
    kinematics_integrator::String
    observation_space::Array
    action_space::Array
    seed::Int
    state::Array
end

# Set defaults with keyword arguments
CartPoleBTEnv(;
    gravity=-10.0, 
    masscart=5.0, 
    masspole=1.0, 
    length=2.0, 
    friction=1.0, 
    max_force=200.0,
    goal_state=[0.0; 0.0; pi; 0.0],
    initial_state="goal",
    disturbances="none",
    initial_state_variance="none",
    measurement_error="none",  # Not implemented yet
    hidden_states=false,  # Not implemented yet
    variance_levels=Dict("none"=>0.0, "low"=>0.01, "high"=>0.2),
    tau=0.05,
    n_steps=100,
    time_step=0,
    kinematics_integrator="RK45",
    observation_space=[[-Inf64; -Inf64; -Inf64; -Inf64],
                       [Inf64; Inf64; Inf64; Inf64]],
    action_space=[[-Inf64; -Inf64]],
    seed=1,
    state=zeros(4)
) = CartPoleBTEnv(
    gravity, masscart, masspole, length, friction, max_force,
    goal_state, initial_state, disturbances, initial_state_variance,
    measurement_error, hidden_states, variance_levels, 
    tau, n_steps, time_step, kinematics_integrator,
    observation_space, action_space, seed, state
)

# Usage:
# CartPoleBTEnv()  # Defaults
# CartPoleBTEnv(;friction=2)  # Specify non-default values

gym = CartPoleBTEnv()
@test gym.friction == 1.0
@test gym.state == [0.0; 0.0; 0.0; 0.0]
@test gym.goal_state == [0.0; 0.0; 3.141592653589793; 0.0]

[32m[1mTest Passed[22m[39m

In [9]:
CartPoleBTEnv()

CartPoleBTEnv(-10.0, 5.0, 1.0, 2.0, 1.0, 200.0, [0.0, 0.0, 3.14159, 0.0], "goal", "none", "none", "none", false, Dict("high"=>0.2,"none"=>0.0,"low"=>0.01), 0.05, 100, 0, "RK45", Array{Float64,1}[[-Inf, -Inf, -Inf, -Inf], [Inf, Inf, Inf, Inf]], Array{Float64,1}[[-Inf, -Inf]], 1, [0.0, 0.0, 0.0, 0.0])

In [10]:
function angle_normalize(theta)
    return theta % (2*pi)
end

@test angle_normalize(0) == 0.0
@test angle_normalize(pi*2.1) == angle_normalize(pi*0.1)
@test angle_normalize(-pi*2.1) == angle_normalize(-pi*0.1)
@test angle_normalize(pi*1.9) == angle_normalize(pi*3.9)

function cost_function(state, goal_state)
        """Evaluates the cost based on the current state y and
        the goal state.
        """
        return ((state[1] - goal_state[1])^2 +
                (angle_normalize(state[3]) - goal_state[3])^2)
end

cost_function(gym::CartPoleBTEnv) = cost_function(gym.state, gym.goal_state)
cost_function(gym::CartPoleBTEnv, state) = cost_function(state, gym.goal_state)

gym = CartPoleBTEnv()
@test cost_function(zeros(4), zeros(4)) == 0.0
@test cost_function(zeros(4), [0.0, 0.0, pi, 0.0]) == 9.869604401089358
@test cost_function(gym) == 9.869604401089358
@test cost_function(gym, zeros(4)) == 9.869604401089358

[32m[1mTest Passed[22m[39m

In [11]:
f(u,p,t) = 0.98u
u0 = 1.0
tspan = (0.0,1.0)
prob = ODEProblem(f,u0,tspan)

[36mODEProblem[0m with uType [36mFloat64[0m and tType [36mFloat64[0m. In-place: [36mfalse[0m
timespan: (0.0, 1.0)
u0: 1.0

In [12]:
sol = solve(prob)

retcode: Success
Interpolation: Automatic order switching interpolation
t: 5-element Array{Float64,1}:
 0.0                
 0.10042494449239292
 0.35218603951893646
 0.6934436028208104 
 1.0                
u: 5-element Array{Float64,1}:
 1.0               
 1.1034222047865465
 1.4121908848175453
 1.9730384275622992
 2.6644561424814506

In [13]:
sol.t

5-element Array{Float64,1}:
 0.0                
 0.10042494449239292
 0.35218603951893646
 0.6934436028208104 
 1.0                

In [14]:
sol.u

5-element Array{Float64,1}:
 1.0               
 1.1034222047865465
 1.4121908848175453
 1.9730384275622992
 2.6644561424814506

In [15]:
gym.max_force

200.0

In [16]:
function step(gym::CartPoleBTEnv, u)
    u = clamp(u, -gym.max_force, gym.max_force)
    y = gym.state
    t = gym.time_step * gym.tau
    
    if gym.kinematics_integrator == "euler"
        println("euler")
        y_dot = cartpend_dydt(t, y,
                              gym.masspole,
                              gym.masscart,
                              gym.length,
                              gym.gravity,
                              gym.friction,
                              u)
        gym.state += gym.tau * y_dot
    elseif gym.kinematics_integrator == "RK45"
        println("RK45")
        f(y, p, t) = cartpend_dydt(t, y,
                                   gym.masspole,
                                   gym.masscart,
                                   gym.length,
                                   gym.gravity,
                                   gym.friction,
                                   u)
        y0 = gym.state
        tspan = (t, t + gym.tau)
        prob = ODEProblem(f, y0, tspan)
        sol = solve(prob)
        gym.state = sol.u[end]
    end
end

step (generic function with 1 method)

In [31]:
t = gym.time_step * gym.tau
f(y, p, t) = cartpend_dydt(t, y,
                                   gym.masspole,
                                   gym.masscart,
                                   gym.length,
                                   gym.gravity,
                                   gym.friction,
                                   u)
y0 = gym.state
tspan = (t, t + gym.tau)
prob = ODEProblem(f, y0, tspan)
sol = solve(prob)

retcode: Success
Interpolation: Automatic order switching interpolation
t: 5-element Array{Float64,1}:
 0.0                  
 9.999999999999999e-5 
 0.0010999999999999998
 0.011099999999999997 
 0.05                 
u: 5-element Array{Array{Float64,1},1}:
 [0.0, 0.0, 0.0, 0.0]                                
 [-5.96006e-10, -1.19201e-5, 2.98003e-10, 5.96004e-6]
 [-7.21119e-8, -0.000131108, 3.60559e-8, 6.55538e-5] 
 [-7.33793e-6, -0.00132165, 3.66878e-6, 0.000660756] 
 [-0.000148476, -0.00592793, 7.41607e-5, 0.00295778] 

In [32]:
sol.u[end]

4-element Array{Float64,1}:
 -0.00014847616143074153
 -0.005927929443185163  
  7.416066919494952e-5  
  0.0029577764184044054 

In [33]:
gym.state

4-element Array{Float64,1}:
 0.0
 0.0
 0.0
 0.0

In [34]:
gym.state = sol.u[end]

ErrorException: setfield! immutable struct of type CartPoleBTEnv cannot be changed

In [19]:
step(gym, [0.5])

MethodError: MethodError: no method matching isless(::Float64, ::Array{Float64,1})
Closest candidates are:
  isless(::Float64, !Matched::Float64) at float.jl:459
  isless(!Matched::Missing, ::Any) at missing.jl:70
  isless(::AbstractFloat, !Matched::AbstractFloat) at operators.jl:148
  ...

In [21]:
f(y, p, t) = cartpend_dydt(t, y,
                           gym.masspole,
                           gym.masscart,
                           gym.length,
                           gym.gravity,
                           gym.friction,
                           u)

f (generic function with 1 method)

In [22]:
y0 = gym.state
tspan = (0.0,1.0)
prob = ODEProblem(f,y0,tspan)

[36mODEProblem[0m with uType [36mArray{Float64,1}[0m and tType [36mFloat64[0m. In-place: [36mfalse[0m
timespan: (0.0, 1.0)
u0: [0.0, 0.0, 0.0, 0.0]

In [23]:
sol = solve(prob)

retcode: Success
Interpolation: Automatic order switching interpolation
t: 11-element Array{Float64,1}:
 0.0                  
 9.999999999999999e-5 
 0.0010999999999999998
 0.011099999999999997 
 0.0602289268362002   
 0.151453304943196    
 0.27267236386087373  
 0.42530040904162536  
 0.610705730342367    
 0.8303569319992211   
 1.0                  
u: 11-element Array{Array{Float64,1},1}:
 [0.0, 0.0, 0.0, 0.0]                                
 [-5.96006e-10, -1.19201e-5, 2.98003e-10, 5.96004e-6]
 [-7.21119e-8, -0.000131108, 3.60559e-8, 6.55538e-5] 
 [-7.33793e-6, -0.00132165, 3.66878e-6, 0.000660756] 
 [-0.000215274, -0.00713203, 0.000107474, 0.00355521]
 [-0.00135086, -0.0177153, 0.000668968, 0.00868759]  
 [-0.00432541, -0.0312495, 0.00209581, 0.0146534]    
 [-0.0103299, -0.0472137, 0.00477972, 0.0200646]     
 [-0.0207403, -0.0647275, 0.00880564, 0.0225602]     
 [-0.0369875, -0.0827303, 0.0135174, 0.0192359]      
 [-0.0520694, -0.0948523, 0.0162593, 0.0125882]      

In [24]:
sol.t

11-element Array{Float64,1}:
 0.0                  
 9.999999999999999e-5 
 0.0010999999999999998
 0.011099999999999997 
 0.0602289268362002   
 0.151453304943196    
 0.27267236386087373  
 0.42530040904162536  
 0.610705730342367    
 0.8303569319992211   
 1.0                  

In [25]:
size(sol.u[end])

(4,)

In [26]:
sol.u

11-element Array{Array{Float64,1},1}:
 [0.0, 0.0, 0.0, 0.0]                                
 [-5.96006e-10, -1.19201e-5, 2.98003e-10, 5.96004e-6]
 [-7.21119e-8, -0.000131108, 3.60559e-8, 6.55538e-5] 
 [-7.33793e-6, -0.00132165, 3.66878e-6, 0.000660756] 
 [-0.000215274, -0.00713203, 0.000107474, 0.00355521]
 [-0.00135086, -0.0177153, 0.000668968, 0.00868759]  
 [-0.00432541, -0.0312495, 0.00209581, 0.0146534]    
 [-0.0103299, -0.0472137, 0.00477972, 0.0200646]     
 [-0.0207403, -0.0647275, 0.00880564, 0.0225602]     
 [-0.0369875, -0.0827303, 0.0135174, 0.0192359]      
 [-0.0520694, -0.0948523, 0.0162593, 0.0125882]      