In [1]:
import equinox as eqx
import gym
import jax
import jax.numpy as jnp
import jax.random as jrandom
import numpy as np
import optax
import sys
import timeit
import wandb

from argparse import Namespace
from functools import partial
from jax import grad, jit, vmap
from typing import Sequence

from jax_learning.agents.rl_agents import EpsilonGreedyAgent
from jax_learning.buffers.ram_buffers import NextStateNumPyBuffer
from jax_learning.buffers.utils import batch_flatten, to_jnp
from jax_learning.common import init_wandb
from jax_learning.constants import DISCRETE
from jax_learning.envs.wrappers.image import HWC2CHW
from jax_learning.learners.q_learning import QLearning
from jax_learning.models.q_functions import SoftmaxQ, NatureQ
from jax_learning.rl_utils import interact, evaluate

In [2]:
init_wandb(project="test_jax_rl", group="breakout-dqn_test", mode="disabled")

In [3]:
cfg_dict = {
    # Environment setup
    "env": "Breakout-v4",
    "seed": 0,
    "render": False,
    # Experiment progress
    "load_step": 0,
    "log_interval": 5000,
    "checkpoint_frequency": 0,
    "save_path": None,
    # Learning hyperparameters
    "max_timesteps": 1000000,
    "buffer_size": 1000000,
    "buffer_warmup": 1000,
    "num_gradient_steps": 1,
    "batch_size": 64,
    "lr": 3e-4,
    "max_grad_norm": 10.0,
    "gamma": 0.99,
    "update_frequency": 4,
    "target_update_frequency": 1,
    "tau": 0.005,  # This is for polyak averaging of target network
    "omega": 1.0,  # This is for residual gradient: 1 for semi-gradient
    # Epsilon greedy hyperparameters
    "init_eps": 1.0,
    "min_eps": 0.02,
    "eps_decay": 0.9999,
    "eps_warmup": 1000,
    # Normalization
    "normalize_obs": False,
    "normalize_value": False,
    # Model architecture
    "hidden_dim": 64,
    "num_hidden": 1,
    # Evaluation
    "evaluation_frequency": 5000,
    "eval_cfg": {
        "num_episodes": 10,
        "seed": 1,
        "render": True,
    },
}
cfg = Namespace(**cfg_dict)
eval_cfg = Namespace(**cfg.eval_cfg)
wandb.config = cfg_dict

In [4]:
np.random.seed(cfg.seed)

In [5]:
env = gym.make(cfg.env)
env = HWC2CHW(env, scale=255)

A.L.E: Arcade Learning Environment (version 0.8.0+919230b)
[Powered by Stella]
  logger.warn(f"Box bound precision lowered by casting to {self.dtype}")


In [6]:
cfg.obs_dim = env.observation_space.shape
cfg.act_dim = (env.action_space.n,)
cfg.action_space = DISCRETE

In [7]:
cfg.h_state_dim = (1,)
cfg.rew_dim = (1,)

In [8]:
cfg.buffer_rng = np.random.RandomState(cfg.seed)
cfg.env_rng = np.random.RandomState(cfg.seed)
cfg.agent_key, cfg.model_key = jrandom.split(jrandom.PRNGKey(cfg.seed), num=2)
eval_cfg.env_rng = np.random.RandomState(eval_cfg.seed)
cfg.evaluation_cfg = eval_cfg

In [9]:
cfg

Namespace(env='Breakout-v4', seed=0, render=False, load_step=0, log_interval=5000, checkpoint_frequency=0, save_path=None, max_timesteps=1000000, buffer_size=1000000, buffer_warmup=1000, num_gradient_steps=1, batch_size=64, lr=0.0003, max_grad_norm=10.0, gamma=0.99, update_frequency=4, target_update_frequency=1, tau=0.005, omega=1.0, init_eps=1.0, min_eps=0.02, eps_decay=0.9999, eps_warmup=1000, normalize_obs=False, normalize_value=False, hidden_dim=64, num_hidden=1, evaluation_frequency=5000, eval_cfg={'num_episodes': 10, 'seed': 1, 'render': True}, obs_dim=(3, 210, 160), act_dim=(4,), action_space='discrete', h_state_dim=(1,), rew_dim=(1,), buffer_rng=RandomState(MT19937) at 0x16A2F2C40, env_rng=RandomState(MT19937) at 0x16A2F2D40, agent_key=DeviceArray([4146024105,  967050713], dtype=uint32), model_key=DeviceArray([2718843009, 1272950319], dtype=uint32), evaluation_cfg=Namespace(num_episodes=10, seed=1, render=True, env_rng=RandomState(MT19937) at 0x110F48940))

