In [None]:
import torch
import torch.nn as nn
import numpy as np
import random
from collections import deque

from helpers_HW.HW3.ion_trap import IonTrapEnv
from helpers_HW.HW3.utils import is_valid_srv

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class QNetwork(nn.Module):
    def __init__(self, state_dim, n_actions):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, n_actions),
        )

    def forward(self, x):
        return self.net(x)

class ReplayBuffer:
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)

    def push(self, s, a, r, s2, done):
        # Stores new values in the buffer
        self.buffer.append((s, a, r, s2, done))

    def sample(self, batch_size):
        # Samples a batch of experiences from the buffer
        batch = random.sample(self.buffer, batch_size)
        s, a, r, s2, d = map(np.array, zip(*batch))
        return (
            torch.tensor(s, dtype=torch.float32, device=device),
            torch.tensor(a, dtype=torch.int64,   device=device),
            torch.tensor(r, dtype=torch.float32, device=device),
            torch.tensor(s2, dtype=torch.float32,device=device),
            torch.tensor(d, dtype=torch.float32, device=device),
        )

    def __len__(self):
        return len(self.buffer)
   
class DQNAgent:
    def __init__(self, state_dim, n_actions,
                 # Q update parameter
                 gamma=0.99, 
                 # Learning rate
                 lr=1e-3, 
                 # Parameters for epsilon decay
                 eps_start=1.0, eps_end=0.05, eps_decay=500):
       
        self.n_actions = n_actions
        self.gamma = gamma

        self.q_net = QNetwork(state_dim, n_actions).to(device)

        # And now the target network, which is just a copy of the previous one.
        # Because this one won't be trained, we directly set it in eval mode!
        self.target_net = QNetwork(state_dim, n_actions).to(device)
        self.target_net.load_state_dict(self.q_net.state_dict())
        self.target_net.eval()

        self.optimizer = torch.optim.Adam(self.q_net.parameters(), lr=lr)

        # Parameters for epsilon decay
        self.eps_start = eps_start
        self.eps_end = eps_end
        self.eps_decay = eps_decay
        self.steps_done = 0

    def epsilon(self):
        # Computes the epsilon value based on the decay schedule
        return self.eps_end + (self.eps_start - self.eps_end) * \
               np.exp(-1.0 * self.steps_done / self.eps_decay)
    
    def act(self, state): #obs_np: np.ndarray
        # We will start by increasing the step counter for the epsilon decay
        self.steps_done += 1
        eps = self.epsilon()
        
        # Important: because here there won't be any policy update, we can perform the sampling
        # from the q-network with torch.no_grad():
        if random.random()<eps:
            action = random.randrange(self.n_actions)

        # random step
        else:
            with torch.no_grad():
                obs = torch.tensor(obs_np, dtype=torch.float32, device=device).unsqueeze(0)
                q = self.q_net(obs)               # [1, n_actions]
                action = int(torch.argmax(q, dim=1).item())

        return action
    
    def update(self, replay_buffer, batch_size):
    
            # If our buffer does not have enough samples to fill a full batch, we skip the update
            if len(replay_buffer) < batch_size:
                return
    
            # If not, we sample a batch from the replay buffer
            states, actions, rewards, next_states, dones = replay_buffer.sample(batch_size)

            # We now sample from a q network. Because we are interested only in the Q-values of the actions taken,
            # we use the `gather` method of pytorch tensors to select the Q-values corresponding to the actions taken.
            q_values = self.q_net(states)
            q_sa = q_values.gather(1, actions.unsqueeze(1)).squeeze(1)

            with torch.no_grad():
                max_next_q = self.target_net(next_states).max(dim=1).values
                target = rewards + self.gamma * (1.0 - dones) * max_next_q

            loss = nn.MSELoss()(q_sa, target)

            self.optimizer.zero_grad()
            loss.backward()
            nn.utils.clip_grad_norm_(self.q_net.parameters(), 5.0)
            self.optimizer.step()

            # Now implement the same as above considering the previous batches

    # Finally, we define the target network update method, which simply copies the weights from the Q-network to the target network.
    def update_target(self):
        self.target_net.load_state_dict(self.q_net.state_dict()) 


In [None]:
from itertools import product

