In [None]:
import sys
import os
from typing import Dict, List, Tuple

import gymnasium as gym
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from IPython.display import clear_output
import random
import time


## Replay buffer
class ReplayBuffer:
    """A simple numpy replay buffer."""

    def __init__(self, obs_dim: int, size: int, batch_size: int = 32):
        self.obs_buf = np.zeros([size, obs_dim], dtype=np.float32)
        self.next_obs_buf = np.zeros([size, obs_dim], dtype=np.float32)
        self.acts_buf = np.zeros([size], dtype=np.float32)
        self.rews_buf = np.zeros([size], dtype=np.float32)
        self.done_buf = np.zeros(size, dtype=np.float32)
        self.max_size, self.batch_size = size, batch_size
        (
            self.ptr,
            self.size,
        ) = (
            0,
            0,
        )

    def store(
        self,
        obs: np.ndarray,
        act: np.ndarray,
        rew: float,
        next_obs: np.ndarray,
        done: bool,
    ):
        self.obs_buf[self.ptr] = obs
        self.next_obs_buf[self.ptr] = next_obs
        self.acts_buf[self.ptr] = act
        self.rews_buf[self.ptr] = rew
        self.done_buf[self.ptr] = done
        self.ptr = (self.ptr + 1) % self.max_size
        self.size = min(self.size + 1, self.max_size)

    def sample_batch(self) -> Dict[str, np.ndarray]:
        idxs = np.random.choice(self.size, size=self.batch_size, replace=False)
        print(idxs)
        return dict(
            obs=self.obs_buf[idxs],
            next_obs=self.next_obs_buf[idxs],
            acts=self.acts_buf[idxs],
            rews=self.rews_buf[idxs],
            done=self.done_buf[idxs],
        )

    def __len__(self) -> int:
        return self.size


