## 1. A gentle example of using ReinforcementLearning.jl

In [1]:
import Pkg;
# uncomment the following if you have not installed them
# Pkg.add("ReinforcementLearning");
# Pkg.add("Flux");
# Pkg.add("StableRNGs");
# Pkg.add("Distributions");
using Flux: InvDecay;
using ReinforcementLearning;
using StableRNGs;
using Flux;
using Flux.Losses;
using Distributions;

In [2]:
env = RandomWalk1D()

# RandomWalk1D

## Traits

| Trait Type        |                Value |
|:----------------- | --------------------:|
| NumAgentStyle     |        SingleAgent() |
| DynamicStyle      |         Sequential() |
| InformationStyle  | PerfectInformation() |
| ChanceStyle       |      Deterministic() |
| RewardStyle       |     TerminalReward() |
| UtilityStyle      |         GeneralSum() |
| ActionStyle       |   MinimalActionSet() |
| StateStyle        | Observation{Int64}() |
| DefaultStateStyle | Observation{Int64}() |

## Is Environment Terminated?

No

## State Space

`Base.OneTo(7)`

## Action Space

`Base.OneTo(2)`

## Current State

```
4
```


### random policy

In [3]:
A = action_space(env)
while true
    env(rand(A))
    is_terminated(env) && break
end

