In [None]:
using DataFrames, CSV, Plots, Statistics, DataStructures, JLD, Plots.PlotMeasures, LaTeXStrings, DynamicalSystems, Distributions, Random, StatsBase, KernelDensity, Interpolations, StatsPlots

path_to_files = "/home/matt/Desktop/Advanced_Analytics/Dissertation/Code/MDTG-MALABM/"
include(path_to_files * "Scripts/Moments.jl"); 

In [None]:
b0_s1 = load("../Data/RL/Training/MARL/Results_Type2_Buy_0_Sell_1_alpha0.1_lambda0.003_gamma0.25_delta_3_6.jld")["rl_results"]
b1_s0 = load("../Data/RL/Training/MARL/Results_Type2_Buy_0_Sell_1_alpha0.1_lambda0.003_gamma0.25_delta_5_10.jld")["rl_results"]
b1_s1 = load("../Data/RL/Training/MARL/Results_Type2_Buy_1_Sell_0_alpha0.1_lambda0.003_gamma0.25_delta_3_6.jld")["rl_results"]


In [None]:
marl_data = Dict("b1_s0" => b1_s0, "b0_s1" => b0_s1, "b1_s1" => b1_s1)

In [None]:
marl_data["b1_s1"][1]["rlAgent_1"]["ActionType"]

In [None]:
function GenerateActionsRL2(deltas, MA2)
    println("-------------------------------- Generating Type 2 RL Actions --------------------------------")
    actions = OrderedDict{Int64, Tuple{Float64, Float64}}()
    i = 1
    for delta in deltas
        for (p, d) in zip(range(0, 2, MA2), ones(MA2) .* delta)
            actions[i] = (p, d)
            i += 1
        end
    end
    return actions
end

In [None]:
function multTuple(t::Tuple)
    return t[1] * t[2]
end

In [None]:
function PlotVolumeTragectory(l::Dict)
    
    # plot the initial actions and the final actions to see if there are differences in actions selected
    actions1 = Vector{Float64}()
    for action in l[1]["rlAgent_1"]["Actions"]
        push!(actions1, action)
    end
    pi1 = plot(1:l[1]["rlAgent_1"]["NumberActions"], actions1, size = (800, 400), seriestype = :line, fillcolor = :blue, linecolor = :blue, legend = false, xlabel = "Action Number", ylabel = "Action volume", title = "Episode 1", titlefontsize = 9, guidefontsize = 8, tickfontsize = 8, left_margin = 5mm, bottom_margin = 5mm)
    
    actionsN = Vector{Float64}()
    n = 200
    for action in l[n]["rlAgent_1"]["Actions"]
        push!(actionsN, action)
    end
    piN = plot(1:l[n]["rlAgent_1"]["NumberActions"], actionsN, size = (800, 400), seriestype = :line, fillcolor = :blue, linecolor = :blue, legend = false, xlabel = "Action Number", ylabel = "Action volume", title = "Episode " * string(n), titlefontsize = 9, guidefontsize = 8, tickfontsize = 8, left_margin = 5mm, bottom_margin = 5mm)
    l = @layout([a; b])
    p = plot(pi1, piN, layout = l, fontfamily = "Computer Modern")
    return p
end

In [None]:
l = b1_s1
n = 200
PlotVolumeTragectory(l)