In [10]:
Q = "q"

buffer = NextStateNumPyBuffer(
    buffer_size=cfg.buffer_size,
    obs_dim=cfg.obs_dim,
    h_state_dim=cfg.h_state_dim,
    act_dim=(1,) if cfg.action_space == DISCRETE else cfg.act_dim,
    rew_dim=cfg.rew_dim,
    rng=cfg.buffer_rng,
)

model = {
    Q: SoftmaxQ(
        q_function=NatureQ(
            in_channel=cfg.obs_dim[0],
            height=cfg.obs_dim[1],
            width=cfg.obs_dim[2],
            out_dim=cfg.act_dim,
            hidden_dim=cfg.hidden_dim,
            num_hidden=cfg.num_hidden,
            key=cfg.model_key,
        )
    ),
}

target_model = {
    Q: SoftmaxQ(
        q_function=NatureQ(
            in_channel=cfg.obs_dim[0],
            height=cfg.obs_dim[1],
            width=cfg.obs_dim[2],
            out_dim=cfg.act_dim,
            hidden_dim=cfg.hidden_dim,
            num_hidden=cfg.num_hidden,
            key=cfg.model_key,
        )
    ),
}

opt_transforms = [optax.scale_by_rms(), optax.scale(-cfg.lr)]
if cfg.max_grad_norm:
    opt_transforms.insert(0, optax.clip_by_global_norm(cfg.max_grad_norm))
opt = {Q: optax.chain(*opt_transforms)}

learner = QLearning(
    model=model, target_model=target_model, opt=opt, buffer=buffer, cfg=cfg
)

agent = EpsilonGreedyAgent(
    model=model,
    model_key=Q,
    buffer=buffer,
    learner=learner,
    init_eps=cfg.init_eps,
    min_eps=cfg.min_eps,
    eps_decay=cfg.eps_decay,
    eps_warmup=cfg.eps_warmup,
    action_space=DISCRETE,
    action_dim=cfg.act_dim[0],
    key=cfg.agent_key,
)

In [11]:
%wandb

In [12]:
interact(env, agent, cfg)

  0%|                         | 972/1000000 [00:02<43:18, 384.40it/s]

sample: 0.0065018750028684735
put to device: 0.0005341670184861869
update Q: 1.3831927920109592
sample: 0.005517332989256829
put to device: 0.00023899998632259667
update Q: 0.37944674998288974
sample: 0.00501412499579601
put to device: 0.0002509589830879122
update Q: 0.38247829201282
sample: 0.005646042001899332
put to device: 0.00023437500931322575


  0%|                       | 1013/1000000 [00:05<5:43:36, 48.46it/s]

update Q: 0.3772512919967994
sample: 0.0059889999974984676
put to device: 0.0002976669929921627
update Q: 0.3797273749951273
sample: 0.0050945420225616544
put to device: 0.0002346250112168491
update Q: 0.38622470799600706
sample: 0.005729041004087776
put to device: 0.00023004200193099678
update Q: 0.37737887498224154
sample: 0.005105666990857571
put to device: 0.0002213749976363033
update Q: 0.37763087498024106
sample: 0.004754624998895451
put to device: 0.00022183300461620092
update Q: 0.37970033401506953
sample: 0.005145290982909501
put to device: 0.00023616699036210775
update Q: 0.39032970799598843
sample: 0.005505083972821012
put to device: 0.0002307090035174042


  0%|                      | 1042/1000000 [00:08<10:19:28, 26.88it/s]

