In [3]:
import math, random

import torch
import torch.nn as nn
import torch.optim as optim
import torch.autograd as autograd 
import torch.nn.functional as F

import matplotlib.pyplot as plt

import gym
import numpy as np

from collections import deque
from tqdm import trange

In [4]:
class DQN(nn.Module):
    
    def __init__(self, n_state, n_action):
        super(DQN, self).__init__()        
        self.layers = nn.Sequential(
            nn.Linear(n_state, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, n_action)
        )
        
    def forward(self, x):
        return self.layers(x)
    
    def act(self, state, epsilon):
        # Get an epsilon greedy action for given state
        if random.random() > epsilon: # Use argmax_a Q(s,a)
            state = autograd.Variable(torch.FloatTensor(state).unsqueeze(0), volatile=True).to(device)
            q_value = self.forward(state)
            q_value = q_value.cpu()
            action = q_value.max(1)[1].item()            
        else: # get random action
            action = random.randrange(env.action_space.n)
        return action

In [5]:
class ReplayBuffer(object):
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)
    
    def push(self, state, action, reward, next_state, done):
        # Add batch index dimension to state representations
        state = np.expand_dims(state, 0)
        next_state = np.expand_dims(next_state, 0)            
        self.buffer.append((state, action, reward, next_state, done))
    
    def sample(self, batch_size):
        state, action, reward, next_state, done = zip(*random.sample(self.buffer, batch_size))
        return np.concatenate(state), action, reward, np.concatenate(next_state), done
    
    def __len__(self):
        return len(self.buffer)

In [10]:
from itertools import count
import random
#from dqn.replay_memory import Transition
import torch.nn.functional as F
import torch


def run_episode(policy_net, env, memory, device, epsilon = 0.05):
    state = env.reset()
    state = torch.FloatTensor(state).unsqueeze(0).to(device)
    store = False
    for t in count():
        if env.to_play() == 1:
            q_values = policy_net(state)
            rand = random.uniform(0,1)

            if rand <= epsilon:
                action = random.choice(env.legal_actions())
            else:
                actions = torch.argsort(q_values,descending=True)
                #action = actions[0,0].item()
                for a in actions[0,:]:
                     if a.item() in env.legal_actions():
                        action = a
                        break

            action_tensor = torch.tensor([[action]], dtype=torch.int64).to(device)
            next_state_0, reward, done = env.step(action)
            next_state_0 = torch.FloatTensor(next_state_0).unsqueeze(0).to(device)
            reward = torch.FloatTensor([reward]).to(device)

            next_state_1 = None
            store = False
        else:
            action_exp = env.expert_action()

            next_state_1, reward, done = env.step(action_exp)
            next_state_1 = torch.FloatTensor(next_state_1).unsqueeze(0).to(device)
            reward = torch.FloatTensor([-reward]).to(device)
            store = True

        if done == True:
            memory.push(state, action_tensor, None, reward)
            break
        else:
            if next_state_1 != None:
                memory.push(state, action_tensor, next_state_1, reward)
        if store == True:
            state = next_state_1
    return reward.item()

def train_network(policy_net, target_net, optimizer, gamma, memory, batch_size):
    if len(memory) < batch_size:
        return
    batch = memory.sample(batch_size)

    batch = Transition(*zip(*batch))

    non_final_mask = torch.tensor(tuple(map(lambda s: s is not None,
    batch.next_state)), device=device)

    non_final_next_states = torch.cat([s for s in batch.next_state
    if s is not None])

    state_batch = torch.cat(batch.state)
    action_batch = torch.cat(batch.action)
    reward_batch = torch.cat(batch.reward)

    q_values = policy_net(state_batch)
    state_action_values = q_values.gather(1, action_batch)

    next_state_values = torch.zeros(batch_size, device=device)
    next_state_values[non_final_mask] = target_net(non_final_next_states).max()

    expected_state_action_values = (next_state_values * gamma) + reward_batch
    loss = F.mse_loss(state_action_values, expected_state_action_values.unsqueeze())

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    return loss.item()

In [11]:
from collections import namedtuple
import random
Transition = namedtuple('Transition',('state', 'action', 'next_state', 'reward'))

class ReplayMemory(object):
    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = []
        self.position = 0

    def push(self, *args):
        """Saves a transition."""
        if len(self.memory) < self.capacity:
            self.memory.append(None)
        self.memory[self.position] = Transition(*args)
        self.position = (self.position + 1) % self.capacity

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)
    def __len__(self):
        return len(self.memory)

In [None]:
device = 'cuda:0'
policy_net = DQN().to(device)
target_net = DQN().to(device)
target_net.load_state_dict(policy_net.state_dict())
target_net.eval()

gamma = 0.3
memory = ReplayMemory(100000)
batch_size = 64
target_update = 10

In [None]:
optimizer = optim.Adam(policy_net.parameters(), lr = 0.001)
step_size = 10000
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size, gamma=0.1)

In [None]:
EPOCH = 20000
losses = []
rewards = []
reward_sum = 0
loss_sum = 0
max_epsilon = 0.20
min_epsilon = 0.0

for epoch in range(1,EPOCH+1):
    epsilon = max_epsilon*(1-(epoch/EPOCH)) + min_epsilon*((epoch/EPOCH))
    reward = run_episode(policy_net, env, memory, device, epsilon = epsilon)
    loss = train_network(policy_net, target_net, optimizer, gamma, memory, ba
    if loss == None:
        continue
    else:
        #scheduler.step()
        pass
    reward_sum += reward
    loss_sum += loss
    if epoch % 100 == 0:
        avg_loss = loss_sum / 100
        avg_reward = reward_sum / 100
        print(f"[EPOCH: {epoch}] [LOSS: {avg_loss}] [REWARD: {avg_reward}]")
        losses.append(avg_loss)
        rewards.append(avg_reward)
        loss_sum = 0
        reward_sum = 0
    if epoch % target_update == 0:
        target_net.load_state_dict(policy_net.state_dict())