## Network
class Network(nn.Module):
    def __init__(self, in_dim: int, out_dim: int):
        """Initialization."""
        super(Network, self).__init__()

        self.layers = nn.Sequential(
            nn.Linear(in_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, out_dim),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward method implementation."""
        return self.layers(x)


## Double DQN Agent
class DQNAgent:
    """DQN Agent interacting with environment.

    Attribute:
        env (gym.Env): openAI Gym environment
        memory (ReplayBuffer): replay memory to store transitions
        batch_size (int): batch size for sampling
        epsilon (float): parameter for epsilon greedy policy
        epsilon_decay (float): step size to decrease epsilon
        max_epsilon (float): max value of epsilon
        min_epsilon (float): min value of epsilon
        target_update (int): period for target model's hard update
        gamma (float): discount factor
        dqn (Network): model to train and select actions
        dqn_target (Network): target model to update
        optimizer (torch.optim): optimizer for training dqn
        transition (list): transition information including
                           state, action, reward, next_state, done
    """

    def __init__(
        self,
        env: gym.Env,
        memory_size: int,
        batch_size: int,
        target_update: int,
        epsilon_decay: float,
        seed: int,
        max_epsilon: float = 1.0,
        min_epsilon: float = 0.0,
        gamma: float = 0.99,
    ):
        """Initialization.

        Args:
            env (gym.Env): openAI Gym environment
            memory_size (int): length of memory
            batch_size (int): batch size for sampling
            target_update (int): period for target model's hard update
            epsilon_decay (float): step size to decrease epsilon
            lr (float): learning rate
            max_epsilon (float): max value of epsilon
            min_epsilon (float): min value of epsilon
            gamma (float): discount factor
        """
        obs_dim = env.observation_space.shape[0]
        action_dim = env.action_space.n

        self.env = env
        self.memory = ReplayBuffer(obs_dim, memory_size, batch_size)
        self.batch_size = batch_size
        self.epsilon = max_epsilon
        self.epsilon_decay = epsilon_decay
        self.seed = seed
        self.max_epsilon = max_epsilon
        self.min_epsilon = min_epsilon
        self.target_update = target_update
        self.gamma = gamma

        # device: cpu / gpu
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(self.device)

        # networks: dqn, dqn_target
        self.dqn = Network(obs_dim, action_dim).to(self.device)
        self.dqn_target = Network(obs_dim, action_dim).to(self.device)
        self.dqn_target.load_state_dict(self.dqn.state_dict())
        self.dqn_target.eval()

        # optimizer
        self.optimizer = optim.Adam(self.dqn.parameters())

        # transition to store in memory
        self.transition = list()

        # mode: train / test
        self.is_test = False

    def select_action(self, state: np.ndarray) -> np.ndarray:
        """Select an action from the input state."""
        # epsilon greedy policy
        rand = np.random.rand()
        print(rand)
        # if rand < epsilon:
        #     print("selecting a random move")
        #     if "legal_moves" in info:
        #         # print("using legal moves")
        #         return random.choice(info["legal_moves"])
        #     else:
        #         q_values = q_values.reshape(-1)
        #         return random.choice(len(q_values))
        # else:
        #     # try:
        #     # print("using provided wrapper to select action")
        #     return wrapper(q_values, info)
        print(self.epsilon)
        if self.epsilon > rand:
            print("random action!")
            selected_action = random.choice(range(self.env.action_space.n))
        else:
            selected_action = self.dqn(
                torch.FloatTensor(state).to(self.device)
            ).argmax()
            # selected_action = torch.stack(
            #     [
            #         torch.tensor(
            #             np.random.choice(np.where(x.cpu() == x.cpu().max())[0])
            #         )
            #         for x in self.dqn(torch.FloatTensor(state).to(self.device))
            #     ]
            # )

            selected_action = selected_action.detach().cpu().numpy()

        if not self.is_test:
            self.transition = [state, selected_action]

        return selected_action

    def step(self, action: np.ndarray) -> Tuple[np.ndarray, np.float64, bool]:
        """Take an action and return the response of the env."""
        next_state, reward, terminated, truncated, _ = self.env.step(action)
        done = terminated or truncated

        if not self.is_test:
            self.transition += [reward, next_state, done]
            self.memory.store(*self.transition)

        return next_state, reward, done

    def update_model(self) -> torch.Tensor:
        """Update the model by gradient descent."""
        samples = self.memory.sample_batch()

        loss = self._compute_dqn_loss(samples)

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        return loss.item()

    def train(self, num_frames: int, plotting_interval: int = 200):
        """Train the agent."""
        self.is_test = False

        state, _ = self.env.reset(seed=self.seed)
        update_cnt = 0
        epsilons = []
        losses = []
        scores = []
        score = 0

        for frame_idx in range(0, num_frames):
            print("training step", frame_idx)
            action = self.select_action(state)
            print("action", action)
            next_state, reward, done = self.step(action)

            state = next_state
            score += reward

            # if episode ends
            if done:
                state, _ = self.env.reset(seed=self.seed)
                scores.append(score)
                score = 0

            self.epsilon = max(
                self.min_epsilon,
                self.epsilon
                - (self.max_epsilon - self.min_epsilon) * self.epsilon_decay,
            )
            epsilons.append(self.epsilon)
            if frame_idx % self.target_update == 0:
                self._target_hard_update()

            # if training is ready
            if len(self.memory) >= self.batch_size:
                loss = self.update_model()
                losses.append(loss)
                update_cnt += 1

                # linearly decrease epsilon

                # if hard update is needed

            # plotting
        #     if frame_idx % plotting_interval == 0:
        #         self._plot(frame_idx, scores, losses, epsilons)
        # self._plot(frame_idx, scores, losses, epsilons)
        self.env.close()

    def test(self, video_folder: str) -> None:
        """Test the agent."""
        self.is_test = True

        # for recording a video
        naive_env = self.env
        self.env = gym.wrappers.RecordVideo(self.env, video_folder=video_folder)

        state, _ = self.env.reset(seed=self.seed)
        done = False
        score = 0

        while not done:
            action = self.select_action(state)
            next_state, reward, done = self.step(action)

            state = next_state
            score += reward

        print("score: ", score)
        self.env.close()

        # reset
        self.env = naive_env

    def _compute_dqn_loss(self, samples: Dict[str, np.ndarray]) -> torch.Tensor:
        """Return dqn loss."""
        device = self.device  # for shortening the following lines
        state = torch.FloatTensor(samples["obs"]).to(device)
        next_state = torch.FloatTensor(samples["next_obs"]).to(device)
        action = torch.LongTensor(samples["acts"].reshape(-1, 1)).to(device)
        reward = torch.FloatTensor(samples["rews"].reshape(-1, 1)).to(device)
        done = torch.FloatTensor(samples["done"].reshape(-1, 1)).to(device)
        print(action)
        print(state)
        # G_t   = r + gamma * v(s_{t+1})  if state != Terminal
        #       = r                       otherwise
        # for param in self.dqn.parameters():
        #     print(param)
        curr_q_value = self.dqn(state).gather(1, action)
        print(self.dqn(state))
        print(curr_q_value)
        print(self.dqn(next_state))
        next_q_value = (
            self.dqn_target(next_state)
            .gather(1, self.dqn(next_state).argmax(dim=1, keepdim=True))  # Double DQN
            .detach()
        )
        # print(next_q_value)
        mask = 1 - done
        target = (reward + self.gamma * next_q_value * mask).to(self.device)
        # print(target)
        # calculate dqn loss
        loss = F.smooth_l1_loss(curr_q_value, target)
        print(loss)
        return loss

    def _target_hard_update(self):
        """Hard update: target <- local."""
        self.dqn_target.load_state_dict(self.dqn.state_dict())

    def _plot(
        self,
        frame_idx: int,
        scores: List[float],
        losses: List[float],
        epsilons: List[float],
    ):
        """Plot the training progresses."""
        clear_output(True)
        plt.figure(figsize=(20, 5))
        plt.subplot(121)
        plt.title("frame %s. score: %s" % (frame_idx, np.mean(scores[-10:])))
        plt.plot(scores)
        plt.subplot(122)
        plt.title("loss")
        plt.plot(losses)
        # plt.subplot(133)
        # plt.title("epsilons")
        # plt.plot(epsilons)
        plt.show()


## Environment
# environment
env = gym.make("CartPole-v1", render_mode="rgb_array")

## Set random seed

import torch
import random
import numpy as np

seed = 777


def seed_torch(seed):
    torch.manual_seed(seed)
    if torch.backends.cudnn.enabled:
        torch.cuda.manual_seed(seed)
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True


np.random.seed(seed)
random.seed(seed)
seed_torch(seed)
## Initialize

# parameters
num_frames = 10000
memory_size = 1000
batch_size = 32
target_update = 200
epsilon_decay = 1 / 2000

# train
agent = DQNAgent(env, memory_size, batch_size, target_update, epsilon_decay, seed)

# for i in range(75):
#   print(i, random.choice(range(2)))

## Train

agent.train(num_frames)

In [None]:
import torch
import random
import numpy as np

seed = 777


def seed_torch(seed):
    torch.manual_seed(seed)
    if torch.backends.cudnn.enabled:
        torch.cuda.manual_seed(seed)
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True


np.random.seed(seed)
random.seed(seed)
seed_torch(seed)
for i in range(4000):
    print(i, np.random.rand())

In [1]:
import torch
import random
import numpy as np

seed = 777


def seed_torch(seed):
    torch.manual_seed(seed)
    if torch.backends.cudnn.enabled:
        torch.cuda.manual_seed(seed)
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True


np.random.seed(seed)
random.seed(seed)
seed_torch(seed)

In [None]:
import gc
from math import e
from time import time
from pkg_resources import get_distribution
import torch
from torch.nn.utils import clip_grad_norm_
from torch.optim.sgd import SGD
from torch.optim.adam import Adam
import numpy as np
from agent_configs import RainbowConfig
from utils import (
    update_per_beta,
    get_legal_moves,
    current_timestamp,
    action_mask,
    epsilon_greedy_policy,
    CategoricalCrossentropyLoss,
    HuberLoss,
    KLDivergenceLoss,
    MSELoss,
    update_inverse_sqrt_schedule,
    update_linear_schedule,
)

import sys

from utils.utils import epsilon_greedy_policy

sys.path.append("../../")

from base_agent.agent import BaseAgent
from replay_buffers.prioritized_n_step_replay_buffer import PrioritizedNStepReplayBuffer
from dqn.rainbow.rainbow_network import RainbowNetwork


class RainbowAgent(BaseAgent):
    def __init__(
        self,
        env,
        config: RainbowConfig,
        name=f"rainbow_{current_timestamp():.1f}",
        device: torch.device = (
            torch.device("cuda")
            if torch.cuda.is_available()
            # MPS is sometimes useful for M2 instances, but only for large models/matrix multiplications otherwise CPU is faster
            else (
                torch.device("mps")
                if torch.backends.mps.is_available() and torch.backends.mps.is_built()
                else torch.device("cpu")
            )
        ),
        num_players: int = 1,
    ):
        super(RainbowAgent, self).__init__(env, config, name, device=device)
        self.model = RainbowNetwork(
            config=config,
            output_size=self.num_actions,
            input_shape=(self.config.minibatch_size,) + self.observation_dimensions,
        )
        self.target_model = RainbowNetwork(
            config=config,
            output_size=self.num_actions,
            input_shape=(self.config.minibatch_size,) + self.observation_dimensions,
        )

        self.model.to(device)
        self.target_model.to(device)
        self.target_model.load_state_dict(self.model.state_dict())
        self.target_model.eval()

        self.optimizer: torch.optim.Optimizer = self.config.optimizer(
            params=self.model.parameters(),
            lr=self.config.learning_rate,
            eps=self.config.adam_epsilon,
            weight_decay=self.config.weight_decay,
        )

        self.replay_buffer = PrioritizedNStepReplayBuffer(
            observation_dimensions=self.observation_dimensions,
            observation_dtype=self.env.observation_space.dtype,
            max_size=self.config.replay_buffer_size,
            batch_size=self.config.minibatch_size,
            max_priority=1.0,
            alpha=self.config.per_alpha,
            beta=self.config.per_beta,
            # epsilon=config["per_epsilon"],
            n_step=self.config.n_step,
            gamma=self.config.discount_factor,
            compressed_observations=(
                self.env.lz4_compress if hasattr(self.env, "lz4_compress") else False
            ),
            num_players=num_players,
        )

        self.eg_epsilon = self.config.eg_epsilon

        self.stats = {
            "score": [],
            "loss": [],
            "test_score": [],
        }
        self.targets = {
            # "score": self.env.spec.reward_threshold,
            # "test_score": self.env.spec.reward_threshold,
        }

    def predict(self, states) -> torch.Tensor:
        # could change type later
        state_input = self.preprocess(states)
        q_distribution: torch.Tensor = self.model(state_input)
        return q_distribution

    def predict_target(self, states) -> torch.Tensor:
        # could change type later
        state_input = self.preprocess(states)
        q_distribution: torch.Tensor = self.target_model(state_input)
        return q_distribution

    def select_actions(
        self, distribution, info: dict = None, mask_actions: bool = False
    ):
        assert info is not None if mask_actions else True, "Need info to mask actions"
        # print(info)
        if self.config.atom_size > 1:
            q_values = distribution * self.support
            q_values = q_values.sum(2, keepdim=False)
        else:
            q_values = distribution
        if mask_actions:
            legal_moves = get_legal_moves(info)
            q_values = action_mask(
                q_values, legal_moves, mask_value=-float("inf"), device=self.device
            )
        # print("Q Values", q_values)
        # q_values with argmax ties
        selected_actions = torch.stack(
            [
                torch.tensor(np.random.choice(np.where(x.cpu() == x.cpu().max())[0]))
                for x in q_values
            ]
        )
        # print(selected_actions)
        # selected_actions = q_values.argmax(1, keepdim=False)
        return selected_actions

    def learn(self) -> np.ndarray:
        samples = self.replay_buffer.sample()
        loss = self.learn_from_sample(samples)
        return loss

    def learn_from_sample(self, samples: dict):
        observations, weights, actions = (
            samples["observations"],
            samples["weights"],
            torch.from_numpy(samples["actions"]).to(self.device).long(),
        )
        print("actions", actions)

        print("Observations", observations)
        online_predictions = self.predict(observations)[
            range(self.config.minibatch_size), actions
        ]
        # for param in self.model.parameters():
        #     print(param)
        print(self.predict(observations))
        print(online_predictions)
        # (B, atom_size)
        # print("using default dqn loss")
        next_observations, rewards, dones = (
            torch.from_numpy(samples["next_observations"]).to(self.device),
            torch.from_numpy(samples["rewards"]).to(self.device),
            torch.from_numpy(samples["dones"]).to(self.device),
        )
        target_predictions = self.predict_target(next_observations)  # next q values
        # print("Next q values", target_predictions)
        # print("Current q values", online_predictions)
        print(self.predict(next_observations))
        next_actions = self.select_actions(
            self.predict(next_observations),  # current q values
        )
        print("Next actions", next_actions)
        target_predictions = target_predictions[
            range(self.config.minibatch_size), next_actions
        ]  # this might not work
        print(target_predictions)
        target_predictions = (
            rewards + self.config.discount_factor * (~dones) * target_predictions
        )
        print(target_predictions)
        # print("predicted", online_distributions)
        # print("target", target_distributions)

        loss = self.config.loss_function(online_predictions, target_predictions)
        print("Loss", loss.mean())
        self.optimizer.zero_grad()
        loss.mean().backward()
        self.optimizer.step()
        return loss.detach().to("cpu").mean().item()

    def update_target_model(self):
        self.target_model.load_state_dict(self.model.state_dict())

    def update_eg_epsilon(self, training_step: int):
        # print("decaying eg epsilon linearly")
        self.eg_epsilon = update_linear_schedule(
            self.config.eg_epsilon_final,
            self.config.eg_epsilon_final_step,
            self.config.eg_epsilon,
            training_step,
        )

    def train(self):
        start_time = time()
        score = 0
        target_model_updated = (False, False)  # (score, loss)

        # self.fill_replay_buffer()
        state, info = self.env.reset(seed=777)

        # self.training_steps += self.start_training_step
        for training_step in range(self.start_training_step, self.training_steps):
            print("training step", training_step)
            with torch.no_grad():
                values = self.predict(state)
                # print(values)
                action = epsilon_greedy_policy(
                    values,
                    info,
                    self.eg_epsilon,
                    wrapper=lambda values, info: self.select_actions(values).item(),
                )
                print("Action", action)
                self.update_eg_epsilon(training_step + 1)
                next_state, reward, terminated, truncated, next_info = self.env.step(
                    action
                )
                done = terminated or truncated
                print("State", state)
                self.replay_buffer.store(
                    state, info, action, reward, next_state, next_info, done
                )
                state = next_state
                info = next_info
                score += reward

                if done:
                    state, info = self.env.reset(seed=777)
                    score_dict = {
                        "score": score,
                        "target_model_updated": target_model_updated[0],
                    }
                    self.stats["score"].append(score_dict)
                    target_model_updated = (False, target_model_updated[1])
                    score = 0

                if training_step % self.config.transfer_interval == 0:
                    target_model_updated = (True, True)
                    # stats["test_score"].append(
                    #     {"target_model_weight_update": training_step}
                    # )
                    self.update_target_model()

            print("replay buffer size", len(self.replay_buffer))
            if len(self.replay_buffer) >= self.config.min_replay_buffer_size:
                loss = self.learn()
                # print(losses)
                # could do things other than taking the mean here
                self.stats["loss"].append(
                    {"loss": loss, "target_model_updated": target_model_updated[1]}
                )
                target_model_updated = (target_model_updated[0], False)

            if (
                training_step % self.checkpoint_interval == 0
                and training_step > self.start_training_step
            ):
                # print(self.stats["score"])
                # print(len(self.replay_buffer))
                self.save_checkpoint(
                    training_step,
                    training_step * self.config.replay_interval,
                    time() - start_time,
                )

        self.save_checkpoint(
            training_step,
            training_step * self.config.replay_interval,
            time() - start_time,
        )
        self.env.close()

In [None]:
import gymnasium as gym
import sys

import torch
from utils.utils import HuberLoss

sys.path.append("../..")
from agent_configs import RainbowConfig
from game_configs import CartPoleConfig

env = gym.make("CartPole-v1", render_mode="rgb_array")

config_dict = {
    "dense_layer_widths": [128, 128],
    "value_hidden_layer_widths": [],
    "advatage_hidden_layer_widths": [],
    "adam_epsilon": 1e-8,
    "learning_rate": 0.001,
    "training_steps": 10000,
    "per_epsilon": 0.0001,
    "per_alpha": 0,
    "per_beta": 0,
    "minibatch_size": 32,
    "replay_buffer_size": 1000,
    "min_replay_buffer_size": 32,
    "transfer_interval": 200,
    "n_step": 1,
    "loss_function": HuberLoss(),  # could do categorical cross entropy
    "clipnorm": 0.0,
    "discount_factor": 0.99,
    "atom_size": 1,
    "replay_interval": 1,
    "dueling": False,
    "noisy_sigma": 0.0,
    "eg_epsilon": 1.0,
    "eg_epsilon_final": 0.0,
    "eg_epsilon_final_step": 2000,
    "eg_epsilon_decay_type": "linear",
}
game_config = CartPoleConfig()
config = RainbowConfig(config_dict, game_config)
device = "cuda:0" if torch.cuda.is_available() else "cpu"
agent = RainbowAgent(env, config, name="Rainbow_CartPole-v1-1", device=device)
agent.checkpoint_interval = 200

for param in agent.model.parameters():
    print(param)

print("start")
agent.train()