In [6]:
import torch
import torch.nn as nn
import torch.optim as optim
import random
import numpy as np
from collections import namedtuple

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [7]:
Transition = namedtuple('Transision', ('state', 'action', 'reward', 'next_state', 'done'))

class SumTree:
    def __init__(self, capacity):
        self.capacity = capacity
        self.data = np.zeros(capacity, dtype=object)
        self.priorities = np.zeros(2 * capacity - 1)
        self.write_idx = 0

    def _propagate(self, idx, change):
        parent = (idx - 1) // 2
        self.priorities[parent] += change
        if parent != 0:
            self._propagate(parent, change)

    def _retrieve(self, idx, s):
        left = 2 * idx + 1
        right = left + 1

        if left >= len(self.priorities):
            return idx
        
        if s <= self.priorities[left]:
            return self._retrieve(left, s)
        else:
            return self._retrieve(right, s-self.priorities[left])
        
    def total(self):
        return self.priorities[0]  # root node is the total priority
    
    def add(self, priority , data):
        idx = self.write_idx + self.capacity - 1
        
        self.data[self.write_idx] = data
        self._propagate(idx, priority)

        self.write_idx += 1
        if self.write_idx >= self.capacity:
            self.write_idx = 0

    
    def get(self, s):
        idx = self._retrieve(0, s)
        data_idx = idx - self.capacity + 1
        return (idx, self.priorities[idx], self.data[data_idx])

In [8]:
class PER:
    def __init__(self, capacity, alpha, beta, beta_increment):
        self.buffer = []
        self.priorities = SumTree(capacity)
        self.alpha = alpha
        self.beta = beta
        self.beta_increment = beta_increment
        self.max_priority = 1.0

    def add(self, experience):
        self.buffer.append(experience)
        self.priorities.add(self.max_priority, experience)

    def sample(self, batch_size):
        batch = []
        idxs = []
        segment = self.priorities.total() / batch_size
        is_weights = np.zeros((batch_size, 1))

        for i in range(batch_size):
            a = segment * i
            b = segment * (i + 1)
            s = random.uniform(a, b)

            (idx, priority, data) = self.priorities.get(s)
            
            is_weights[i, 0] = (self.priorities.total() * priority) ** (-self.beta)  # imp_samp weights = ((1/N) * 1/prio(i))^beta
            
            batch.append(data)
            idxs.append(idx)
        
        is_wights /= is_weights.max()
        
        return batch, idx, is_weights
    
    def update_priorities(self, idxs, priorities):
        for idx, priority in zip(idxs, priorities):
            self.priorities._propagate(idx, priority - self.priorities[idx])  # update the priority
            self.max_priority = max(self.max_priority, priority)

        self.beta = min(1.0, self.beta + self.beta_increment)


In [9]:
class DDQN(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim

        self.layers = nn.Sequential(
            nn.Linear(self.input_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.linear(128, self.output_dim)
        )
    
    def forward(self, state):
        QValues = self.layers(state)
        return QValues
    
    def actor(self, state, epsilon):
        if random.random() > epsilon:
            with torch.no_grad():
                state = torch.tensor(state, dtype=torch.float32).unsqueeze(0).to(device)
                q_values = self(state)
                action = q_values.max(1)[1].item()  # this is basically the greedy action

        else: 
            action = random.randrange(self.output_dim)  # if the epsilon value not higher then take random actions

        return action

In [10]:

gamma = 0.99
batch_size = 32
learning_rate = 1e-3

# input and output dim must be set dynamically somehow


memory = PER(capacity=10000, alpha=0.6)
policy_net = DDQN(input_dim=input_dim, output_dim=output_dim).to(device)
target_net = DDQN(input_dim=input_dim, output_dim=output_dim).to(device)

target_net.load_state_dict(policy_net.state_dict())

optimizer = optim.Adam(policy_net.parameters(), lr=learning_rate)

def update_model():
    # sample a batch with priorities
    batch, idxs, is_weights = memory.sample(batch_size)

    # Convert to tensors, move to device
    states = torch.tensor(states, dtype=torch.float32).to(device)
    actions = torch.tensor(actions, dtype=torch.long).to(device)
    rewards = torch.tensor(rewards, dtype=torch.float32).to(device)
    next_states = torch.tensor(next_states, dtype=torch.float32).to(device)
    dones = torch.tensor(dones, dtype=torch.bool).to(device)


    current_q_value = policy_net(states).gather(1, actions.unsqueeze(1)).squeeze(1)

    # Calculate TD-Target
    # (Get next_state Q-values from both current and target networks for DDQN)
    with torch.no_grad():
        next_states_q_values = policy_net(next_states)
        best_next_actions = next_states_q_values.max(1)[1].unsqueeze(1)  # DDQN: Select actions according to policy_net

        next_states_target_q_values = target_net(next_states).gather(1, best_next_actions)  # DDQN: Evaluate with target_net

        td_target = rewards + (gamma * next_states_target_q_values * (1 - dones))


    # calculate the loss using importance sampling weights
    loss = (td_target - current_q_value) ** 2 * torch.tensor(is_weights).to(device)
    loss = loss.mean()

    # optimize
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # update the priorities in sum-tree
    new_priorities = np.abs(td_target - current_q_value).detach().cpu().numpy()  #calculate absolute TD-Error
    # We take the absolute values as priorities should represent the magnitude of the errors.
    
    memory.update_priorities(idxs, new_priorities)


SyntaxError: invalid syntax (3108991969.py, line 51)