In [4]:
run(
    RandomPolicy(),
    RandomWalk1D(),
    StopAfterEpisode(10),
    TotalRewardPerEpisode())

            ⠀⠀⠀⠀⠀⠀⠀⠀⠀[97;1mTotal reward per episode[0m⠀⠀⠀⠀⠀⠀⠀⠀⠀ 
            [38;5;8m┌────────────────────────────────────────┐[0m 
          [38;5;8m1[0m [38;5;8m│[0m⠀⠀⠀⠀[38;5;2m⣷[0m⠀⠀⠀⠀⠀⠀⠀⠀[38;5;2m⣇[0m⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀[38;5;2m⣸[0m⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀[38;5;8m│[0m [38;5;8m[0m
           [38;5;8m[0m [38;5;8m│[0m⠀⠀⠀[38;5;2m⢠[0m[38;5;2m⢻[0m⠀⠀⠀⠀⠀⠀⠀[38;5;2m⢰[0m[38;5;2m⢹[0m⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀[38;5;2m⡏[0m[38;5;2m⡆[0m⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀[38;5;8m│[0m [38;5;8m[0m
           [38;5;8m[0m [38;5;8m│[0m⠀⠀⠀[38;5;2m⢸[0m⠀[38;5;2m⡇[0m⠀⠀⠀⠀⠀⠀[38;5;2m⢸[0m[38;5;2m⠈[0m[38;5;2m⡆[0m⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀[38;5;2m⢰[0m[38;5;2m⠁[0m[38;5;2m⡇[0m⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀[38;5;8m│[0m [38;5;8m[0m
           [38;5;8m[0m [38;5;8m│[0m⠀⠀⠀[38;5;2m⡎[0m⠀[38;5;2m⢇[0m⠀⠀⠀⠀⠀⠀[38;5;2m⡇[0m⠀[38;5;2m⡇[0m⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀[38;5;2m⢸[0m⠀[38;5;2m⢸[0m⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀[38;5;8m│[0m [38;5;8m[0m
           [38;5;8m[0m [38;5;8m│[0m⠀⠀⠀[38;5;2m⡇[0m⠀[38;5;2m⢸[0m⠀⠀⠀⠀⠀[38;5;2m⢀[0m[38;5;2m⠇[0m⠀[38;5;2m⢸[0m⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀[38

TotalRewardPerEpisode([-1.0, 1.0, -1.0, 1.0, -1.0, -1.0, 1.0, -1.0, -1.0, -1.0], 0.0, true)

### tabular policy

In [5]:
# create a tabular policy
S = state_space(env);
A = action_space(env);
NS, NA = length(S),A;
tabular_policy = TabularPolicy(;table=Dict(zip(1:NS, fill(2,NS))));

In [6]:
run(
   tabular_policy,
   RandomWalk1D(),
   StopAfterEpisode(10),
   TotalRewardPerEpisode()
)

           ⠀⠀⠀⠀⠀⠀⠀⠀⠀[97;1mTotal reward per episode[0m⠀⠀⠀⠀⠀⠀⠀⠀⠀ 
           [38;5;8m┌────────────────────────────────────────┐[0m 
         [38;5;8m2[0m [38;5;8m│[0m⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀[38;5;8m│[0m [38;5;8m[0m
          [38;5;8m[0m [38;5;8m│[0m⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀[38;5;8m│[0m [38;5;8m[0m
          [38;5;8m[0m [38;5;8m│[0m⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀[38;5;8m│[0m [38;5;8m[0m
          [38;5;8m[0m [38;5;8m│[0m⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀[38;5;8m│[0m [38;5;8m[0m
          [38;5;8m[0m [38;5;8m│[0m⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀[38;5;8m│[0m [38;5;8m[0m
          [38;5;8m[0m [38;5;8m│[0m⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀[38;5;8m│[0m [38;5;8m[0m
          [38;5;8m[0m [38;5;8m│[0m⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀[38;5;8m│[0m [38;5;8m[0m
   Score  [38;5;8m[0m [38;5;8m│[0m[38;5;2m⠤[0m[38;5;2m⠤[0m[38;5;2m⠤[0m[38;5;2m⠤[0m[38;5;2m⠤[0m[38;5;2m

TotalRewardPerEpisode([1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], 0.0, true)

### `QBasedPolicy`

In [8]:
# `MonteCarloLearner + EpsilonGreedyExplorer`
policy = QBasedPolicy(
   learner = MonteCarloLearner(;
           approximator=TabularQApproximator(
               ;n_state = NS,
               n_action = NA,
               opt = InvDecay(1.0)
           )
       ),
   explorer = EpsilonGreedyExplorer(0.1)
)

typename(QBasedPolicy)
├─ learner => typename(MonteCarloLearner)
│  ├─ approximator => typename(TabularApproximator)
│  │  ├─ table => 2×7 Matrix{Float64}
│  │  └─ optimizer => typename(InvDecay)
│  │     ├─ gamma => 1.0
│  │     └─ state => typename(IdDict)
│  ├─ γ => 1.0
│  ├─ kind => typename(ReinforcementLearningZoo.FirstVisit)
│  └─ sampling => typename(ReinforcementLearningZoo.NoSampling)
└─ explorer => typename(EpsilonGreedyExplorer)
   ├─ ϵ_stable => 0.1
   ├─ ϵ_init => 1.0
   ├─ warmup_steps => 0
   ├─ decay_steps => 0
   ├─ step => 1
   ├─ rng => typename(Random._GLOBAL_RNG)
   └─ is_training => true


In [9]:
run(
   policy,
   RandomWalk1D(),
   StopAfterEpisode(10),
   TotalRewardPerEpisode()
)

            ⠀⠀⠀⠀⠀⠀⠀⠀⠀[97;1mTotal reward per episode[0m⠀⠀⠀⠀⠀⠀⠀⠀⠀ 
            [38;5;8m┌────────────────────────────────────────┐[0m 
          [38;5;8m0[0m [38;5;8m│[0m⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀[38;5;8m│[0m [38;5;8m[0m
           [38;5;8m[0m [38;5;8m│[0m⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀[38;5;8m│[0m [38;5;8m[0m
           [38;5;8m[0m [38;5;8m│[0m⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀[38;5;8m│[0m [38;5;8m[0m
           [38;5;8m[0m [38;5;8m│[0m⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀[38;5;8m│[0m [38;5;8m[0m
           [38;5;8m[0m [38;5;8m│[0m⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀[38;5;8m│[0m [38;5;8m[0m
           [38;5;8m[0m [38;5;8m│[0m⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀[38;5;8m│[0m [38;5;8m[0m
           [38;5;8m[0m [38;5;8m│[0m⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀[38;5;8m│[0m [38;5;8m[0m
   Score   [38;5;8m[0m [38;5;8m│[0m[38;5;2m⠤[0m[38;5;2m⠤[0m[38;5;2m⠤[0m[38;5;2m⠤[0m[38;5;2m⠤[0

TotalRewardPerEpisode([-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0], 0.0, true)

### wrap the policy + trajectory into the 'agent'

In [10]:
agent = Agent(policy=policy, trajectory=VectorSARTTrajectory())

typename(Agent)
├─ policy => typename(QBasedPolicy)
│  ├─ learner => typename(MonteCarloLearner)
│  │  ├─ approximator => typename(TabularApproximator)
│  │  │  ├─ table => 2×7 Matrix{Float64}
│  │  │  └─ optimizer => typename(InvDecay)
│  │  │     ├─ gamma => 1.0
│  │  │     └─ state => typename(IdDict)
│  │  ├─ γ => 1.0
│  │  ├─ kind => typename(ReinforcementLearningZoo.FirstVisit)
│  │  └─ sampling => typename(ReinforcementLearningZoo.NoSampling)
│  └─ explorer => typename(EpsilonGreedyExplorer)
│     ├─ ϵ_stable => 0.1
│     ├─ ϵ_init => 1.0
│     ├─ warmup_steps => 0
│     ├─ decay_steps => 0
│     ├─ step => 31
│     ├─ rng => typename(Random._GLOBAL_RNG)
│     └─ is_training => true
└─ trajectory => typename(Trajectory)
   └─ traces => typename(NamedTuple)
      ├─ state => 0-element Vector{Int64}
      ├─ action => 0-element Vector{Int64}
      ├─ reward => 0-element Vector{Float32}
      └─ terminal => 0-element Vector{Bool}


In [11]:
run(agent, env, StopAfterEpisode(10), TotalRewardPerEpisode())

            ⠀⠀⠀⠀⠀⠀⠀⠀⠀[97;1mTotal reward per episode[0m⠀⠀⠀⠀⠀⠀⠀⠀⠀ 
            [38;5;8m┌────────────────────────────────────────┐[0m 
          [38;5;8m1[0m [38;5;8m│[0m⠀⠀⠀⠀[38;5;2m⡏[0m[38;5;2m⠉[0m[38;5;2m⠉[0m[38;5;2m⠉[0m[38;5;2m⠉[0m[38;5;2m⠉[0m[38;5;2m⠉[0m[38;5;2m⠉[0m[38;5;2m⠉[0m[38;5;2m⠉[0m[38;5;2m⠉[0m[38;5;2m⠉[0m[38;5;2m⠉[0m[38;5;2m⠉[0m[38;5;2m⠉[0m[38;5;2m⠉[0m[38;5;2m⠉[0m[38;5;2m⠉[0m[38;5;2m⠉[0m[38;5;2m⠉[0m[38;5;2m⠉[0m[38;5;2m⠉[0m[38;5;2m⠉[0m[38;5;2m⠉[0m[38;5;2m⠉[0m[38;5;2m⠉[0m[38;5;2m⠉[0m[38;5;2m⠉[0m[38;5;2m⠉[0m[38;5;2m⠉[0m[38;5;2m⠉[0m[38;5;2m⠉[0m[38;5;2m⠉[0m[38;5;2m⠉[0m[38;5;2m⠉[0m[38;5;2m⠉[0m[38;5;8m│[0m [38;5;8m[0m
           [38;5;8m[0m [38;5;8m│[0m⠀⠀⠀[38;5;2m⢠[0m[38;5;2m⠃[0m⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀[38;5;8m│[0m [38;5;8m[0m
           [38;5;8m[0m [38;5;8m│[0m⠀⠀⠀[38;5;2m⢸[0m⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀[38;5;8m│[0m [38;5;8m[0m
           [38;5;8m[0m [38;5;8

TotalRewardPerEpisode([-1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], 0.0, true)

## 2. PPO algorithm for pendulum problem (built-in experiment)

In [12]:

function RL.Experiment(
    ::Val{:JuliaRL},
    ::Val{:PPO},
    ::Val{:Pendulum},
    ::Nothing;
    save_dir = nothing,
    seed = 123,
)
    rng = StableRNG(seed)
    inner_env = PendulumEnv(T = Float32, rng = rng)
    A = action_space(inner_env)
    low = A.left
    high = A.right
    ns = length(state(inner_env))

    N_ENV = 8
    UPDATE_FREQ = 2048
    env = MultiThreadEnv([
        PendulumEnv(T = Float32, rng = StableRNG(hash(seed + i))) |>
        env -> ActionTransformedEnv(env, action_mapping = x -> clamp(x * 2, low, high)) for i in 1:N_ENV
    ])

    init = glorot_uniform(rng)

    agent = Agent(
        policy = PPOPolicy(
            approximator = ActorCritic(
                actor = GaussianNetwork(
                    pre = Chain(
                        Dense(ns, 64, relu; init = glorot_uniform(rng)),
                        Dense(64, 64, relu; init = glorot_uniform(rng)),
                    ),
                    μ = Chain(Dense(64, 1, tanh; init = glorot_uniform(rng)), vec),
                    logσ = Chain(Dense(64, 1; init = glorot_uniform(rng)), vec),
                ),
                critic = Chain(
                    Dense(ns, 64, relu; init = glorot_uniform(rng)),
                    Dense(64, 64, relu; init = glorot_uniform(rng)),
                    Dense(64, 1; init = glorot_uniform(rng)),
                ),
                optimizer = ADAM(3e-4),
            ) |> gpu,
            γ = 0.99f0,
            λ = 0.95f0,
            clip_range = 0.2f0,
            max_grad_norm = 0.5f0,
            n_epochs = 10,
            n_microbatches = 32,
            actor_loss_weight = 1.0f0,
            critic_loss_weight = 0.5f0,
            entropy_loss_weight = 0.00f0,
            dist = Normal,
            rng = rng,
            update_freq = UPDATE_FREQ,
        ),
        trajectory = PPOTrajectory(;
            capacity = UPDATE_FREQ,
            state = Matrix{Float32} => (ns, N_ENV),
            action = Vector{Float32} => (N_ENV,),
            action_log_prob = Vector{Float32} => (N_ENV,),
            reward = Vector{Float32} => (N_ENV,),
            terminal = Vector{Bool} => (N_ENV,),
        ),
    )

    stop_condition = StopAfterStep(50_000, is_show_progress=!haskey(ENV, "CI"))
    hook = TotalBatchRewardPerEpisode(N_ENV)
    Experiment(agent, env, stop_condition, hook, "# Play Pendulum with PPO")
end

[32m[1m   Resolving[22m[39m package versions...
[32m[1m  No Changes[22m[39m to `~/.julia/environments/v1.7/Project.toml`
[32m[1m  No Changes[22m[39m to `~/.julia/environments/v1.7/Manifest.toml`
[32m[1m   Resolving[22m[39m package versions...
[32m[1m  No Changes[22m[39m to `~/.julia/environments/v1.7/Project.toml`
[32m[1m  No Changes[22m[39m to `~/.julia/environments/v1.7/Manifest.toml`


In [13]:
    end
    update!(opt, params(m), gs)
  end
  @show accuracy(valX, valY)
endPkg.add("Plots")
using Plots
using Statistics
ex = E`JuliaRL_PPO_Pendulum`
run(ex)
# n = minimum(map(length, ex.hook.rewards))
# m = mean([@view(x[1:n]) for x in ex.hook.rewards])
# s = std([@view(x[1:n]) for x in ex.hook.rewards])
# plot(m,ribbon=s)

[32m[1m   Resolving[22m[39m package versions...
[32m[1m  No Changes[22m[39m to `~/.julia/environments/v1.7/Project.toml`
[32m[1m  No Changes[22m[39m to `~/.julia/environments/v1.7/Manifest.toml`
┌ Info: The GPU function is being called but the GPU is not accessible. 
│ Defaulting back to the CPU. (No action is required if you want to run on the CPU).
└ @ Flux /home/richard/.julia/packages/Flux/7nTyc/src/functor.jl:187


# Play Pendulum with PPO


[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:50[39m9:52[39m


               ⠀⠀⠀⠀⠀⠀⠀[97;1mAvg total reward per episode[0m⠀⠀⠀⠀⠀⠀⠀ 
               [38;5;8m┌────────────────────────────────────────┐[0m 
             [38;5;8m0[0m [38;5;8m│[0m⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀[38;5;1m⢠[0m[38;5;1m⣾[0m[38;5;1m⣷[0m[38;5;1m⡿[0m[38;5;1m⣶[0m[38;5;1m⣿[0m[38;5;1m⣿[0m[38;5;1m⣴[0m[38;5;1m⣷[0m[38;5;1m⠃[0m⠀⠀⠀⠀⠀⠀[38;5;8m│[0m [38;5;8m[0m
              [38;5;8m[0m [38;5;8m│[0m⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀[38;5;1m⢠[0m[38;5;1m⣴[0m[38;5;1m⣸[0m[38;5;1m⠟[0m[38;5;3m⣯[0m[38;5;7m⣧[0m[38;5;6m⣿[0m[38;5;3m⣿[0m[38;5;6m⣶[0m[38;5;7m⣾[0m[38;5;3m⣸[0m[38;5;2m⡆[0m⠀⠀⠀⠀⠀⠀[38;5;8m│[0m [38;5;8m[0m
              [38;5;8m[0m [38;5;8m│[0m⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀[38;5;1m⣼[0m[38;5;1m⣿[0m[38;5;3m⣿[0m[38;5;2m⣿[0m[38;5;6m⣿[0m[38;5;6m⣧[0m[38;5;6m⣿[0m[38;5;6m⣿[0m[38;5;6m⣿[0m[38;5;6m⣿[0m[38;5;6m⢻[0m[38;5;4m⡆[0m⠀⠀⠀⠀⠀⠀[38;5;8m│[0m [38;5;8m[0m
              [38;5;8m[0m [38;5;8m│[0m⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀[38;5;1m⢀[0m[38;5;1

# Play Pendulum with PPO


typename(Experiment)
├─ policy => typename(Agent)
│  ├─ policy => typename(PPOPolicy)
│  │  ├─ approximator => typename(ActorCritic)
│  │  │  ├─ actor => typename(GaussianNetwork)
│  │  │  │  ├─ pre => typename(Chain)
│  │  │  │  │  └─ layers
│  │  │  │  │     ├─ 1
│  │  │  │  │     │  └─ typename(Dense)
│  │  │  │  │     │     ├─ weight => 64×3 Matrix{Float32}
│  │  │  │  │     │     ├─ bias => 64-element Vector{Float32}
│  │  │  │  │     │     └─ σ => typename(typeof(relu))
│  │  │  │  │     └─ 2
│  │  │  │  │        └─ typename(Dense)
│  │  │  │  │           ├─ weight => 64×64 Matrix{Float32}
│  │  │  │  │           ├─ bias => 64-element Vector{Float32}
│  │  │  │  │           └─ σ => typename(typeof(relu))
│  │  │  │  ├─ μ => typename(Chain)
│  │  │  │  │  └─ layers
│  │  │  │  │     ├─ 1
│  │  │  │  │     │  └─ typename(Dense)
│  │  │  │  │     │     ├─ weight => 1×64 Matrix{Float32}
│  │  │  │  │     │     ├─ bias => 1-element Vector{Float32}
│  │  │  │  │     │     └─ σ => typen

## 3. Break down the experiment running implementation

[reference: JuliaRinforcementLearning Blog](https://juliareinforcementlearning.org/blog/an_introduction_to_reinforcement_learning_jl_design_implementations_thoughts/#ol_start2_design)

In [18]:
seed = 123;
rng = StableRNG(seed);
env = CartPoleEnv(; T = Float32);
ns, na = length(state_space(env)), length(action_space(env));

In [11]:
function custom_run(
        policy, 
        env, 
        stop_criterion = StopAfterEpisode(5),
        hook = EmptyHook(),
    )
    step_counter = 0
    while true
        reset!(env)
        #policy(PRE_EPISODE_STAGE)
        
        while !is_terminated(env)
            #env |> policy |> env
            action = policy(env)
            step_counter = step_counter +1
            
            #policy(PRE_ACT_STAGE, env, action)
            env(action)
            
            println(step_counter, reward(env))
            
            #policy(POST_ACT_STAGE, env)
            stop_criterion(policy, env) && return
        end
        #policy(POST_EPISODE_STAGE)
    end
end

custom_run (generic function with 3 methods)

In [12]:
function dumm_policy(env)
    return 1
end

dumm_policy (generic function with 1 method)

In [14]:
custom_run(dumm_policy, env, StopAfterEpisode(15))

11.0
21.0
31.0
41.0
51.0
61.0
71.0
80.0
91.0
101.0
111.0
121.0
131.0
141.0
151.0
161.0
170.0
181.0
191.0
201.0
211.0
221.0
231.0
241.0
251.0
261.0
270.0
281.0
291.0
301.0
311.0
321.0
331.0
341.0
351.0
361.0
370.0
381.0
391.0
401.0
411.0
421.0
431.0
441.0
450.0
461.0
471.0
481.0
491.0
501.0
511.0
521.0
531.0
541.0
550.0
561.0
571.0
581.0
591.0
601.0
611.0
621.0
630.0
641.0
651.0
661.0
671.0
681.0
691.0
701.0
711.0
721.0
730.0
741.0
751.0
761.0
771.0
781.0
791.0
801.0
811.0
820.0
831.0
841.0
851.0
861.0
871.0
881.0
891.0
901.0
911.0
920.0
931.0
941.0
951.0
961.0
971.0
981.0
991.0
1001.0
1011.0
1020.0
1031.0
1041.0
1051.0
1061.0
1071.0
1081.0
1091.0
1101.0
1111.0
1120.0
1131.0
1141.0
1151.0
1161.0
1171.0
1181.0
1191.0
1201.0
1211.0
1220.0
1231.0
1241.0
1251.0
1261.0
1271.0
1281.0
1291.0
1301.0
1310.0
1321.0
1331.0
1341.0
1351.0
1361.0
1371.0
1381.0
1391.0
1401.0
1410.0


### 4. A Q-learning Agent

For simplicity, we omit the explorer here for env with small state and action spaces.

In [15]:
export QBasedPolicy, TabularRandomPolicy

Pkg.add("MacroTools")
Pkg.add("Setfield")
using MacroTools: @forward
using Setfield: @set

[32m[1m    Updating[22m[39m registry at `~/.julia/registries/General.toml`
[32m[1m   Resolving[22m[39m package versions...
[32m[1m  No Changes[22m[39m to `~/.julia/environments/v1.7/Project.toml`
[32m[1m  No Changes[22m[39m to `~/.julia/environments/v1.7/Manifest.toml`
[32m[1m   Resolving[22m[39m package versions...
[32m[1m  No Changes[22m[39m to `~/.julia/environments/v1.7/Project.toml`
[32m[1m  No Changes[22m[39m to `~/.julia/environments/v1.7/Manifest.toml`


In [68]:
model = Chain(
    Dense(ns, 128, relu; init = glorot_uniform(rng)),
    Dense(128, 128, relu; init = glorot_uniform(rng)),
    Dense(128, na; init = glorot_uniform(rng)),
) |> gpu;
optimizer = ADAM();

In [195]:
using Flux.Losses: mse
println("initial output: ", model(state(env)))

# a template update step
function loss(states, action_label)
    #action = argmax([model(states), model(states)], dims=2);
    action = 
    return sum(mse(action, action_label))
end

y = [100, 100];
states = state(env)

g = Flux.gradient(()->loss(states, y), Flux.params(model))

Flux.train!(loss, Flux.params(model), [(states, y)], optimizer)

# for i = 1:5
#     Flux.update!(optimizer, Flux.params(model), g)
# end

println("updated output: ", model(state(env)))

initial output: Float32[-0.0019286007, -0.0044424515]
updated output: Float32[-0.0019286007, -0.0044424515]


In [176]:
g[]

LoadError: Only reference types can be differentiated with `Params`.