In [1]:
from Sokoban_gym import gym
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical
import matplotlib.pyplot as plt
import numpy as np
import time
import random
from mlagents.envs import UnityEnvironment

In [3]:
class DQN(nn.Module):
    def __init__(self):
        super(DQN, self).__init__()
        self.conv1 = nn.Conv2d(3,32,kernel_size=8, stride=4, padding=1)
        self.conv2 = nn.Conv2d(32,64,kernel_size=4, stride=2, padding=0)
        self.conv3 = nn.Conv2d(64,64,kernel_size=3, stride=1, padding=0)
        self.fc1 = nn.Linear(3136,512)
        self.fc_q = nn.Linear(512,5)
        self.optimizer = optim.Adam(self.parameters(), lr=0.0003)
    
    def Q(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = x.reshape(x.shape[0], -1)
        x = F.relu(self.fc1(x))
        q = self.fc_q(x)
        return q
    
    def train(self, loss):
        self.optimizer.zero_grad()
        loss.backward(retain_graph=True)
        self.optimizer.step()

In [4]:
env = gym.make('Sokoban')
model = DQN()
GAMMA = 0.98
EPSILON = 0.1
BATCH_SIZE = 32
N = 30000 ## reply memory size
replay_memory = []

In [5]:
def epsilon_greedy(Q, eps, action_dim):
    if random.random() < eps:
        action = random.randint(0, action_dim-1)
    else:
        action = torch.argmax(Q).item()
    return action

def store_transition(s, a, r, s_prime, done):
    if len(replay_memory) == N:
        del(replay_memory[0])
    replay_memory.append((s, a, r, s_prime, done))
    
def training():
    s_list, r_list, a_list, s_p_list, done_list = [], [], [], [], []
    mini_batch = random.sample(replay_memory, BATCH_SIZE)
    for sample in mini_batch:
        s_list.append(sample[0].unsqueeze(0))
        s_p_list.append(sample[3].unsqueeze(0))
        r_list.append([sample[2]])
        a_list.append([sample[1]])
        done_list.append([0]) if sample[-1] else done_list.append([1])
    s = torch.cat(s_list, dim=0)
    s_p = torch.cat(s_p_list, dim=0)
    a = torch.tensor(a_list).reshape(-1,1)
    r = torch.tensor(r_list, dtype=torch.float).reshape(-1,1)
    done_mask = torch.tensor(done_list, dtype=torch.float).reshape(-1,1)
    
    cur_Q = model.Q(s)
    next_Q = model.Q(s_p)
    td_target = r + GAMMA*torch.max(next_Q, dim=1)[0].reshape(-1,1)*done_mask
    loss = (td_target.detach() - cur_Q.gather(1,a)).pow(2).mean()
    model.train(loss)

def test_agent():
    reward_sum = 0.0
    for ep in range(10):
        observation = env.reset()
        while True:
            state = torch.tensor(observation, dtype=torch.float)
            action = torch.argmax(model.Q(state))
            observation, reward, done, _ = env.step(action.item())
            reward_sum += reward
            if done: break
    return reward_sum/10.0

In [6]:
reward_list = []

for ep in range(2000):
    observation = env.reset()
    while True:
        state = torch.tensor(observation, dtype=torch.float)
        action = epsilon_greedy(model.Q(state), EPSILON, 2)
        observation, reward, done, _ = env.step(action)
        next_state = torch.tensor(observation, dtype=torch.float)
        store_transition(state.squeeze(0), action, reward, next_state.squeeze(0), done)
        if len(replay_memory) > 1000:
            training()
        if done:
            break
            
    if ep % 20 == 19:
        test_reward = test_agent()
        print('Episode %d'%ep,', Greedy action reward : %f'%(test_reward))
        reward_list.append(test_reward)
        if test_reward > 470: break
env.close()

Episode 19 , Greedy action reward : -2.010000
Episode 39 , Greedy action reward : -3.759000
Episode 59 , Greedy action reward : -2.010000
Episode 79 , Greedy action reward : -1.909000
Episode 99 , Greedy action reward : -1.810000
Episode 119 , Greedy action reward : -1.910000
Episode 139 , Greedy action reward : -1.792000
Episode 159 , Greedy action reward : -1.809000
Episode 179 , Greedy action reward : -2.010000
Episode 199 , Greedy action reward : -1.800000
Episode 219 , Greedy action reward : -1.911000
Episode 239 , Greedy action reward : -2.010000
Episode 259 , Greedy action reward : -1.709000
Episode 279 , Greedy action reward : -1.813000
Episode 299 , Greedy action reward : -3.671000
Episode 319 , Greedy action reward : -1.811000
Episode 339 , Greedy action reward : -1.913000
Episode 359 , Greedy action reward : -1.910000
Episode 379 , Greedy action reward : -1.888000
Episode 399 , Greedy action reward : -3.790000
Episode 419 , Greedy action reward : -1.911000
Episode 439 , Gree

KeyboardInterrupt: 

In [None]:
plt.plot(reward_list)
plt.show()

In [None]:
env.close()

In [None]:
for ep in range(10):
    observation = env.reset()
    while True:
        time.sleep(0.05)
        env.render()

        state = torch.tensor(observation, dtype=torch.float)
        action = torch.argmax(model.Q(state))
        observation, reward, done, _ = env.step(action.item())
        if done: break
env.close()