In [1]:
import torch
import torch.nn as nn
import random
from collections import deque
import numpy as np
import torch.optim as optim
import gym

In [2]:
def cartpole_model(observation_space, action_space):
    return nn.Sequential(
        nn.Linear(observation_space, 24),
        nn.ReLU(),
        nn.Linear(24, 24),
        nn.ReLU(),
        nn.Linear(24, action_space)
    )

In [3]:
class DQN:
    def __init__(self, observation_space, action_space):
        self.exploration_rate = MAX_EXPLORE
        self.action_space = action_space
        self.observation_space = observation_space
        self.memory = deque(maxlen=MEMORY_LEN)
        
        self.target_net = cartpole_model(self.observation_space, self.action_space)
        self.policy_net = cartpole_model(self.observation_space, self.action_space)
        
        self.target_net.load_state_dict(self.policy_net.state_dict())
        self.target_net.eval()
        
        self.criterion = nn.MSELoss()
        self.optimizer = optim.Adam(self.policy_net.parameters())
        
        self.explore_limit = False
        
    def load_memory(self, state, action, reward, next_state, terminal):
        self.memory.append((state, action, reward, next_state, terminal))
        
    def predict_action(self, state):
        random_number = np.random.rand()
        
        if random_number < self.exploration_rate:
            return random.randrange(self.action_space)
        
        q_values = self.target_net(state).detach().numpy()
        return np.argmax(q_values[0])
    
    def experience_replay(self):
        if len(self.memory) < BATCH_SIZE:
            return
        
        for state, action, reward, next_state, terminal in batch:
            q_update = reward
            
            if not terminal:
                    q_update = reward + GAMMA * self.target_net(next_state).max(axis=1)[0]
                    
            q_values = self.target_net(state)
            q_values[0][action] = q_update
            
            loss = self.criterion(self.policy_net(state), q_values)
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            
            if not self.explore_limit:
                self.exploration_rate *= EXPLORE_DECAY
                if self.exploration_rate < MIN_EXPLORE:
                    self.exploration_rate = MIN_EXPLORE
                    self.explore_limit = True

In [4]:
ENV_NAME = "CartPole-v1"
BATCH_SIZE = 20
GAMMA = 0.95
LEARNING_RATE = 0.001
MAX_EXPLORE = 1.0
MIN_EXPLORE = 0.01
EXPLORE_DECAY = 0.995
MEMORY_LEN = 1_000_000
UPDATE_FREQ = 10

In [5]:
env = gym.make(ENV_NAME)
observation_space = env.observation_space.shape[0]
action_space = env.action_space.n
dqn = DQN(observation_space, action_space)

In [25]:
print(f'| Run  | Exploration Rate | Score |')
for i in range(50):
    state = env.reset()
    state = np.reshape(state, [1, observation_space])
    state = torch.from_numpy(state).float()
    
    score = 0
    while True:
        score += 1
        action = dqn.predict_action(state)
        next_state, reward, terminal, info = env.step(action)
        
        next_state = torch.from_numpy(np.reshape(next_state, [1, observation_space])).float()
        dqn.load_memory(state, action, reward, next_state, terminal)
        state = next_state
        
        if terminal:
            print(f'|  {i+1:02}  | {dqn.exploration_rate:.4f}           | {score:03}   |')
            break

| Run  | Exploration Rate | Score |
|  01  | 1.0000           | 035   |
|  02  | 1.0000           | 050   |
|  03  | 1.0000           | 017   |
|  04  | 1.0000           | 030   |
|  05  | 1.0000           | 039   |
|  06  | 1.0000           | 035   |
|  07  | 1.0000           | 020   |
|  08  | 1.0000           | 036   |
|  09  | 1.0000           | 035   |
|  10  | 1.0000           | 014   |
|  11  | 1.0000           | 019   |
|  12  | 1.0000           | 025   |
|  13  | 1.0000           | 012   |
|  14  | 1.0000           | 029   |
|  15  | 1.0000           | 014   |
|  16  | 1.0000           | 031   |
|  17  | 1.0000           | 021   |
|  18  | 1.0000           | 013   |
|  19  | 1.0000           | 026   |
|  20  | 1.0000           | 020   |
|  21  | 1.0000           | 014   |
|  22  | 1.0000           | 029   |
|  23  | 1.0000           | 013   |
|  24  | 1.0000           | 012   |
|  25  | 1.0000           | 017   |
|  26  | 1.0000           | 019   |
|  27  | 1.0000           | 