In [None]:
import sys
IN_COLAB = "google.colab" in sys.modules

import math
import os
from collections import deque
import random
from typing import Deque, Dict, List, Tuple

import gym
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from IPython.display import clear_output
import matplotlib.pyplot as plt
from torch.nn.utils import clip_grad_norm_

if torch.backends.cudnn.enabled:
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

seed = 777
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)

class SumTree:
    data_pointer = 0

    def __init__(self, capacity):
        self.capacity = capacity  # leaf node의 수 = capacity
        self.tree = np.zeros(2 * capacity - 1)  # 총 node의 수 -> 우선순위(priority)를 저장
        self.data = np.zeros(capacity, dtype=object)  # 경험(state, action, reward, next state, done flag로 이루어진 tuple)을 저장
        self.n_entries = 0

    def add(self, priority, data):
        tree_index = self.data_pointer + self.capacity - 1
        self.data[self.data_pointer] = data # update data 프레임
        self.update(tree_index, priority) # leaf(priority) 업데이트
        self.data_pointer += 1  # pointer를 1 증가시킴
        if self.data_pointer >= self.capacity:  # capacity를 넘었다면 첫번째 index로 돌아감
            self.data_pointer = 0
        if self.n_entries < self.capacity:
            self.n_entries += 1

    # leaf priority score 업데이트
    def _propagate(self, idx, change):
        parent = (idx - 1) // 2
        self.tree[parent] += change
        # parent가 0이면 중단. root node에 도달했기 때문
        if parent != 0:
            self._propagate(parent, change)

    def update(self, tree_index, priority):
        change = priority - self.tree[tree_index]
        self.tree[tree_index] = priority
        self._propagate(tree_index, change)

    # 이진 트리 구조를 사용하여 특정 조건을 만족하는 노드를 찾는 재귀 함수(recursive function)
    # 주어진 값 s에 대해 특정 조건을 만족하는 노드의 인덱스를 찾아라.
    def _retrieve(self, idx, s):
        left_child_index = 2 * idx + 1
        right_child_index = left_child_index + 1
        # 현재 노드가 leaf node(자식이 없는 node)인 경우를 검
        if left_child_index >= len(self.tree):
            return idx
        # s가 왼쪽 자식 노드에 저장된 값보다 작거나 같으면, 왼쪽 자식으로 재귀적으로 이동
        if s <= self.tree[left_child_index]:
            return self._retrieve(left_child_index, s)
        else:
            return self._retrieve(right_child_index, s - self.tree[left_child_index])

    def get_leaf(self, s):
        leaf_index = self._retrieve(0, s)
        data_index = leaf_index - self.capacity + 1
        return (leaf_index, self.tree[leaf_index], self.data[data_index])

    # 루트 노드를 반환
    def total_priority(self):
        return self.tree[0]

class PrioritizedReplayBuffer(object):
    PER_e = 0.001 # 어떤 경험을 할 확률이 0이 되지 않도록 하는 hyperparameter
    PER_a = 0.6 # 우선순위가 높은 것과 무작위 샘플링 사이 절충을 하기 위한 hyperparameter
    PER_b = 0.4 # Importance Sampling. 1까지 증가
    PER_b_increment_per_sampling = 0.001

    def __init__(self, capacity):
        self.tree = SumTree(capacity)
        self.capacity = capacity

    # 최대 우선 순위 검색
    def _getPriority(self, error):
        return (error + self.PER_e) ** self.PER_a

    def store(self, error, sample):
        max_priority = self._getPriority(error)
        self.tree.add(max_priority, sample)

    def sample(self, n):
        minibatch = []
        idxs = []
        priority_segment = self.tree.total_priority() / n
        priorities = []
        self.PER_b = np.min([1., self.PER_b + self.PER_b_increment_per_sampling])

        for i in range(n):
            a = priority_segment * i
            b = priority_segment * (i + 1)
            value = np.random.uniform(a, b)
            (idx, p, data) = self.tree.get_leaf(value)
            priorities.append(p)
            minibatch.append(data)
            idxs.append(idx)

        sampling_probabilities = priorities / self.tree.total_priority()
        is_weight = np.power(self.tree.n_entries * sampling_probabilities, -self.PER_b)
        is_weight /= is_weight.max()

        return minibatch, idxs, is_weight

    def batch_update(self, idx, error):
        p = self._getPriority(error)
        self.tree.update(idx, p)