def all_valid_srvs(num_ions: int, dim: int, is_valid_srv_fn):
    valids = []
    for srv in product(range(1, dim+1), repeat=num_ions):
        ok, _ = is_valid_srv_fn(srv, d=dim)
        if ok:
            valids.append(list(srv))
    return valids

[[1, 1, 1], [1, 2, 2], [2, 1, 2], [2, 2, 1], [2, 2, 2], [2, 2, 3], [2, 3, 2], [2, 3, 3], [3, 2, 2], [3, 2, 3], [3, 3, 2], [3, 3, 3]]


### Training etc

In [None]:
def train_dqn_iontrap(
    episodes=20000,
    buffer_capacity=50000,
    batch_size=128,
    target_update_freq=500,
    train_every=1,
    seed=0,
):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

    env = IonTrapEnv()
    n_actions = env.num_actions


    valid_goals = all_valid_srvs(env.num_ions, env.dim, is_valid_srv)
    if len(valid_goals) == 0:
        raise RuntimeError("No valid SRVs found for this (num_ions, dim).")

    # Observation dim = 2*(d^n) + num_ions
    #state_dim = env.dim ** env.num_ions
    #obs_dim = 2 * state_dim + env.num_ions

    agent = DQNAgent(state_dim=, n_actions=n_actions,
                     gamma=0.99, lr=1e-3,
                     eps_start=1.0, eps_end=0.05, eps_decay=20000)
    replay = ReplayBuffer(capacity=buffer_capacity)

    total_steps = 0
    returns = []
    success_rate_window = deque(maxlen=200)

    for ep in range(episodes):
        goal = random.choice(valid_goals)


        env.goal = [goal]  # env expects list of SRVs

        state = env.reset()
        obs = make_obs(state, goal, env)

        ep_return = 0.0
        done = False

        while not done:
            action = agent.act(obs)

            next_state, reward, done = env.step(action)

            next_obs = make_obs(next_state, goal, env)

            replay.push(obs, action, reward, next_obs, float(done))

            obs = next_obs
            ep_return += reward
            total_steps += 1

            if total_steps % train_every == 0:
                agent.update(replay, batch_size)

            if total_steps % target_update_freq == 0:
                agent.update_target()

        returns.append(ep_return)
        success_rate_window.append(1 if ep_return > 0.0 else 0)

        # print progress occasionally
        if (ep + 1) % 500 == 0:
            sr = sum(success_rate_window) / len(success_rate_window)
            print(f"Episode {ep+1:6d} | recent success rate: {sr:.3f} | eps: {agent.epsilon():.3f}")

    return agent, returns

# Example:
agent, returns, valid_goals = train_dqn_iontrap(episodes=10000)


In [None]:
from tqdm import trange
import matplotlib.pyplot as plt

def train_dqn_iontrap(
    episodes=300):

    # Initialized the environment
    env = IonTrapEnv()
    state_dim = env.observation_space.shape[0]
    n_actions = env.action_space.n

    # Initialize the DQN agent and buffer
    agent = DQNAgent(state_dim, n_actions)
    batch_size=64
    target_update_freq=1000

    replay = ReplayBuffer(capacity = 50000)

    # We keep track of the number of total steps for the target network update
    total_steps = 0

    # We keep track of the amount of rewards
    returns = []  

    pbar = trange(episodes, desc="Training DQN", leave=True)

    for ep in pbar:
        
        s, _ = env.reset() 
        ep_return = 0
        done = False

        while not done:
            # act
            a = agent.act(s)
            
            s2, r, terminated, truncated, _ = env.step(a)
            done = terminated or truncated

            # store transition
            replay.push(s, a, r, s2, done)
            s = s2
            ep_return += r
            total_steps += 1

            # update networks
            agent.update(replay, batch_size)

            if total_steps % target_update_freq == 0:
                agent.update_target()

        # logging
        returns.append(ep_return)
        
        pbar.set_postfix(
            return_=f"{ep_return:.1f}",
            eps=f"{agent.epsilon():.3f}"
        )

    return agent, returns

plt.plot(returns)
plt.xlabel("Episode")
plt.ylabel("Return")
plt.title("DQN on CartPole")
plt.show()