In [1]:
import pandas as pd
import datetime

In [2]:
pd.options.display.max_columns=500
pd.options.display.max_colwidth = None
pd.options.display.max_rows = None

In [3]:
import time
import threading

import pandas as pd

In [4]:
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

  from .autonotebook import tqdm as notebook_tqdm


device(type='cuda')

### CODE

In [5]:
import gym

import random
import collections

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torchvision import transforms

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


HyperParameter = collections.namedtuple('HyperParameter',
                                        ['batch_size', 'gamma', 'learning_rate', 'buffer_limit'])


class Qnet(nn.Module):

    def __init__(self):
        super(Qnet, self).__init__()
        self._conv1 = nn.Conv2d(in_channels=3,
                                out_channels=16,
                                kernel_size=8,
                                stride=4,
                                device=device)

        self._bn1 = nn.BatchNorm2d(16, device=device)

        self._conv2 = nn.Conv2d(in_channels=16,
                                out_channels=32,
                                kernel_size=4,
                                stride=2,
                                device=device)

        self._bn2 = nn.BatchNorm2d(32,
                                   device=device)

        self._ln1 = nn.Linear(2592, 256,
                              device=device)
        self._ln2 = nn.Linear(256, 9,
                              device=device)

    def forward(self, x):
        x = x.to(device)
        x = F.relu(self._bn1(self._conv1(x)))
        x = F.relu(self._bn2(self._conv2(x)))

        x = x.view(-1) if x.dim() == 3 else x.view(x.shape[0], -1)

        x = F.relu(self._ln1(x))
        x = self._ln2(x)

        return x


class ReplayBuffer:

    def __init__(self, buffer_limit):
        self._buffer = collections.deque(maxlen=buffer_limit)

    @property
    def size(self):
        return len(self._buffer)

    def put(self, state, state_prime, action, reward, done):
        self._buffer.append((state, state_prime, action, reward, done))

    def sample(self, n):
        mini_batch = random.sample(self._buffer, n)
        state_list, action_list, reward_list, state_prime_list, done_mask_list = [], [], [], [], []

        for transition in mini_batch:
            state, state_prime, action, reward, done_mask = transition

            state_list.append(state)
            action_list.append([action])
            reward_list.append([reward])
            state_prime_list.append(state_prime)
            done_mask_list.append([done_mask])

        return torch.stack(state_list), torch.tensor(action_list), \
               torch.tensor(reward_list), torch.stack(state_prime_list), \
               torch.tensor(done_mask_list)

    def reset(self):
        self._buffer.clear()


class DQNAgent:

    def __init__(self, param: HyperParameter, path=None):
        self._PARAMETER = param

        self._memory = ReplayBuffer(param.buffer_limit)
        self._policy_network = None
        self._target_network = None
        
        
        if path:
            self.load(path)
        else:
            self._policy_network = Qnet()
            self._target_network = Qnet()

            self.update_network()

        self._optimizer = optim.Adam(self._policy_network.parameters(), lr=param.learning_rate)

    def update_network(self):
        self._target_network.load_state_dict(self._policy_network.state_dict())

    def predict(self, state, epsilon):

        out = self._policy_network(state.unsqueeze(0))
        r = random.random()

        # epsilon greedy
        if r < epsilon:
            return random.randint(0, 8)
        else:
            return out.argmax().item()

    def step(self, env, state, action):

        state_prime, reward, done, info = env.step(action)

        self._memory.put(
            state=state,
            state_prime=state_prime,
            action=action,
            reward=reward,
            done=done
        )

        return state_prime, reward, done, info

    def train(self):
        state_list, action_list, reward_list, state_prime_list, \
        done_mask_list = self._memory.sample(self._PARAMETER.batch_size)

        output = self._policy_network(state_list)
        q_action = output.gather(1, action_list)

        max_q_prime = self._target_network(state_prime_list).max(1)[0].unsqueeze(1)
        target = reward_list + self._PARAMETER.gamma * max_q_prime * done_mask_list

        loss = F.smooth_l1_loss(q_action, target)

        self._optimizer.zero_grad()
        loss.backward()
        self._optimizer.step()

    def save(self, path):
        torch.save(self._policy_network.state_dict(), path)

    def load(self, path):
        self._policy_network = Qnet()
        self._target_network = Qnet()

        self._policy_network.load_state_dict(torch.load(path))
        self._policy_network.eval()

        self._memory.reset()

        self.update_network()