def append_sample(self, state, action, reward, next_state, done):
    # PyTorch 텐서로 변환
    state = torch.FloatTensor(state).unsqueeze(0)
    next_state = torch.FloatTensor(next_state).unsqueeze(0)
    action = torch.LongTensor([action])
    reward = torch.FloatTensor([reward])
    done = torch.FloatTensor([done])

    # Q 값 계산
    with torch.no_grad():
        main_next_q = self.dqn(next_state)
        next_action = main_next_q.max(1)[1].view(1, 1)
        target_next_q = self.dqn_target(next_state)
        target_value = target_next_q.gather(1, next_action).item()

    target_value = reward + (self.gamma * target_value * (1 - done))

    # 현재 상태에 대한 Q 값
    main_q = self.dqn(state).gather(1, action.unsqueeze(1)).item()

    # TD 오차 계산
    td_error = abs(target_value - main_q)

    # 메모리에 경험 저장
    self.MEMORY.store(td_error, (state, action, reward, next_state, done))


class Network(nn.Module):
    def __init__(self, state_size: int, action_size: int,
    ):
        """Initialization."""
        super(Network, self).__init__()

        self.layers = nn.Sequential(
            nn.Linear(state_size, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, action_size)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward method implementation."""
        return self.layers(x)

class DQNAgent:
    """DQN Agent interacting with environment.

    Attribute:
        env (gym.Env): openAI Gym environment
        memory (PrioritizedReplayBuffer): replay memory to store transitions
        batch_size (int): batch size for sampling
        epsilon (float): parameter for epsilon greedy policy
        epsilon_decay (float): step size to decrease epsilon
        max_epsilon (float): max value of epsilon
        min_epsilon (float): min value of epsilon
        target_update (int): period for target model's hard update
        gamma (float): discount factor
        dqn (Network): model to train and select actions
        dqn_target (Network): target model to update
        optimizer (torch.optim): optimizer for training dqn
        transition (list): transition information including
                           state, action, reward, next_state, done
        beta (float): determines how much importance sampling is used
        prior_eps (float): guarantees every transition can be sampled
        use_n_step (bool): whether to use n_step memory
        n_step (int): step number to calculate n-step td error
        memory_n (ReplayBuffer): n-step replay buffer
    """

    def __init__(
        self,
        env: gym.Env,
        memory_size: int,
        batch_size: int,
        target_update: int,
        epsilon_decay: float,
        max_epsilon: float = 1.0,
        min_epsilon: float = 0.1,
        gamma: float = 0.99,
        # PER parameters
        alpha: float = 0.2,
        beta: float = 0.6,
        prior_eps: float = 1e-6,
        # N-step Learning
        n_step: int = 3,
    ):
        """
        Initialization.

        Args:
            env (gym.Env): openAI Gym environment
            memory_size (int): length of memory
            batch_size (int): batch size for sampling
            target_update (int): period for target model's hard update
            epsilon_decay (float): step size to decrease epsilon
            lr (float): learning rate
            max_epsilon (float): max value of epsilon
            min_epsilon (float): min value of epsilon
            gamma (float): discount factor
            alpha (float)    : determines how much prioritization is used
            beta (float)     : determines how much importance sampling is used
            prior_eps (float): guarantees every transition can be sampled
            n_step (int): step number to calculate n-step td error
        """
        self.env = env
        # network parameters
        self.state_size = self.env.observation_space.shape[0]
        self.action_size = self.env.action_space.n

        # hyperparameters
        self.batch_size = batch_size
        self.epsilon = max_epsilon
        self.epsilon_decay = epsilon_decay
        self.max_epsilon = max_epsilon
        self.min_epsilon = min_epsilon
        self.target_update = target_update
        self.gamma = gamma


        # device: cpu / gpu
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(self.device)

        # PER
        # memory for 1-step Learning
        self.beta = beta
        self.prior_eps = prior_eps
        self.memory = PrioritizedReplayBuffer(
            self.state_size, memory_size, batch_size, alpha
        )

        # memory for N-step Learning
        self.use_n_step = True if n_step > 1 else False
        if self.use_n_step:
            self.n_step = n_step
            self.memory_n = ReplayBuffer(
                self.state_size,
                memory_size,
                batch_size,
                n_step=n_step,
                gamma=gamma
            )

        # networks: dqn, dqn_target
        self.dqn = Network(self.state_size, self.action_size
                          ).to(self.device)
        self.dqn_target = Network(self.state_size, self.action_size
                          ).to(self.device)
        self.dqn_target.load_state_dict(self.dqn.state_dict())
        self.dqn_target.eval()

        # optimizer
        self.optimizer = optim.Adam(self.dqn.parameters())

        # transition to store in memory
        self.transition = list()

        # mode: train / test
        self.is_test = False

    def get_action(self, state: np.ndarray) -> np.ndarray:
        """Select an action from the input state."""
        # epsilon greedy policy
        if self.epsilon > np.random.random():
            selected_action = self.env.action_space.sample()
        else:
            selected_action = self.dqn(
                torch.FloatTensor(state).to(self.device)
            ).argmax()
            selected_action = selected_action.detach().cpu().numpy()

        if not self.is_test:
            self.transition = [state, selected_action]

        return selected_action

    def _compute_dqn_loss(
        self, samples: Dict[str, np.ndarray], gamma: float
    ) -> torch.Tensor:
        """Return dqn loss."""
        device     = self.device  # for shortening the following lines

        state      = torch.FloatTensor(samples["obs"]).to(device)
        action     = torch.LongTensor(samples["acts"].reshape(-1, 1)).to(device)
        reward     = torch.FloatTensor(samples["rews"].reshape(-1, 1)).to(device)
        next_state = torch.FloatTensor(samples["next_obs"]).to(device)
        done       = torch.FloatTensor(samples["done"].reshape(-1, 1)).to(device)

        # G_t   = r + gamma * v(s_{t+1})  if state != Terminal
        #       = r                       otherwise
        curr_Qs = self.dqn(state).gather(1, action)
        next_Q_targs = self.dqn_target(
            next_state
        ).max(dim=1, keepdim=True)[0].detach()
        mask = 1 - done
        target_value = (reward + self.gamma * next_Q_targs * mask).to(self.device)

        # calculate element-wise dqn loss
        elementwise_loss = F.smooth_l1_loss(curr_Qs, target_value, reduction="none")

        return elementwise_loss

    def train_step(self) -> torch.Tensor:
        """Update the model by gradient descent."""
        # PER needs beta to calculate weights

        samples = self.memory.sample_batch(self.beta)
        weights = torch.FloatTensor(samples["weights"].reshape(-1, 1)).to(self.device)
        indices = samples["indices"]

        # 1-step Learning loss
        elementwise_loss = self._compute_dqn_loss(samples, self.gamma)

        # PER: importance sampling before average
        loss = torch.mean(elementwise_loss * weights)

        # N-step Learning loss
        # we are gonna combine 1-step loss and n-step loss so as to
        # prevent high-variance. The original rainbow employs n-step loss only.
        if self.use_n_step:
            gamma = self.gamma ** self.n_step
            samples = self.memory_n.sample_batch_from_idxs(indices)
            elementwise_loss_n_loss = self._compute_dqn_loss(samples, gamma)
            elementwise_loss += elementwise_loss_n_loss

            # PER: importance sampling before average
            loss = torch.mean(elementwise_loss * weights)

        self.optimizer.zero_grad()
        loss.backward()
        clip_grad_norm_(self.dqn.parameters(), 10.0)
        self.optimizer.step()

        # PER: update priorities
        loss_for_prior = elementwise_loss.detach().cpu().numpy()
        new_priorities = loss_for_prior + self.prior_eps
        self.memory.update_priorities(indices, new_priorities)

        return loss.item()

    def _target_hard_update(self):
        """Hard update: target <- local."""
        self.dqn_target.load_state_dict(self.dqn.state_dict())

# environment
env_name = "CartPole-v0"
env = gym.make(env_name)



memory_size = 2000
target_update = 100
epsilon_decay = 1 / 2000
initial_random_steps = 5000

max_episodes = 100
batch_size = 32


agent = DQNAgent(
    env,
    memory_size,
    batch_size,
    target_update,
    epsilon_decay,
)

if __name__ == "__main__":

    """Train the agent."""
    agent.is_test = False

    update_cnt    = 0
    epsilons      = []
    losses        = []
    scores        = []
    frame_idx = 0
    num_frames= 100000

    # EACH EPISODE
    for episode in range(max_episodes):
        ## Reset environment and get first new observation
        state = agent.env.reset()
        episode_reward = 0
        done = False  # has the enviroment finished?

        while not done:
            '''
            Get Action
            '''
            action = agent.get_action(state)

            '''
            Execute Action and Observe
            '''
            next_state, reward, done, _ = agent.env.step(action)

            '''
            Store Transitions
            '''
            agent.transition += [reward, next_state, done]

            # N-step transition
            if agent.use_n_step:
                one_step_transition = agent.memory_n.store(*agent.transition)
            # 1-step transition
            else:
                one_step_transition = agent.transition

            # add a single step transition
            if one_step_transition:
                agent.memory.store(*one_step_transition)

            state = next_state
            episode_reward += reward

            frame_idx += 1

            # PER: increase beta
            fraction = min(frame_idx / num_frames, 1.0)
            agent.beta = agent.beta + fraction * (1.0 - agent.beta)

            # if episode ends
            if done:
                state = agent.env.reset()
                scores.append(episode_reward)
                print("Episode " + str(episode+1) + ": " + str(episode_reward))

            # if training is ready
            if (len(agent.memory) >= agent.batch_size):
                loss = agent.train_step()
                losses.append(loss)
                update_cnt += 1

                # linearly decrease epsilon
                agent.epsilon = max(
                    agent.min_epsilon, agent.epsilon - (
                        agent.max_epsilon - agent.min_epsilon
                    ) * agent.epsilon_decay
                )
                epsilons.append(agent.epsilon)

                # if hard update is needed
                if update_cnt % agent.target_update == 0:
                    agent._target_hard_update()



--2024-01-09 11:58:21--  https://raw.githubusercontent.com/curt-park/rainbow-is-all-you-need/master/segment_tree.py
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.110.133, 185.199.108.133, 185.199.111.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.110.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 4283 (4.2K) [text/plain]
Saving to: ‘segment_tree.py’


2024-01-09 11:58:21 (33.7 MB/s) - ‘segment_tree.py’ saved [4283/4283]



  logger.warn(
  deprecation(
  deprecation(


cpu
Episode 1: 11.0
Episode 2: 27.0
Episode 3: 17.0
Episode 4: 12.0
Episode 5: 20.0
Episode 6: 20.0
Episode 7: 12.0
Episode 8: 11.0
Episode 9: 20.0
Episode 10: 11.0
Episode 11: 23.0
Episode 12: 16.0
Episode 13: 49.0
Episode 14: 32.0
Episode 15: 27.0
Episode 16: 15.0
Episode 17: 13.0
Episode 18: 11.0
Episode 19: 27.0
Episode 20: 59.0
Episode 21: 17.0
Episode 22: 25.0
Episode 23: 16.0
Episode 24: 19.0
Episode 25: 17.0
Episode 26: 28.0
Episode 27: 43.0
Episode 28: 30.0
Episode 29: 34.0
Episode 30: 43.0
Episode 31: 15.0
Episode 32: 125.0
Episode 33: 136.0
Episode 34: 40.0
Episode 35: 156.0
Episode 36: 161.0
Episode 37: 134.0
Episode 38: 200.0
Episode 39: 200.0
Episode 40: 200.0
Episode 41: 200.0
Episode 42: 165.0
Episode 43: 160.0
Episode 44: 200.0
Episode 45: 180.0
Episode 46: 170.0
Episode 47: 153.0
Episode 48: 156.0
Episode 49: 160.0
Episode 50: 139.0
Episode 51: 200.0
Episode 52: 156.0
Episode 53: 163.0
Episode 54: 200.0
Episode 55: 140.0
Episode 56: 154.0
Episode 57: 194.0
Episode 58: