In [1]:
import numpy as np
import gym
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import namedtuple

In [8]:
GAMMA = 0.9
LEARNING_RATE = 0.001

Transition = namedtuple('Transition', ('state', 'action', 'next_state', 'reward', 'done'))

In [14]:
class QLearningAgent:
    def __init__(self, num_states, num_actions):
        self.num_states = num_states
        self.num_actions = num_actions
        
        self.epsilon = 1.0
        self.min_epsilon = 0.01
        self.epsilon_decay = 0.5
        self.action_list = np.arange(num_actions)
        
        self.model = self._get_model()
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=LEARNING_RATE)
        
    def get_action(self, state, episode=None, train=True):
        if train:
            if self.epsilon <= np.random.uniform(0, 1):
                with torch.no_grad():
                    state = torch.FloatTensor(state).view(1, -1)
                    action = self.model(state).max(1)[1].item()
            else:
                action = np.random.choice(self.action_list)
        else:
            with torch.no_grad():
                state = torch.FloatTensor(state).view(1, -1)
                action = self.model(state).max(1)[1].item()
            
        return action
    
    def train(self, episode, state, action, next_state, reward, done):
        next_action = self.get_action(next_state, episode)
        
        state = torch.FloatTensor(state).view(-1, self.num_states)
        action = torch.LongTensor([action]).view(-1, 1)
        reward = torch.FloatTensor([reward]).view(-1, 1)
        next_state = torch.FloatTensor(next_state).view(-1, self.num_states)
        done = torch.ByteTensor([done]).view(-1, 1)
        next_action = torch.LongTensor([next_action]).view(-1, 1)
        
        state_action_value = self.model(state).gather(1, action)
#         next_state_action_value = self.model(next_state).gather(1, next_action)
        next_state_action_value = self.model(next_state).max(1)[0].view(-1, 1)
        
        td_target = reward + GAMMA * (1 - done) * next_state_action_value
        
        loss = F.smooth_l1_loss(state_action_value, td_target)
#         loss = F.mse_loss(state_action_value, td_target)
        
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        
        return loss.item()
    
    def decay_epsilon(self, episode):
        if self.epsilon > self.min_epsilon:
            self.epsilon *= self.epsilon_decay    
        
    def _get_model(self):
        model = nn.Sequential()
        model.add_module('fc1', nn.Linear(self.num_states, 32))
        model.add_module('relu1', nn.ReLU())
        model.add_module('fc2', nn.Linear(32, 32))
        model.add_module('relu2', nn.ReLU())
        model.add_module('fc3', nn.Linear(32, self.num_actions))
        return model

In [15]:
env = gym.make('CartPole-v0')

In [16]:
agent = QLearningAgent(4, 2)

In [17]:
N_EPOCH = 500

In [18]:
for episode in range(N_EPOCH):
    done = False
    state = env.reset()
    step = 0
    while not done:
        action = agent.get_action(state, episode)
        next_state, reward, done, _ = env.step(action)
        if done:
            if step < 195:
                reward = -1.0
            else:
                reward = 1.0
        else:
            reward = 0.0
            
        agent.train(episode, state, action, next_state, reward, done)
        state = next_state
        step += 1
    agent.decay_epsilon(episode)
    print("Episdoe {} Step {}".format(episode, step))

Episdoe 0 Step 25
Episdoe 1 Step 12
Episdoe 2 Step 9
Episdoe 3 Step 10
Episdoe 4 Step 9
Episdoe 5 Step 14
Episdoe 6 Step 13
Episdoe 7 Step 13
Episdoe 8 Step 17
Episdoe 9 Step 12
Episdoe 10 Step 13
Episdoe 11 Step 30
Episdoe 12 Step 9
Episdoe 13 Step 33
Episdoe 14 Step 28
Episdoe 15 Step 34
Episdoe 16 Step 19
Episdoe 17 Step 19
Episdoe 18 Step 38
Episdoe 19 Step 18
Episdoe 20 Step 21
Episdoe 21 Step 55
Episdoe 22 Step 35
Episdoe 23 Step 34
Episdoe 24 Step 20
Episdoe 25 Step 27
Episdoe 26 Step 28
Episdoe 27 Step 19
Episdoe 28 Step 73
Episdoe 29 Step 16
Episdoe 30 Step 27
Episdoe 31 Step 17
Episdoe 32 Step 17
Episdoe 33 Step 12
Episdoe 34 Step 11
Episdoe 35 Step 24
Episdoe 36 Step 31
Episdoe 37 Step 20
Episdoe 38 Step 45
Episdoe 39 Step 81
Episdoe 40 Step 26
Episdoe 41 Step 16
Episdoe 42 Step 50
Episdoe 43 Step 28
Episdoe 44 Step 52
Episdoe 45 Step 25
Episdoe 46 Step 65
Episdoe 47 Step 21
Episdoe 48 Step 37
Episdoe 49 Step 34
Episdoe 50 Step 41
Episdoe 51 Step 61
Episdoe 52 Step 18
Episdo

Episdoe 416 Step 9
Episdoe 417 Step 11
Episdoe 418 Step 10
Episdoe 419 Step 200
Episdoe 420 Step 200
Episdoe 421 Step 200
Episdoe 422 Step 200
Episdoe 423 Step 200
Episdoe 424 Step 191
Episdoe 425 Step 200
Episdoe 426 Step 200
Episdoe 427 Step 200
Episdoe 428 Step 160
Episdoe 429 Step 200
Episdoe 430 Step 200
Episdoe 431 Step 177
Episdoe 432 Step 194
Episdoe 433 Step 200
Episdoe 434 Step 200
Episdoe 435 Step 193
Episdoe 436 Step 194
Episdoe 437 Step 200
Episdoe 438 Step 200
Episdoe 439 Step 200
Episdoe 440 Step 200
Episdoe 441 Step 10
Episdoe 442 Step 10
Episdoe 443 Step 9
Episdoe 444 Step 8
Episdoe 445 Step 10
Episdoe 446 Step 8
Episdoe 447 Step 9
Episdoe 448 Step 10
Episdoe 449 Step 9
Episdoe 450 Step 10
Episdoe 451 Step 9
Episdoe 452 Step 11
Episdoe 453 Step 10
Episdoe 454 Step 13
Episdoe 455 Step 10
Episdoe 456 Step 11
Episdoe 457 Step 10
Episdoe 458 Step 10
Episdoe 459 Step 11
Episdoe 460 Step 11
Episdoe 461 Step 12
Episdoe 462 Step 200
Episdoe 463 Step 200
Episdoe 464 Step 200
Ep