# DQN CartPole

### Components and random seed.

In [None]:
import torch as th
import torch.nn as nn
import torch.optim as optim
import matplotlib as plt

from components.agent import ReplayMemory, DQN, DQNAgent
from components.environments import CartPoleEnv
from components.utils import set_random_seed, get_parameters

In [None]:
set_random_seed(2)

### Hyperparameters

In [None]:
LR = 0.01
GAMMA = 0.95
EPS = 1
EPS_DECAY = 0.995
EPS_MIN = 0.1
TAU = 0.05

MEMORY_CAPACITY = 4000

BATCH_SIZE = 512


### Define the DQN agent and the environment

In [None]:
env = CartPoleEnv()

In [None]:
model = nn.Sequential(
    nn.Linear(env.n_actions, 64),
    nn.ReLU(),
    nn.Linear(64, env.n_observations)
)
criterion = nn.MSELoss()
optimiser = optim.Adam(get_parameters(model), lr=LR)
dqn = DQN(model, criterion, optimiser)

memory = ReplayMemory(MEMORY_CAPACITY)

agent = DQNAgent(
    dqn, memory,
    gamma=GAMMA,
    eps=EPS,
    eps_decay=EPS_DECAY,
    eps_min=EPS_MIN,
    tau=TAU)

### Train the agent

In [None]:
history = agent.train(env, 1000, batch_size=BATCH_SIZE, evaluation_episodes=25, apply_best_model=True)
print("Final score:", agent.evaluate(env, 100))

### Plot the learning curve

In [None]:
plt.plot(history)
plt.show()