update Q: 0.3793186660041101
sample: 0.005992041988065466
put to device: 0.0002778329944703728
update Q: 0.3787002910103183
sample: 0.004799124988494441
put to device: 0.00022799998987466097
update Q: 0.37731362498016097
sample: 0.004571291996398941
put to device: 0.00022987500415183604
update Q: 0.3791487090056762
sample: 0.004459040996152908
put to device: 0.00023720800527371466
update Q: 0.3809495000168681
sample: 0.00433354199049063
put to device: 0.0002310840063728392


  0%|                      | 1063/1000000 [00:10<13:15:41, 20.92it/s]

update Q: 0.37941679201321676
sample: 0.005661957984557375
put to device: 0.0002871670003514737
update Q: 0.37765783301438205
sample: 0.0051626249914988875
put to device: 0.00023062500986270607
update Q: 0.380612334003672
sample: 0.004948416986735538
put to device: 0.00026058300863951445
update Q: 0.3799479169829283
sample: 0.005142749985679984
put to device: 0.00022887499653734267


  0%|                      | 1078/1000000 [00:11<15:50:34, 17.51it/s]

update Q: 0.3818934999871999
sample: 0.005613749992335215
put to device: 0.0003141670022159815
update Q: 0.38246512500336394
sample: 0.0042374169861432165
put to device: 0.00023704199702478945
update Q: 0.40919858298730105
sample: 0.00506083300570026
put to device: 0.00022487499518319964


  0%|                      | 1089/1000000 [00:13<18:06:22, 15.32it/s]

update Q: 0.4161198749789037
sample: 0.004970916983438656
put to device: 0.00024133300757966936
update Q: 0.38723270798800513
sample: 0.00443525001173839
put to device: 0.0002566669718362391


  0%|                      | 1097/1000000 [00:13<19:22:45, 14.32it/s]

update Q: 0.3799997090245597
sample: 0.0052621670183725655
put to device: 0.0002554579987190664


  0%|                      | 1103/1000000 [00:14<19:18:54, 14.37it/s]

update Q: 0.3800355830171611
sample: 0.005351125000743195
put to device: 0.00029420899227261543


  0%|                      | 1108/1000000 [00:14<19:41:08, 14.10it/s]

update Q: 0.37691262501175515
sample: 0.005022915982408449
put to device: 0.0002575830148998648


  0%|                      | 1112/1000000 [00:15<20:41:55, 13.41it/s]

update Q: 0.3802861669973936
sample: 0.005168958014110103
put to device: 0.0002822080277837813


  0%|                      | 1115/1000000 [00:15<22:43:47, 12.21it/s]

update Q: 0.39090479200240225
sample: 0.0048618330038152635
put to device: 0.00022683298448100686


  0%|                      | 1118/1000000 [00:15<24:50:41, 11.17it/s]

update Q: 0.3860368329915218
sample: 0.005022291996283457
put to device: 0.0002340409846510738


  0%|                      | 1121/1000000 [00:16<26:55:09, 10.31it/s]

update Q: 0.3826870839984622
sample: 0.005173250014195219
put to device: 0.0002644580090418458


  0%|                      | 1125/1000000 [00:16<27:23:02, 10.13it/s]

update Q: 0.3953966250119265
sample: 0.0051713330030906945
put to device: 0.0002842499816324562


  0%|                      | 1129/1000000 [00:17<27:42:58, 10.01it/s]

update Q: 0.3906417909893207
sample: 0.005082417017547414
put to device: 0.00023079197853803635


  0%|                      | 1133/1000000 [00:17<27:56:30,  9.93it/s]

update Q: 0.39003037501242943
sample: 0.004869583004619926
put to device: 0.0002910839975811541


  0%|                      | 1137/1000000 [00:17<27:59:27,  9.91it/s]

update Q: 0.384424541000044
sample: 0.0048650410026311874
put to device: 0.00024079199647530913


  0%|                      | 1141/1000000 [00:18<28:01:16,  9.90it/s]

update Q: 0.3847961249994114
sample: 0.005432374979136512
put to device: 0.00022399998852051795


  0%|                      | 1145/1000000 [00:18<27:58:18,  9.92it/s]

