In [1]:
import argparse
import os
import sys
import random
import time
import re
from dataclasses import dataclass
from distutils.util import strtobool
from typing import Any, List, Optional, Union, Tuple, Iterable
import gym
import numpy as np
import torch
import torch as t
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from einops import rearrange
from gym.spaces import Discrete, Box
from matplotlib import pyplot as plt
from torch.utils.tensorboard import SummaryWriter
from numpy.random import Generator
import gym.envs.registration
import pandas as pd
import utils
from utils_tabular import make_env
import wandb

t.set_default_dtype(t.float32)


MAIN = __name__ == "__main__"
os.environ["SDL_VIDEODRIVER"] = "dummy"


  from collections import Mapping, Iterable
  if not hasattr(tensorboard, "__version__") or LooseVersion(


In [3]:
class QNetwork(nn.Module):
    def __init__(
        self,
        dim_observation: int,
        num_actions: int,
        hidden_sizes: list[int] = [180, 120],
    ):
        super().__init__()
        self.fc1 = nn.Linear(dim_observation, hidden_sizes[0])
        self.fc2 = nn.Linear(hidden_sizes[0], hidden_sizes[1])
        self.fc3 = nn.Linear(hidden_sizes[1], num_actions)
        self.relu = nn.ReLU()

    def forward(self, x: t.Tensor) -> t.Tensor:
        out = self.relu(self.fc1(x))
        out = self.relu(self.fc2(out))

        return self.fc3(out)


if MAIN:
    net = QNetwork(dim_observation=4, num_actions=2)
    n_params = sum((p.nelement() for p in net.parameters()))
    print(net)
    print(f"Total number of parameters: {n_params}")
    print("You should manually verify network is Linear-ReLU-Linear-ReLU-Linear")
    assert n_params == 10934


QNetwork(
  (fc1): Linear(in_features=4, out_features=180, bias=True)
  (fc2): Linear(in_features=180, out_features=120, bias=True)
  (fc3): Linear(in_features=120, out_features=2, bias=True)
  (relu): ReLU()
)
Total number of parameters: 22862
You should manually verify network is Linear-ReLU-Linear-ReLU-Linear


AssertionError: 

In [4]:
@dataclass
class ReplayBufferSamples:
    """
    Samples from the replay buffer, converted to PyTorch for use in neural network 
    training.
    obs: shape (sample_size, *observation_shape), dtype t.float
    actions: shape (sample_size, ) dtype t.int
    rewards: shape (sample_size, ), dtype t.float
    dones: shape (sample_size, ), dtype t.bool
    next_observations: shape (sample_size, *observation_shape), dtype t.float
    """

    observations: t.Tensor
    actions: t.Tensor
    rewards: t.Tensor
    dones: t.Tensor
    next_observations: t.Tensor


class ReplayBuffer:
    rng: Generator
    observations: t.Tensor
    actions: t.Tensor
    rewards: t.Tensor
    dones: t.Tensor
    next_observations: t.Tensor

    def __init__(
        self,
        buffer_size: int,
        num_actions: int,
        observation_shape: tuple,
        num_environments: int,
        seed: int,
    ):
        assert (
            num_environments == 1
        ), "This buffer only supports SyncVectorEnv with 1 environment inside."
        self.observations = t.zeros((buffer_size, *observation_shape), dtype=t.float32)
        self.actions = t.zeros((buffer_size,), dtype=t.int64)
        self.rewards = t.zeros((buffer_size,), dtype=t.float32)
        self.dones = t.zeros((buffer_size,), dtype=t.bool)
        self.next_observations = t.zeros(
            (buffer_size, *observation_shape), dtype=t.float32
        )
        self.buffer_pointer = 0
        self.buffer_size = buffer_size

        self.rng = np.random.default_rng(seed)

    def right_push_to_tensor(self, tensor, x):
        return torch.cat((tensor[1:], t.tensor([x], dtype=tensor.dtype)))

    def add(
        self,
        obs: np.ndarray,
        actions: np.ndarray,
        rewards: np.ndarray,
        dones: np.ndarray,
        next_obs: np.ndarray,
    ) -> None:
        """
        obs: shape (num_environments, *observation_shape)
            Observation before the action
        actions: shape (num_environments, )
            Action chosen by the agent
        rewards: shape (num_environments, )
            Reward after the action
        dones: shape (num_environments, )
            If True, the episode ended and was reset automatically
        next_obs: shape (num_environments, *observation_shape)
            Observation after the action
            If done is True, this should be the terminal observation, NOT the first 
            observation of the next episode.
        """
        pointer = self.buffer_pointer % self.buffer_size
        self.observations[pointer] = t.tensor(obs)
        self.actions[pointer] = t.tensor(actions)
        self.rewards[pointer] = t.tensor(rewards)
        self.dones[pointer] = t.tensor(dones)
        self.next_observations[pointer] = t.tensor(next_obs)

        self.buffer_pointer += 1

    def sample(self, sample_size: int, device: t.device) -> ReplayBufferSamples:
        """Uniformly sample sample_size entries from the buffer and convert them to 
           PyTorch tensors on device.

        Sampling is with replacement, and sample_size may be larger than the buffer size.
        """
        sample_idx = self.rng.integers(
            low=0, high=min(self.buffer_pointer, self.buffer_size), size=sample_size
        )
        obs = self.observations[sample_idx].to(device)
        act = self.actions[sample_idx].to(device)
        rew = self.rewards[sample_idx].to(device)
        don = self.dones[sample_idx].to(device)
        nex_obs = self.next_observations[sample_idx].to(device)

        return ReplayBufferSamples(obs, act, rew, don, nex_obs)


# if MAIN:
#     utils.test_replay_buffer_single(ReplayBuffer)
#     utils.test_replay_buffer_deterministic(ReplayBuffer)
#     utils.test_replay_buffer_wraparound(ReplayBuffer)


In [5]:
# if MAIN:
#     rb = ReplayBuffer(
#         buffer_size=256,
#         num_actions=2,
#         observation_shape=(4,),
#         num_environments=1,
#         seed=0,
#     )
#     envs = gym.vector.SyncVectorEnv(
#         [utils.make_env("CartPole-v1", 0, 0, False, "test")]
#     )
#     obs = envs.reset()
#     for i in range(512):
#         actions = np.array([0])
#         (next_obs, rewards, dones, infos) = envs.step(actions)
#         real_next_obs = next_obs.copy()
#         for (i, done) in enumerate(dones):
#             if done:
#                 real_next_obs[i] = infos[i]["terminal_observation"]
#         rb.add(obs, actions, rewards, dones, next_obs)
#         obs = next_obs
#     sample = rb.sample(128, t.device("cpu"))
#     columns = ["cart_pos", "cart_v", "pole_angle", "pole_v"]
#     df = pd.DataFrame(rb.observations, columns=columns)
#     df.plot(subplots=True, title="Replay Buffer")
#     df2 = pd.DataFrame(sample.observations, columns=columns)
#     df2.plot(subplots=True, title="Shuffled Replay Buffer")


In [6]:
def linear_schedule(
    current_step: int,
    start_e: float,
    end_e: float,
    exploration_fraction: float,
    total_timesteps: int,
) -> float:
    """Return the appropriate epsilon for the current step.

    Epsilon should be start_e at step 0 and decrease linearly to end_e at step 
    (exploration_fraction * total_timesteps).

    It should stay at end_e for the rest of the episode.
    """
    return start_e + (end_e - start_e) * min(
        current_step / (exploration_fraction * total_timesteps), 1
    )


# if MAIN:
#     epsilons = [
#         linear_schedule(
#             step, start_e=1.0, end_e=0.05, exploration_fraction=0.5, total_timesteps=500
#         )
#         for step in range(500)
#     ]
#     utils.test_linear_schedule(linear_schedule)


In [7]:
def epsilon_greedy_policy(
    envs: gym.vector.SyncVectorEnv,
    q_network: QNetwork,
    rng: Generator,
    obs: t.Tensor,
    epsilon: float,
) -> np.ndarray:
    """With probability epsilon, take a random action. Otherwise, take a greedy action 
        according to the q_network.
    Inputs:
        envs : gym.vector.SyncVectorEnv, the family of environments to run against
        q_network : QNetwork, the network used to approximate the Q-value function
        obs : The current observation
        epsilon : exploration percentage
    Outputs:
        actions: (n_environments, ) the sampled action for each environment.
    """
    determiner = rng.random()
    n_envs = envs.num_envs
    n_actions = envs.single_action_space.n
    if determiner < epsilon:
        return rng.integers(0, n_actions, n_envs)
    else:
        out = q_network(obs).detach().numpy()
        out = out.argmax(axis=1)
        return out


# if MAIN:
#     utils.test_epsilon_greedy_policy(epsilon_greedy_policy)


In [8]:
ObsType = np.ndarray
ActType = int


class Probe1(gym.Env):
    """One action, observation of [0.0], one timestep long, +1 reward.

    We expect the agent to rapidly learn that the value of the constant [0.0] 
    observation is +1.0. Note we're using a continuous observation space for consistency 
    with CartPole.
    """

    action_space: Discrete
    observation_space: Box

    def __init__(self):
        super().__init__()
        self.observation_space = Box(np.array([0]), np.array([0]))
        self.action_space = Discrete(1)
        self.reset()

    def step(self, action: ActType) -> tuple[ObsType, float, bool, dict]:
        return (np.array([0]), 1.0, True, {})

    def reset(
        self, seed: Optional[int] = None, return_info=False, options=None
    ) -> Union[ObsType, tuple[ObsType, dict]]:
        super().reset(seed=seed)
        if return_info:
            return (np.array([0.0]), {})
        return np.array([0.0])


gym.envs.registration.register(id="Probe1-v0", entry_point=Probe1)
if MAIN:
    env = gym.make("Probe1-v0")
    assert env.observation_space.shape == (1,)
    assert env.action_space.shape == ()


  logger.warn(f"Overriding environment {id}")
  logger.warn(f"Box bound precision lowered by casting to {self.dtype}")


In [9]:
class Probe2(gym.Env):
    """One action, observation of [-1.0] or [+1.0], one timestep long, reward equals 
       observation.

    We expect the agent to rapidly learn the value of each observation is equal to the 
    observation.
    """

    action_space: Discrete
    observation_space: Box

    def __init__(self):
        pass

    def step(self, action: ActType) -> tuple[ObsType, float, bool, dict]:
        pass

    def reset(
        self, seed: Optional[int] = None, return_info=False, options=None
    ) -> Union[ObsType, tuple[ObsType, dict]]:
        pass


gym.envs.registration.register(id="Probe2-v0", entry_point=Probe2)


class Probe3(gym.Env):
    """One action, [0.0] then [1.0] observation, two timesteps, +1 reward at the end.

    We expect the agent to rapidly learn the discounted value of the initial observation.
    """

    action_space: Discrete
    observation_space: Box

    def __init__(self):
        pass

    def step(self, action: ActType) -> tuple[ObsType, float, bool, dict]:
        pass

    def reset(
        self, seed: Optional[int] = None, return_info=False, options=None
    ) -> Union[ObsType, tuple[ObsType, dict]]:
        pass


gym.envs.registration.register(id="Probe3-v0", entry_point=Probe3)


class Probe4(gym.Env):
    """Two actions, [0.0] observation, one timestep, reward is -1.0 or +1.0 dependent on 
       the action.

    We expect the agent to learn to choose the +1.0 action.
    """

    action_space: Discrete
    observation_space: Box

    def __init__(self):
        pass

    def step(self, action: ActType) -> tuple[ObsType, float, bool, dict]:
        pass

    def reset(
        self, seed: Optional[int] = None, return_info=False, options=None
    ) -> Union[ObsType, tuple[ObsType, dict]]:
        pass


gym.envs.registration.register(id="Probe4-v0", entry_point=Probe4)


class Probe5(gym.Env):
    """Two actions, random 0/1 observation, one timestep, reward is 1 if action equals 
       observation otherwise -1.

    We expect the agent to learn to match its action to the observation.
    """

    action_space: Discrete
    observation_space: Box

    def __init__(self):
        pass

    def step(self, action: ActType) -> tuple[ObsType, float, bool, dict]:
        pass

    def reset(
        self, seed: Optional[int] = None, return_info=False, options=None
    ) -> Union[ObsType, tuple[ObsType, dict]]:
        pass


gym.envs.registration.register(id="Probe5-v0", entry_point=Probe5)


  logger.warn(f"Overriding environment {id}")
  logger.warn(f"Overriding environment {id}")
  logger.warn(f"Overriding environment {id}")
  logger.warn(f"Overriding environment {id}")


In [10]:
@dataclass
class DQNArgs:
    exp_name: str = os.path.basename(
        globals().get("__file__", "DQN_implementation").rstrip(".py")
    )
    seed: int = 1
    torch_deterministic: bool = True
    cuda: bool = True
    track: bool = True
    wandb_project_name: str = "Curt-CartPoleDQN"
    wandb_entity: Optional[str] = None
    capture_video: bool = True
    env_id: str = "CartPole-v1"
    total_timesteps: int = 500000
    learning_rate: float = 0.00025
    buffer_size: int = 10000
    gamma: float = 0.99
    target_network_frequency: int = 500
    batch_size: int = 128
    start_e: float = 1.0
    end_e: float = 0.05
    exploration_fraction: float = 0.5
    learning_starts: int = 10000
    train_frequency: int = 10


arg_help_strings = dict(
    exp_name="the name of this experiment",
    seed="seed of the experiment",
    torch_deterministic="if toggled, `torch.backends.cudnn.deterministic=False`",
    cuda="if toggled, cuda will be enabled by default",
    track="if toggled, this experiment will be tracked with Weights and Biases",
    wandb_project_name="the wandb's project name",
    wandb_entity="the entity (team) of wandb's project",
    capture_video="whether to capture videos of the agent performances (check out `videos` folder)",
    env_id="the id of the environment",
    total_timesteps="total timesteps of the experiments",
    learning_rate="the learning rate of the optimizer",
    buffer_size="the replay memory buffer size",
    gamma="the discount factor gamma",
    target_network_frequency="the timesteps it takes to update the target network",
    batch_size="the batch size of samples from the replay memory",
    start_e="the starting epsilon for exploration",
    end_e="the ending epsilon for exploration",
    exploration_fraction="the fraction of `total-timesteps` it takes from start-e to go end-e",
    learning_starts="timestep to start learning",
    train_frequency="the frequency of training",
)
toggles = ["torch_deterministic", "cuda", "track", "capture_video"]


def parse_args(arg_help_strings=arg_help_strings, toggles=toggles) -> DQNArgs:
    parser = argparse.ArgumentParser()
    for (name, field) in DQNArgs.__dataclass_fields__.items():
        flag = "--" + name.replace("_", "-")
        type_function = (
            field.type if field.type != bool else lambda x: bool(strtobool(x))
        )
        toggle_kwargs = {"nargs": "?", "const": True} if name in toggles else {}
        parser.add_argument(
            flag,
            type=type_function,
            default=field.default,
            help=arg_help_strings[name],
            **toggle_kwargs,
        )
    return DQNArgs(**vars(parser.parse_args()))


def setup(
    args: DQNArgs,
) -> Tuple[str, SummaryWriter, np.random.Generator, t.device, gym.vector.SyncVectorEnv]:
    """Helper function to set up useful variables for the DQN implementation"""
    run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
    if args.track:
        import wandb

        wandb.init(
            project=args.wandb_project_name,
            entity=args.wandb_entity,
            sync_tensorboard=True,
            config=vars(args),
            name=run_name,
            monitor_gym=True,
            save_code=True,
        )
    writer = SummaryWriter(f"runs/{run_name}")
    writer.add_text(
        "hyperparameters",
        "|param|value|\n|-|-|\n%s"
        % "\n".join([f"|{key}|{value}|" for (key, value) in vars(args).items()]),
    )
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.backends.cudnn.deterministic = args.torch_deterministic
    rng = np.random.default_rng(args.seed)
    device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")
    envs = gym.vector.SyncVectorEnv(
        [utils.make_env(args.env_id, args.seed, 0, args.capture_video, run_name)]
    )
    assert isinstance(
        envs.single_action_space, Discrete
    ), "only discrete action space is supported"
    return (run_name, writer, rng, device, envs)


def log(
    writer: SummaryWriter,
    start_time: float,
    step: int,
    predicted_q_vals: t.Tensor,
    loss: Union[float, t.Tensor],
    infos: Iterable[dict],
    epsilon: float,
):
    """Helper function to write relevant info to TensorBoard logs, and print some things to stdout"""
    if step % 100 == 0:
        writer.add_scalar("losses/td_loss", loss, step)
        writer.add_scalar("losses/q_values", predicted_q_vals.mean().item(), step)
        writer.add_scalar("charts/SPS", int(step / (time.time() - start_time)), step)
        if step % 10000 == 0:
            print("SPS:", int(step / (time.time() - start_time)))
    episodic_return = 0
    for info in infos:
        if "episode" in info.keys():
            print(f"global_step={step}, episodic_return={info['episode']['r']}")
            episodic_return = info['episode']['r']
            writer.add_scalar("charts/episodic_return", info["episode"]["r"], step)
            writer.add_scalar("charts/episodic_length", info["episode"]["l"], step)
            writer.add_scalar("charts/epsilon", epsilon, step)
            break
    # OPTIONAL: ADD CODE HERE TO LOG TO WANDB
    wandb.log({"loss":loss, "episodic_return":episodic_return})


if MAIN:
    if "ipykernel_launcher" in os.path.basename(sys.argv[0]):
        filename = globals().get("__file__", "<filename of this script>")
        print(
            f"Try running this file from the command line instead: python {os.path.basename(filename)} --help"
        )
        args = DQNArgs()
    else:
        args = parse_args()
    # train_dqn(args)


Try running this file from the command line instead: python <filename of this script> --help


In [11]:
def train_dqn(args: DQNArgs):
    (run_name, writer, rng, device, envs) = setup(args)
    "YOUR CODE: Create your Q-network, Adam optimizer, and replay buffer here."
    n_actions = envs.single_action_space.n
    obs_dim = envs.single_observation_space.shape
    num_obs = np.array(obs_dim, dtype=int).prod()

    q_network = QNetwork(num_obs, n_actions).to(device)
    target_network = QNetwork(num_obs, n_actions).to(device)
    target_network.load_state_dict(q_network.state_dict())

    rb = ReplayBuffer(args.buffer_size, n_actions, obs_dim, envs.num_envs, args.seed)
    optimizer = t.optim.Adam(q_network.parameters(), args.learning_rate)

    start_time = time.time()
    obs = envs.reset()
    for step in range(args.total_timesteps):
        "YOUR CODE: Sample actions according to the epsilon greedy policy using the linear schedule for epsilon, and then step the environment"
        epsilon = linear_schedule(step, args.start_e, args.end_e, args.exploration_fraction, args.total_timesteps)
        actions = epsilon_greedy_policy(envs, q_network, rng, t.Tensor(obs).to(dtype=t.float32).to(device), epsilon)
        next_obs, rewards, dones, infos = envs.step(actions)

        "Boilerplate to handle the terminal observation case"
        real_next_obs = next_obs.copy()
        for (i, done) in enumerate(dones):
            if done:
                real_next_obs[i] = infos[i]["terminal_observation"]
                
        rb.add(obs, actions, rewards, dones, next_obs)
        obs = next_obs
        if step > args.learning_starts and step % args.train_frequency == 0:
            "YOUR CODE: Sample from the replay buffer, compute the TD target, compute TD loss, and perform an optimizer step."
            sample = rb.sample(args.batch_size, device)
            s, a, r, d, s_new = sample.observations, sample.actions, sample.rewards, sample.dones, sample.next_observations

            with t.inference_mode():
                target_max = target_network(t.Tensor(s_new).to(dtype=t.float32).to(device)).max(-1).values
            predicted_q_vals = q_network(s)[t.arange(args.batch_size), a.flatten()]

            delta = r.flatten() + args.gamma * target_max * (1 - d.float().flatten()) - predicted_q_vals
            loss = delta.pow(2).sum() / args.buffer_size
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            log(writer, start_time, step, predicted_q_vals, loss, infos, epsilon)
        
        if step % args.target_network_frequency == 0:
            "(4) YOUR CODE: Copy weights to the target network"
            target_network.load_state_dict(q_network.state_dict())

    "If running one of the Probe environments, will test if the learned q-values are\n    sensible after training. Useful for debugging."
    probe_batches = [
        t.tensor([[0.0]]),
        t.tensor([[-1.0], [+1.0]]),
        t.tensor([[0.0], [1.0]]),
        t.tensor([[0.0]]),
        t.tensor([[0.0], [1.0]]),
    ]
    if re.match(r"Probe(\d)-v0", args.env_id):
        probe_no = int(re.match(r"Probe(\d)-v0", args.env_id).group(1))
        batch = probe_batches[probe_no]
        value = q_network(batch)
        print("Value: ", value)
        expected = t.tensor([[1.0]]).to(device)
        t.testing.assert_close(value, expected, 0.0001)

    envs.close()
    writer.close()


In [12]:
train_dqn(args)

  return LooseVersion(v) >= LooseVersion(check)
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mcurt-tigges[0m ([33marena-ldn[0m). Use [1m`wandb login --relogin`[0m to force relogin
  from IPython.core.display import HTML, display  # type: ignore


  loader = importlib.find_loader(fullname, path)
  deprecation(
  logger.deprecation(
  if distutils.version.LooseVersion(
  logger.deprecation(
  if distutils.version.LooseVersion(
  logger.deprecation(
  if distutils.version.LooseVersion(
  logger.deprecation(
  if distutils.version.LooseVersion(
  logger.deprecation(
  if distutils.version.LooseVersion(
  logger.deprecation(
  if distutils.version.LooseVersion(
  logger.deprecation(
  if distutils.version.LooseVersion(
  logger.deprecation(
  if distutils.version.LooseVersion(


global_step=10120, episodic_return=14.0
global_step=10470, episodic_return=13.0
global_step=10540, episodic_return=41.0
global_step=10640, episodic_return=13.0
global_step=10780, episodic_return=33.0
global_step=11010, episodic_return=9.0
global_step=11140, episodic_return=14.0


  logger.deprecation(
  if distutils.version.LooseVersion(


global_step=12040, episodic_return=14.0
global_step=12310, episodic_return=16.0
global_step=12640, episodic_return=12.0
global_step=12950, episodic_return=22.0
global_step=13010, episodic_return=27.0
global_step=13130, episodic_return=11.0
global_step=13670, episodic_return=24.0
global_step=13770, episodic_return=12.0
global_step=13810, episodic_return=23.0
global_step=14030, episodic_return=27.0
global_step=14320, episodic_return=15.0
global_step=14400, episodic_return=32.0
global_step=15110, episodic_return=18.0
global_step=15470, episodic_return=17.0
global_step=15500, episodic_return=15.0
global_step=15590, episodic_return=18.0
global_step=15720, episodic_return=31.0
global_step=15810, episodic_return=16.0


  logger.deprecation(
  if distutils.version.LooseVersion(


global_step=16190, episodic_return=28.0
global_step=16230, episodic_return=15.0
global_step=16580, episodic_return=41.0
global_step=16600, episodic_return=20.0
global_step=16660, episodic_return=15.0
global_step=16870, episodic_return=18.0
global_step=17180, episodic_return=48.0
global_step=17400, episodic_return=12.0
global_step=18150, episodic_return=15.0
global_step=19030, episodic_return=17.0
global_step=19090, episodic_return=11.0
global_step=19240, episodic_return=31.0
global_step=19290, episodic_return=19.0
global_step=19650, episodic_return=13.0
global_step=19780, episodic_return=25.0
global_step=19850, episodic_return=23.0
global_step=19900, episodic_return=12.0
SPS: 2198
global_step=20010, episodic_return=52.0
global_step=20300, episodic_return=31.0
global_step=20420, episodic_return=11.0
global_step=20610, episodic_return=16.0
global_step=21280, episodic_return=25.0
global_step=21430, episodic_return=16.0
global_step=21990, episodic_return=35.0
global_step=22440, episodic_re

  logger.deprecation(
  if distutils.version.LooseVersion(


global_step=22980, episodic_return=14.0
global_step=23260, episodic_return=25.0
global_step=23660, episodic_return=36.0
global_step=23870, episodic_return=24.0
global_step=23930, episodic_return=29.0
global_step=24020, episodic_return=26.0
global_step=24630, episodic_return=23.0
global_step=25690, episodic_return=35.0
global_step=25750, episodic_return=46.0
global_step=25810, episodic_return=29.0
global_step=25920, episodic_return=41.0
global_step=26040, episodic_return=14.0
global_step=26630, episodic_return=54.0
global_step=26970, episodic_return=43.0
global_step=27030, episodic_return=16.0
global_step=27060, episodic_return=30.0
global_step=27760, episodic_return=29.0
global_step=28060, episodic_return=33.0
global_step=28220, episodic_return=17.0
global_step=28750, episodic_return=16.0
global_step=29000, episodic_return=18.0
global_step=29170, episodic_return=16.0
global_step=29310, episodic_return=52.0
global_step=29440, episodic_return=13.0
global_step=29510, episodic_return=28.0


  logger.deprecation(
  if distutils.version.LooseVersion(


global_step=53930, episodic_return=17.0
global_step=54830, episodic_return=47.0
global_step=55000, episodic_return=73.0
global_step=55190, episodic_return=71.0
global_step=55670, episodic_return=11.0
global_step=55790, episodic_return=66.0
global_step=56310, episodic_return=15.0
global_step=56620, episodic_return=18.0
global_step=57490, episodic_return=74.0
global_step=57810, episodic_return=32.0
global_step=58160, episodic_return=29.0
global_step=58390, episodic_return=42.0
global_step=58460, episodic_return=48.0
global_step=58790, episodic_return=79.0
global_step=58870, episodic_return=31.0
global_step=59200, episodic_return=99.0
SPS: 2201
global_step=60010, episodic_return=79.0
global_step=60420, episodic_return=89.0
global_step=60560, episodic_return=59.0
global_step=60640, episodic_return=37.0
global_step=60900, episodic_return=35.0
global_step=60940, episodic_return=40.0
global_step=61150, episodic_return=22.0
global_step=62240, episodic_return=46.0
global_step=62410, episodic_re

  logger.deprecation(
  if distutils.version.LooseVersion(


global_step=109880, episodic_return=117.0
SPS: 2199
global_step=110530, episodic_return=48.0
global_step=110990, episodic_return=25.0
global_step=111970, episodic_return=65.0
global_step=112180, episodic_return=103.0
global_step=112420, episodic_return=101.0
global_step=112430, episodic_return=10.0
global_step=112630, episodic_return=200.0
global_step=113410, episodic_return=78.0
global_step=113940, episodic_return=94.0
global_step=114250, episodic_return=16.0
global_step=117080, episodic_return=49.0
global_step=117830, episodic_return=44.0
global_step=119390, episodic_return=45.0
SPS: 2197
global_step=120530, episodic_return=64.0
global_step=120760, episodic_return=184.0
global_step=121260, episodic_return=164.0
global_step=121530, episodic_return=182.0
global_step=122660, episodic_return=99.0
global_step=123510, episodic_return=64.0
global_step=124170, episodic_return=185.0
global_step=124470, episodic_return=214.0
global_step=124780, episodic_return=24.0
global_step=125810, episodic

  logger.deprecation(
  if distutils.version.LooseVersion(


SPS: 2021
global_step=421880, episodic_return=424.0
SPS: 2005
global_step=435740, episodic_return=206.0
SPS: 1989
SPS: 1973
SPS: 1965
SPS: 1957
SPS: 1952
SPS: 1945


In [None]:
if MAIN:
    if "ipykernel_launcher" in os.path.basename(sys.argv[0]):
        filename = globals().get("__file__", "<filename of this script>")
        print(
            f"Try running this file from the command line instead: python {os.path.basename(filename)} --help"
        )
        args = DQNArgs()
    else:
        args = parse_args()
    # train_dqn(args)