class Environment(gym.Wrapper):
    move = 0
    eat = 50
    death = -1000

    def __init__(self):
        super(Environment, self).__init__(gym.make('MsPacman-v0'))

        self._move_reward = Environment.move
        self._eat_reward = Environment.eat
        self._death_reward = Environment.death

        self._metadata = None

    def reset(self,
              reward_move: int = move,
              reward_eat: int = eat,
              reward_death: int = death,
              **kwargs):

        self._move_reward = reward_move
        self._eat_reward = reward_eat
        self._death_reward = reward_death

        state = super(Environment, self).reset(**kwargs)
        return self.observation(state)

    def step(self, action):
        state_prime, reward, done, info = super(Environment, self).step(action)

        state_prime = self.observation(state_prime)
        reward = self.reward(reward, info)

        self._metadata = info

        return state_prime, reward, done, info

    def reward(self, reward, info):

        new_reward = 0

        # move
        if reward == 0:
            new_reward = self._move_reward
        # eat
        elif reward == 10:
            new_reward = self._eat_reward

        if self._metadata and self._metadata['lives'] > info['lives']:
            new_reward -= 1000

        return new_reward

    def observation(self, observation):
        observation = observation[1:172, 1:160]

        transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize((84, 84)),
            transforms.ToTensor()
        ])

        return transform(observation)



savefile = '20220516121000.pt'

def main():

    parameter = HyperParameter(
        batch_size=32,
        buffer_limit=50000,
        gamma=0.98,
        learning_rate=0.1
    )
    
    
    env = Environment()
    agent = DQNAgent(param=parameter, path=savefile)

    print_interval = 20
    score = 0.0

    for n_epi in range(100000):
        epsilon = 0.5
        
        state = env.reset()
        done = False
        
        episode_score = 0
        
        while not done:
            action = agent.predict(state, epsilon)
            state_prime, reward, done, info = agent.step(env, state, action)
            print(action)
            state = state_prime
            
            score += reward   
            episode_score += reward
            
            import time
            
            time.sleep(0.01)
            
            env.render()
            
        print(episode_score)
            
        if n_epi % print_interval == 0 and n_epi != 0:
            agent.update_network()

            print("n_episode :{}, score : {:.1f}, eps : {:.1f}%".format(
                n_epi, score / print_interval, epsilon * 100))
            score = 0.0

main()       

8


  "We strongly suggest supplying `render_mode` when "


8
3
3
3
4
3
1
3
7
6
2
8
6
8
3
3
8
3
6
3
3
3
2
8
3
3
3
3
3
7
8
3
3
3
3
3
7
3
1
3
2
3
3
3
1
3
3
3
3
0
3
8
2
3
6
3
7
3
3
1
3
3
5
3
8
3
3
2
7
3
7
3
3
5
3
3
7
7
5
6
4
0
3
0
5
0
3
6
1
3
8
3
4
6
3
7
3
6
3
3
3
8
4
3
2
5
6
8
5
5
8
4
3
1
6
6
3
3
5
6
5
5
3
5
3
3
3
5
5
0
5
5
6
5
8
6
3
5
5
3
3
5
5
1
4
3
3
4
5
3
1
3
1
3
3
7
3
3
3
3
2
3
0
3
6
6
7
3
5
2
5
6
6
3
3
2
3
3
3
6
6
5
5
6
6
0
3
6
2
5
6
3
2
0
3
4
3
6
5
5
4
2
3
3
5
5
6
0
3
0
5
6
3
0
6
3
8
3
3
6
3
3
1
5
2
0
8
5
5
8
4
3
3
3
5
5
3
3
3
3
3
5
1
3
3
3
7
3
4
3
7
1
3
6
1
6
1
1
6
4
5
6
1
6
7
0
5
3
3
4
0
3
3
5
4
3
7
5
3
2
3
5
5
3
5
5
5
1
3
6
4
3
4
3
3
3
2
6
5
4
5
5
3
3
3
5
8
3
3
6
8
5
3
3
5
3
2
8
3
3
3
3
8
5
3
0
6
6
6
3
3
1
0
5
4
8
3
6
3
3
2
3
3
6
6
3
6
3
1
6
2
1
8
3
3
3
3
6
3
3
4
5
5
7
3
3
3
3
3
3
7
4
3
8
2
3
3
3
5
8
0
0
3
3
5
3
3
3
3
3
3
4
7
1
3
2
3
5
7
4
3
1
2
3
3
8
3
3
0
5
8
5
5
4
5
5
5
3
8
6
3
5
7
8
6
5
3
5
6
5
5
0
5
5
4
0
3
0
4
5
0
2
6
5
5
3
8
4
1
1
4
3
6
7
3
8
3
3
4
3
1
1
5
4
0
3
4
5
3
7
5
5
2
0
6
6
5
7
3
3
8
0
3
3
4
3
3
3
4
7
3
3
3
3
3
3
5
1
3
3


KeyboardInterrupt: 