In [None]:
#!/usr/bin/env python
# coding: utf-8

In [None]:
import gym
import math
import random
import numpy as np
from gym import wrappers
from IPython import display
from collections import namedtuple, deque
from itertools import count
import time

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as T

import matplotlib
import matplotlib.pyplot as plt
get_ipython().run_line_magic('matplotlib', 'inline')

from pyvirtualdisplay import Display

virtual_display = Display(visible=0, size=(1400, 900))
virtual_display.start()


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
env = gym.make('MsPacman-ram-v0').unwrapped
#env = gym.make('CartPole-v0').unwrapped
#print(env.reset())

In [None]:
DISPLAY=1

In [None]:
BATCH_SIZE = 128
GAMMA = 0.9
# EPS_START = 0.95
# EPS_END = 0.05
# EPS_DECAY = 10000
TARGET_UPDATE = 10
REPLAY_MEMORY_SIZE = 20000
LEARNING_RATE = 0.0003
FRAME_SKIP_SIZE = 15

In [None]:
Transition = namedtuple('Transition',
                        ('state', 'action', 'next_state', 'reward'))

class ReplayMemory(object):

    def __init__(self, capacity):
        self.memory = deque([],maxlen=capacity)

    def push(self, *args):
        self.memory.append(Transition(*args))

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

In [None]:
steps_done = 0
def select_action(state):
    global steps_done
    sample = random.random()
    eps_threshold = EPS_END + (EPS_START - EPS_END) * math.exp(-1. * steps_done / EPS_DECAY)
    steps_done += 1
    if sample > eps_threshold:
        with torch.no_grad():
            return policy_net(state.to(device)).max(1)[1].view(1, 1)
    else:
        x = torch.tensor([[random.randrange(n_actions)]], device=device, dtype=torch.long)
        return x

In [None]:
class DQN(nn.Module):

    def __init__(self):
        super(DQN, self).__init__()
        self.conv1(6, 8, 8, 4)
        self.conv2(8, 16, 4, 3)
        self.conv3(16, 32, 3, 2)
        self.fc1 = nn.Linear(32*5*5, 512)
        self.fc2 = nn.Linear(512, 128)
        self.fc3 = nn.Linear(128, 4)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [None]:
def optimize_model():
    if len(memory) < BATCH_SIZE:
        return
    transitions = memory.sample(BATCH_SIZE)

    batch = Transition(*zip(*transitions))
    non_final_mask = torch.tensor(tuple(map(lambda s: s is not None,
                                            batch.next_state)), device=device, dtype=torch.bool)
    non_final_next_states = torch.cat([s for s in batch.next_state if s is not None])

    state_batch = torch.cat(batch.state).to(device)
    action_batch = torch.cat(batch.action)
    reward_batch = torch.cat(batch.reward)

    state_action_values = policy_net(state_batch).gather(1, action_batch)

    next_state_values = torch.zeros(BATCH_SIZE, device=device)
    next_state_values[non_final_mask] = target_net(non_final_next_states.to(device)).max(1)[0].detach()

    expected_state_action_values = (next_state_values * GAMMA) + reward_batch

    criterion = nn.SmoothL1Loss()
    loss = criterion(state_action_values, expected_state_action_values.unsqueeze(1))

    optimizer.zero_grad()
    loss.backward()
    for param in policy_net.parameters():
        param.grad.data.clamp_(-1, 1)
    optimizer.step()

In [None]:
n_actions = 4
steps_done = 0

policy_net = DQN().to(device)
policy_net.train()
target_net = DQN().to(device)
target_net.load_state_dict(policy_net.state_dict())
target_net.eval()

optimizer = optim.Adam(policy_net.parameters(), lr=LEARNING_RATE)
memory = ReplayMemory(REPLAY_MEMORY_SIZE)

In [None]:
episode_durations = []
num_episodes = 2000
time_start = time.time()
for i_episode in range(num_episodes):
    observation = torch.FloatTensor(env.reset())
    while observation[10] == 88:
        observation, _, _, _ = env.step(0)
    observation = torch.FloatTensor([observation])
    pre_lives = 3
    for t in count():
        action = select_action(observation)
        for skip in range(FRAME_SKIP_SIZE):
            next_observation, reward, done, info = env.step(action.item() + 1)

        next_observation, reward, done, info = env.step(action.item() + 1)

        next_observation = torch.FloatTensor([next_observation])

        if done:
            next_observation = None

        cur_lives = info['lives']
        if cur_lives != pre_lives :
            reward -= 100
            if cur_lives != 0:
                pre_lives = cur_lives
                tmp_observation, _, _, _ = env.step(0)
                while tmp_observation[10] == 88:
                    tmp_observation, _, done, _ = env.step(0)
        reward = torch.tensor([reward], device=device)
        memory.push(observation, action, next_observation, reward)

        observation = next_observation

        optimize_model()
        if done:
            episode_durations.append(t + 1)
            break
    if i_episode % TARGET_UPDATE == 0:
        target_net.load_state_dict(policy_net.state_dict())

    time_end = time.time();
    if i_episode % 50 == 0:
        print("ep ", i_episode, "finished ", time_end - time_start)

print('Complete')

In [None]:
# action
# 0: none
# 1: up
# 2: right
# 3: left
# 4: down
# 5: rightup
# 6: leftup
# 7: rightdown
# 8: leftdown
# 2500
img = plt.imshow(env.render(mode='rgb_array')) # only call this once
state = torch.FloatTensor([env.reset()])
print(state)
pre_lives = 3
for _ in count():
    if True:
        img.set_data(env.render(mode='rgb_array')) # just update the data
        display.display(plt.gcf())
        display.clear_output(wait=True)
    #action = select_action(state)
    #print(target_net(state.to(device)))
    action = target_net(state.to(device)).max(1)[1].view(1, 1)

    state, _, done, _ = env.step(action.item() + 1)

    state = torch.FloatTensor([state])
    #print(env.action_space)
    #observation, reward, done, info = env.step(action)
    #print(reward, info)
    if done:
        break
env.close()

In [None]:

state = env.reset()
state = torch.FloatTensor([state]).to(device)
print(state)
res = policy_net(state)
print(res)
print(policy_net(state).max(1)[1].view(1, 1))

In [None]:
print(episode_durations)

In [None]:
torch.save(target_net, "./2000")

In [None]:
meow = DQN().to(device)

In [None]:
meow.load("./1500")

In [None]:
print(steps_done)