In [None]:
  class SumTree:
      data_pointer = 0

      def __init__(self, capacity):
          self.capacity = capacity  # leaf node의 수 = capacity
          self.tree = np.zeros(2 * capacity - 1)  # 총 node의 수 -> 우선순위(priority)를 저장
          self.data = np.zeros(capacity, dtype=object)  # 경험(state, action, reward, next state, done flag로 이루어진 tuple)을 저장
          self.n_entries = 0

      def add(self, priority, data):
          tree_index = self.data_pointer + self.capacity - 1
          self.data[self.data_pointer] = data # update data 프레임
          self.update(tree_index, priority) # leaf(priority) 업데이트
          self.data_pointer += 1  # pointer를 1 증가시킴
          if self.data_pointer >= self.capacity:  # capacity를 넘었다면 첫번째 index로 돌아감
              self.data_pointer = 0
          if self.n_entries < self.capacity:
              self.n_entries += 1

      # leaf priority score 업데이트
      def _propagate(self, idx, change):
          parent = (idx - 1) // 2
          self.tree[parent] += change
          # parent가 0이면 중단. root node에 도달했기 때문
          if parent != 0:
              self._propagate(parent, change)

      def update(self, tree_index, priority):
          change = priority - self.tree[tree_index]
          self.tree[tree_index] = priority
          self._propagate(tree_index, change)

      # 이진 트리 구조를 사용하여 특정 조건을 만족하는 노드를 찾는 재귀 함수(recursive function)
      # 주어진 값 s에 대해 특정 조건을 만족하는 노드의 인덱스를 찾아라.
      def _retrieve(self, idx, s):
          left_child_index = 2 * idx + 1
          right_child_index = left_child_index + 1
          # 현재 노드가 leaf node(자식이 없는 node)인 경우를 검
          if left_child_index >= len(self.tree):
              return idx
          # s가 왼쪽 자식 노드에 저장된 값보다 작거나 같으면, 왼쪽 자식으로 재귀적으로 이동
          if s <= self.tree[left_child_index]:
              return self._retrieve(left_child_index, s)
          else:
              return self._retrieve(right_child_index, s - self.tree[left_child_index])

      def get_leaf(self, s):
          leaf_index = self._retrieve(0, s)
          data_index = leaf_index - self.capacity + 1
          return (leaf_index, self.tree[leaf_index], self.data[data_index])

      # 루트 노드를 반환
      def total_priority(self):
          return self.tree[0]


In [None]:
class PrioritizedReplayBuffer(object):
    PER_e = 0.001 # 어떤 경험을 할 확률이 0이 되지 않도록 하는 hyperparameter
    PER_a = 0.6 # 우선순위가 높은 것과 무작위 샘플링 사이 절충을 하기 위한 hyperparameter
    PER_b = 0.4 # Importance Sampling. 1까지 증가
    PER_b_increment_per_sampling = 0.001

    def __init__(self, capacity):
        self.tree = SumTree(capacity)
        self.capacity = capacity

    # 최대 우선 순위 검색
    def _getPriority(self, error):
        return (error + self.PER_e) ** self.PER_a

    def store(self, error, sample):
        max_priority = self._getPriority(error)
        self.tree.add(max_priority, sample)

    def sample(self, n):
        minibatch = []
        idxs = []
        priority_segment = self.tree.total_priority() / n
        priorities = []
        self.PER_b = np.min([1., self.PER_b + self.PER_b_increment_per_sampling])

        for i in range(n):
            a = priority_segment * i
            b = priority_segment * (i + 1)
            value = np.random.uniform(a, b)
            (idx, p, data) = self.tree.get_leaf(value)
            priorities.append(p)
            minibatch.append(data)
            idxs.append(idx)

        sampling_probabilities = priorities / self.tree.total_priority()
        is_weight = np.power(self.tree.n_entries * sampling_probabilities, -self.PER_b)
        is_weight /= is_weight.max()

        return minibatch, idxs, is_weight

    def batch_update(self, idx, error):
        p = self._getPriority(error)
        self.tree.update(idx, p)


In [None]:
def append_sample(self, state, action, reward, next_state, done):
    # PyTorch 텐서로 변환
    state = torch.FloatTensor(state).unsqueeze(0)
    next_state = torch.FloatTensor(next_state).unsqueeze(0)
    action = torch.LongTensor([action])
    reward = torch.FloatTensor([reward])
    done = torch.FloatTensor([done])

    # Q 값 계산
    with torch.no_grad():
        main_next_q = self.dqn(next_state)
        next_action = main_next_q.max(1)[1].view(1, 1)
        target_next_q = self.dqn_target(next_state)
        target_value = target_next_q.gather(1, next_action).item()

    target_value = reward + (self.gamma * target_value * (1 - done))

    # 현재 상태에 대한 Q 값
    main_q = self.dqn(state).gather(1, action.unsqueeze(1)).item()

    # TD 오차 계산
    td_error = abs(target_value - main_q)

    # 메모리에 경험 저장
    self.MEMORY.store(td_error, (state, action, reward, next_state, done))


In [None]:
def train_step(self):
    mini_batch, idxs, IS_weights = self.MEMORY.sample(self.batch_size)

    states, actions, rewards, next_states, dones = zip(*mini_batch)
    states = torch.FloatTensor(states)
    actions = torch.LongTensor(actions)
    rewards = torch.FloatTensor(rewards)
    next_states = torch.FloatTensor(next_states)
    dones = torch.FloatTensor(dones)

    # Q 값 계산 및 TD 오차 업데이트
    current_q_values = self.dqn(states).gather(1, actions.unsqueeze(1)).squeeze(1)
    with torch.no_grad():
        max_next_q_values = self.dqn_target(next_states).max(1)[0]
    expected_q_values = rewards + (self.gamma * max_next_q_values * (1 - dones))

    # 오차 계산
    td_errors = abs(current_q_values - expected_q_values)

    # 손실 계산 (IS 가중치 적용)
    loss = (td_errors.pow(2) * torch.FloatTensor(IS_weights)).mean()

    # 역전파
    self.optimizer.zero_grad()
    loss.backward()
    self.optimizer.step()

    # 메모리 우선순위 업데이트
    for i in range(self.batch_size):
        idx = idxs[i]
        self.MEMORY.batch_update(idx, td_errors[i].item())