update Q: 0.38031558398506604
sample: 0.004584709007758647
put to device: 0.00024120800662785769


  0%|                      | 1149/1000000 [00:19<28:02:03,  9.90it/s]

update Q: 0.3840427910035942
sample: 0.00484120799228549
put to device: 0.00022441698820330203


  0%|                      | 1153/1000000 [00:19<28:01:00,  9.90it/s]

update Q: 0.38327129199751653
sample: 0.004667250002967194
put to device: 0.00023537498782388866


  0%|                      | 1157/1000000 [00:19<28:06:08,  9.87it/s]

update Q: 0.38737216699519195
sample: 0.004797000001417473
put to device: 0.00023070801398716867


  0%|                      | 1161/1000000 [00:20<28:10:13,  9.85it/s]

update Q: 0.3883647079928778
sample: 0.004998625023290515
put to device: 0.00023654100368730724


  0%|                      | 1165/1000000 [00:20<28:04:42,  9.88it/s]

update Q: 0.37921741598984227
sample: 0.00487504099146463
put to device: 0.0002399170189164579


  0%|                      | 1169/1000000 [00:21<28:17:41,  9.81it/s]

update Q: 0.3940174999879673
sample: 0.004908792005153373
put to device: 0.00024300001678057015


  0%|                      | 1173/1000000 [00:21<28:25:05,  9.76it/s]

update Q: 0.3928618749778252
sample: 0.005158500018296763
put to device: 0.0002612910175230354


  0%|                      | 1177/1000000 [00:22<28:54:33,  9.60it/s]

update Q: 0.4109682079870254
sample: 0.004994791990611702
put to device: 0.000254459009738639


  0%|                      | 1181/1000000 [00:22<28:41:59,  9.67it/s]

update Q: 0.38509162500849925
sample: 0.004479416995309293
put to device: 0.00023166698520071805


  0%|                      | 1185/1000000 [00:22<28:28:18,  9.74it/s]

update Q: 0.381808790989453
sample: 0.004874708014540374
put to device: 0.0002749169943854213


  0%|                      | 1189/1000000 [00:23<28:28:33,  9.74it/s]

update Q: 0.38921483399462886
sample: 0.004630832991097122
put to device: 0.0002452080079820007


  0%|                      | 1193/1000000 [00:23<28:27:12,  9.75it/s]

update Q: 0.3867682080017403
sample: 0.004758166003739461
put to device: 0.00022437499137595296


  0%|                      | 1197/1000000 [00:24<28:38:30,  9.69it/s]

update Q: 0.3973773329926189
sample: 0.004875874990830198
put to device: 0.0002677080046851188


  0%|                      | 1201/1000000 [00:24<28:53:55,  9.60it/s]

update Q: 0.402388125017751
sample: 0.004593874997226521
put to device: 0.00024083300377242267


  0%|                      | 1205/1000000 [00:24<28:39:31,  9.68it/s]

update Q: 0.38263729101163335
sample: 0.0045545830216724426
put to device: 0.0002585829934105277


  0%|                      | 1209/1000000 [00:25<28:27:01,  9.75it/s]

update Q: 0.38087662498583086
sample: 0.004850750003242865
put to device: 0.00024295799084939063


  0%|                      | 1213/1000000 [00:25<28:17:43,  9.81it/s]

update Q: 0.3800620840047486
sample: 0.0044578750093933195
put to device: 0.0002348329871892929


  0%|                      | 1217/1000000 [00:26<28:34:29,  9.71it/s]

update Q: 0.40140487501048483
sample: 0.004944250016706064
put to device: 0.0002988339983858168


  0%|                      | 1221/1000000 [00:26<28:50:22,  9.62it/s]

update Q: 0.4055310410039965
sample: 0.004680166981415823
put to device: 0.0002831250021699816


  0%|                      | 1225/1000000 [00:26<28:47:28,  9.64it/s]

update Q: 0.39320695798960514
sample: 0.004726749990368262
put to device: 0.00022833398543298244


  0%|                       | 1228/1000000 [00:27<6:11:06, 44.86it/s]


KeyboardInterrupt: 

In [None]:
wandb.finish()