In [None]:
#----- Percent Correct Actions -----# 
function PlotAverageActionsOverTime1(l::Dict, actionsMap::OrderedDict)

    n = length(l)
    
    # get the average actions
    avg_actions_highvol = Vector{Float64}()
    avg_actions_lowvol = Vector{Float64}()

    avg_actions_highspr = Vector{Float64}()
    avg_actions_lowspr = Vector{Float64}()

    avg_actions_hightime = Vector{Float64}()
    avg_actions_lowtime = Vector{Float64}()

    avg_actions_highinv = Vector{Float64}()
    avg_actions_lowinv = Vector{Float64}()

    avg_actions_highvol_lowspr = Vector{Float64}()
    avg_actions_lowvol_highspr = Vector{Float64}()

    avg_actions_hightime_highinv = Vector{Float64}()
    avg_actions_lowtime_lowinv = Vector{Float64}()

    for i in 1:n
        states = collect(keys(l[i]["rlAgent_1"]["Q"]))
        # get the required states

        highvol = getindex.(Ref(l[i]["rlAgent_1"]["Q"]), states[findall(x -> x[4] >= 4 && x[1] > 0, states)])
        lowvol = getindex.(Ref(l[i]["rlAgent_1"]["Q"]), states[findall(x -> x[4] <= 2 && x[1] > 0, states)])

        highspr = getindex.(Ref(l[i]["rlAgent_1"]["Q"]), states[findall(x -> x[3] >= 3 && x[1] > 0, states)])
        lowspr = getindex.(Ref(l[i]["rlAgent_1"]["Q"]), states[findall(x -> x[3] <= 2 && x[1] > 0, states)])

        hightime = getindex.(Ref(l[i]["rlAgent_1"]["Q"]), states[findall(x -> x[1] >= 4 && x[1] > 0, states)])
        lowtime = getindex.(Ref(l[i]["rlAgent_1"]["Q"]), states[findall(x -> x[1] <= 2 && x[1] > 0, states)])

        highinv = getindex.(Ref(l[i]["rlAgent_1"]["Q"]), states[findall(x -> x[2] >= 4 && x[1] > 0, states)])
        lowinv = getindex.(Ref(l[i]["rlAgent_1"]["Q"]), states[findall(x -> x[2] <= 2 && x[1] > 0, states)])

        highvol_lowspr = getindex.(Ref(l[i]["rlAgent_1"]["Q"]), states[findall(x -> x[4] >= 4 && x[3] <= 2 && x[1] > 0, states)]) 
        lowvol_highspr = getindex.(Ref(l[i]["rlAgent_1"]["Q"]), states[findall(x -> x[4] <= 2 && x[3] >= 3 && x[1] > 0, states)])

        hightime_highinv = getindex.(Ref(l[i]["rlAgent_1"]["Q"]), states[findall(x -> x[1] >= 4 && x[2] >= 4 && x[1] > 0, states)]) 
        lowtime_lowinv = getindex.(Ref(l[i]["rlAgent_1"]["Q"]), states[findall(x -> x[1] <= 2 && x[2] <= 2 && x[1] > 0, states)]) 

        # remove the states just visited [0,0,...,0], and get the actions
        avg_action_highvol = mean(multTuple.(getindex.(Ref(actionsMap) ,argmax.(highvol[findall(x -> x != zeros(20), highvol)]))))
        avg_action_lowvol = mean(multTuple.(getindex.(Ref(actionsMap) ,argmax.(lowvol[findall(x -> x != zeros(20), lowvol)]))))

        avg_action_highspr = mean(multTuple.(getindex.(Ref(actionsMap) ,argmax.(highspr[findall(x -> x != zeros(9), highspr)]))))
        avg_action_lowspr = mean(multTuple.(getindex.(Ref(actionsMap) ,argmax.(lowspr[findall(x -> x != zeros(9), lowspr)]))))

        avg_action_hightime = mean(multTuple.(getindex.(Ref(actionsMap) ,argmax.(hightime[findall(x -> x != zeros(9), hightime)]))))
        avg_action_lowtime = mean(multTuple.(getindex.(Ref(actionsMap) ,argmax.(lowtime[findall(x -> x != zeros(9), lowtime)]))))

        avg_action_highinv = mean(multTuple.(getindex.(Ref(actionsMap) ,argmax.(highinv[findall(x -> x != zeros(9), highinv)]))))
        avg_action_lowinv = mean(multTuple.(getindex.(Ref(actionsMap) ,argmax.(lowinv[findall(x -> x != zeros(9), lowinv)]))))

        avg_action_highvol_lowspr = mean(multTuple.(getindex.(Ref(actionsMap) ,argmax.(highvol_lowspr[findall(x -> x != zeros(9), highvol_lowspr)]))))
        avg_action_lowvol_highspr = mean(multTuple.(getindex.(Ref(actionsMap) ,argmax.(lowvol_highspr[findall(x -> x != zeros(9), lowvol_highspr)]))))

        avg_action_hightime_highinv = mean(multTuple.(getindex.(Ref(actionsMap) ,argmax.(hightime_highinv[findall(x -> x != zeros(9), hightime_highinv)]))))
        avg_action_lowtime_lowinv = mean(multTuple.(getindex.(Ref(actionsMap) ,argmax.(lowtime_lowinv[findall(x -> x != zeros(9), lowtime_lowinv)]))))

        # add to vector
        push!(avg_actions_highvol, avg_action_highvol)
        push!(avg_actions_lowvol, avg_action_lowvol)

        push!(avg_actions_highspr, avg_action_highspr)
        push!(avg_actions_lowspr, avg_action_lowspr)

        push!(avg_actions_hightime, avg_action_hightime)
        push!(avg_actions_lowtime, avg_action_lowtime)

        push!(avg_actions_highinv, avg_action_highinv)
        push!(avg_actions_lowinv, avg_action_lowinv)

        push!(avg_actions_highvol_lowspr, avg_action_highvol_lowspr)
        push!(avg_actions_lowvol_highspr, avg_action_lowvol_highspr)

        push!(avg_actions_hightime_highinv, avg_action_hightime_highinv)
        push!(avg_actions_lowtime_lowinv, avg_action_lowtime_lowinv)


    end

    # plot the volume  
    v = plot(avg_actions_highvol, label = "High Volume", color = :blue, fg_legend = :transparent, xlabel = "Episodes", ylabel = "Average Action", title = "", titlefontsize = 11, fontfamily = "Computer Modern")
    plot!(avg_actions_lowvol, label = "Low Volume", color = :red, fg_legend = :transparent)

    # plot the spread
    s = plot(avg_actions_lowspr, label = "Narrow Spread", color = :blue, fg_legend = :transparent, xlabel = "Episodes", ylabel = "Average Action", title = "", titlefontsize = 11, fontfamily = "Computer Modern")
    plot!(avg_actions_highspr, label = "Wide Spread", color = :red, fg_legend = :transparent)

    # plot the time
    t = plot(avg_actions_hightime, label = "More remaining time", color = :blue, fg_legend = :transparent, xlabel = "Episodes", ylabel = "Average Action", title = "", titlefontsize = 11, fontfamily = "Computer Modern")
    plot!(avg_actions_lowtime, label = "Less remaining time", color = :red, fg_legend = :transparent)

    # plot the inventory
    i = plot(avg_actions_highinv, label = "More remaining inventory", color = :blue, fg_legend = :transparent, xlabel = "Episodes", ylabel = "Average Action", title = "", titlefontsize = 11, fontfamily = "Computer Modern")
    plot!(avg_actions_lowinv, label = "Less remaining inventory", color = :red, fg_legend = :transparent)

    # plot the spread vs volume interaction 
    sv = plot(avg_actions_highvol_lowspr, label = "High Volume, Narrow Spread", color = :blue, fg_legend = :transparent, xlabel = "Episodes", ylabel = "Average Action", title = "", titlefontsize = 11, fontfamily = "Computer Modern")
    plot!(avg_actions_lowvol_highspr, label = "Low Volume, Wide Spread", color = :red, fg_legend = :transparent)

    # plot the time and inventory interactions
    ti = plot(avg_actions_hightime_highinv, label = "More remaining time, more remaining inventory", color = :blue, fg_legend = :transparent, xlabel = "Episodes", ylabel = "Average Action", title = "", titlefontsize = 11, fontfamily = "Computer Modern")
    plot!(avg_actions_lowtime_lowinv, label = "Less remaining time, less remaining inventory", color = :red, fg_legend = :transparent)

    l = @layout([a b;c d;e f])
    P = plot(v, s, t, i, sv, ti, layout=l, size = (1000, 700))
#     savefig(P, "./MarlLearningDynamics.pdf")
    return P
end


In [None]:
l = b0_s1
deltas = [-1, 0, 3, 6]      # placement depth for limit orders, -1 is for market orders
MA2 = 5
actionsMap = GenerateActionsRL2(deltas, MA2)
PlotAverageActionsOverTime1(l, actionsMap)