In [1]:
import gym
import torch
import time
import numpy as np
from tqdm import tqdm

from rl_toolkit.mlp import MLP
from rl_toolkit.models import DQN
from rl_toolkit.experience_replay import SARSReplayBuffer
from rl_toolkit.agents import DQNAgent

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
env = gym.make("CartPole-v1")

## Initialize replay buffer

In [3]:
replay_buffer = SARSReplayBuffer(200)
observation, _ = env.reset(return_info=True)
for _ in range(200):
    action = env.action_space.sample()
    new_observation, reward, done, info = env.step(action)
    sars = (observation, action, reward - done, new_observation)
    replay_buffer.add(sars)
    if done:
        new_observation, info = env.reset(return_info=True)
    observation = new_observation

# Train & Play

In [4]:
from torch.optim import Adam

class DQNTrainer:
    def __init__(self, dqn_agent: DQNAgent):
        self.dqn_agent = dqn_agent
        self.gamma = 0.98
        self.optimizer = Adam(dqn_agent.backbone.parameters(), lr=1e-5)
        
    def train(self, batch: list):
        batch = list(map(torch.tensor, batch))
        states, actions, rewards, new_states = batch
        state_q = self.dqn_agent.backbone(torch.tensor(states))
        action_q = state_q[list(range(len(state_q))), actions]
        with torch.no_grad():
            new_state_q = self.dqn_agent.backbone(torch.tensor(new_states))
            new_state_best_action_q = new_state_q.max(axis=-1).values
            target_q = rewards + self.gamma * new_state_best_action_q
            target_q[~rewards.bool()] = 0
        loss = torch.mean((target_q - action_q) ** 2, axis=0)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        return loss.detach().numpy()

In [5]:
def train():
    observation, _ = env.reset(return_info=True)
    losses = []
    for iteration in range(8000):
        # play
        agent.eval()
        action = agent.sample_action(observation)
        new_observation, reward, done, info = env.step(action)
        sars = (observation, action, reward - done, new_observation)
        replay_buffer.add(sars)
        if done:
            new_observation, _ = env.reset(return_info=True)
        observation = new_observation
        
        # train
        agent.train()
        batch = replay_buffer.sample_batch()
        loss = trainer.train(batch)
        losses.append(loss)
    mean_loss = np.mean(losses, axis=0)
    return mean_loss
        

In [6]:
def evaluate():
    agent.eval()
    game_rewards = []
    for i in range(20):
        sum_reward = 0
        observation, _ = env.reset(return_info=True)
        for _ in range(200):
            # action = env.action_space.sample()
            action = agent.get_best_action(observation)
            new_observation, reward, done, info = env.step(action)
            sum_reward += reward
            if done:
                break
        game_rewards.append(sum_reward)
    return np.mean(game_rewards, axis=0)

In [7]:
agent = DQNAgent(env.observation_space, env.action_space)
trainer = DQNTrainer(agent)

In [8]:
evaluate()

9.0

In [None]:
scores = []
for i in range(200):
    if (i % 10 == 0):
        score = evaluate()
        scores.append(score)
        print(f"Test cumulative reward: {score}")
    train_loss = train()
    print(f"Train loss: {train_loss}")

  state_q = self.dqn_agent.backbone(torch.tensor(states))
  new_state_q = self.dqn_agent.backbone(torch.tensor(new_states))


Test cumulative reward: 9.2
Train loss: 1.007729270152436
Train loss: 1.1049978982669335
Train loss: 1.2404377462185696
Train loss: 1.417643605745585
Train loss: 1.6178649829642149
Train loss: 1.8349394541539923
Train loss: 2.058304997963266
Train loss: 2.3013122145539664
Train loss: 2.515867197002885
Train loss: 2.7341497048270393
Test cumulative reward: 9.55
Train loss: 2.95156842238977
Train loss: 3.145705130821252
Train loss: 3.3251821132933745
Train loss: 3.4804464753773807
Train loss: 3.6582658036661493
Train loss: 3.778250332602997
Train loss: 3.937458149895746
Train loss: 4.001721282965039
Train loss: 4.09192989608794
Train loss: 4.14239419361984
Test cumulative reward: 9.2
Train loss: 4.252531904696947
Train loss: 4.284044665841863
Train loss: 4.3237428002180645
Train loss: 4.345989202062832
Train loss: 4.42457718966566
Train loss: 4.381616716136221
Train loss: 4.4390048354587215
Train loss: 4.381297177834338
Train loss: 4.432477213003804
Train loss: 4.419156805879639
Test cum

In [None]:
from matplotlib import pyplot as plt
plt.plot(scores)

In [None]:
replay_buffer.sample_batch()[2].sum()

In [None]:
replay_buffer.sample_batch()[2].sum()

In [None]:
sum(replay_buffer._buffer.buffer[2])