# Q-Learning 

In [34]:
using CSV
using DataFrames 
using Printf
using LinearAlgebra 
using Distributions 
using Tables

df = CSV.read("pitchData.csv", DataFrame)
y2df = first(df,1433578)

Row,s,a,r,sp
Unnamed: 0_level_1,Int64,Int64,Float64,Int64
1,1,5,2.5,9
2,9,5,2.5,17
3,17,5,2.5,17
4,17,5,-2.0,41
5,41,2,-2.0,65
6,65,5,-10.0,385
7,1,5,-2.0,25
8,25,4,-10.0,385
9,195,5,-2.0,219
10,219,5,-2.0,243


In [36]:


# Algorithm 17.2: Q-learning Algorithm
mutable struct QLearning
    ùíÆ # state space (assumes 1:nstates)
    ùíú # action space (assumes 1:nactions)
    Œ≥ # discount
    Q # action value function
    Œ± # learning rate
end

lookahead(model::QLearning, s, a) = model.Q[s,a]

function update!(model::QLearning, s, a, r, s‚Ä≤) 
    Œ≥, Q, Œ± = model.Œ≥, model.Q, model.Œ±
    Q[s,a] += Œ±*(r + Œ≥*maximum(Q[s‚Ä≤,:]) - Q[s,a]) 
    return model
end

# Q learning parameters
S = 385
A = 16
Œ≥ = 0.95
Œ± = .001
Q = zeros(Float64, (S,A))

model = QLearning(S, A, Œ≥, Q, Œ±)

# Update Q matrix with iteration

for i in 1:5
    for j in 1:size(y2df)[1]
        update!(model, y2df[j,1], y2df[j,2], y2df[j,3], y2df[j,4]) 
    end 
end 

# Policy Extraction
# Initialize
Policy = zeros(Int64, (S,1))

for i in 1:S
    Policy[i] = findmax(Q[i,:])[2]
end

# Write to CSV file
CSV.write("QLearn_pitch2.csv", Tables.table(Policy), writeheader = false)

"QLearn_pitch2.csv"

## Sarsa 

In [40]:

using CSV
using DataFrames 
using DataStructures
using Printf
using LinearAlgebra 
using Distributions 
using Tables

df = CSV.read("pitchData.csv", DataFrame)
y1df = first(df,716788)

mutable struct Sarsa
    ùíÆ # state space (assumes 1:nstates)
    ùíú # action space (assumes 1:nactions)
    Œ≥ # discount
    Q # action value function
    Œ± # learning rate
    ‚Ñì # most recent experience tuple (s,a,r)
end
lookahead(model::Sarsa, s, a) = model.Q[s,a]

# Sarsa learning parameters
S = 385
A = 16

Œ≥ = 0.95
Œ± = .001
SarsaQ = zeros(Float64, (S,A));


SarsaModel = Sarsa(S,A,Œ≥,SarsaQ,Œ±,nothing)

function update!(model::Sarsa, s, a, r, s‚Ä≤)
    if model.‚Ñì != nothing
        Œ≥, Q, Œ±, ‚Ñì = model.Œ≥, model.Q, model.Œ±,  model.‚Ñì
        model.Q[‚Ñì.s,‚Ñì.a] += Œ±*(‚Ñì.r + Œ≥*Q[s,a] - Q[‚Ñì.s,‚Ñì.a])
    end
    model.‚Ñì = (s=s, a=a, r=r)
    return model

end

for i in 1:5
    for j in 1:size(y1df)[1]
        update!(SarsaModel, y1df[j,1], y1df[j,2], y1df[j,3], y1df[j,4]) 
    end
end

# Policy Extraction
# Initialize
SarsaPolicy = zeros(Int64, (S,1))

for i in 1:S
    SarsaPolicy[i] = findmax(SarsaQ[i,:])[2]
end

# Write to CSV file
CSV.write("Sarsa_pitch1.csv", Tables.table(SarsaPolicy), writeheader = false)



"Sarsa_pitch1.csv"

## SarsaLambda

In [41]:




## SarsaLambda
mutable struct SarsaLambda
    ùíÆ # state space (assumes 1:nstates)
    ùíú # action space (assumes 1:nactions)
    Œ≥ # discount
    Q # action value function
    N # trace
    Œ± # learning rate
    Œª # trace decay rate
    ‚Ñì # most recent experience tuple (s,a,r)
end

# Sarsa learning parameters
S = 385
A = 16

Œ≥ = 0.95
Œ± = .001
SarsaLQ = zeros(Float64, (S,A));

lookahead(model::SarsaLambda, s, a) = model.Q[s,a]

function update!(model::SarsaLambda, s, a, r, s‚Ä≤)
    if model.‚Ñì != nothing
        Œ≥, Œª, Q, Œ±, ‚Ñì = model.Œ≥, model.Œª, model.Q, model.Œ±, model.‚Ñì
        model.N[‚Ñì.s,‚Ñì.a] += 1
        Œ¥ = ‚Ñì.r + Œ≥*Q[s,a] - Q[‚Ñì.s,‚Ñì.a]
        for s in model.ùíÆ
            for a in model.ùíú
                model.Q[s,a] += Œ±*Œ¥*model.N[s,a]
                model.N[s,a] *= Œ≥*Œª
            end
        end
    else
    	model.N[:,:] .= 0.0
    end
    model.‚Ñì = (s=s, a=a, r=r)
    return model
end

N = zeros(S,A)

SarsaLamdaModel = SarsaLambda(S,A,Œ≥,SarsaLQ, N , Œ±, .5, nothing)

for i in 1:5
    for j in 1:size(df)[1]
        update!(SarsaLamdaModel, df[j,1], df[j,2], df[j,3], df[j,4]) 
    end
end

# Policy Extraction
# Initialize
SarsaLamdaPolicy = zeros(Int64, (S,1))

for i in 1:S
    SarsaLamdaPolicy[i] = findmax(SarsaLQ[i,:])[2]
end

# Write to CSV file
CSV.write("SarsaLamda_pitch.policy", Tables.table(SarsaLamdaPolicy), writeheader = false)


"SarsaLamda_pitch.policy"