In [None]:
import math
import random
import matplotlib
import matplotlib.pyplot as plt
from collections import namedtuple, deque
from itertools import count

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
import import_ipynb
from qc_env_parity import qc

In [None]:
Transition = namedtuple('Transition',
                        ('state', 'action', 'next_state', 'reward', 'done'))

class ReplayMemory(object):

    def __init__(self, capacity):
        self.memory = deque([], maxlen=capacity)

    def push(self, *args):
        self.memory.append(Transition(*args))

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

In [None]:
class DQN(nn.Module):

    def __init__(self, n_observations, n_actions):
        super(DQN, self).__init__()
        self.flatten = nn.Flatten()
        self.layer1 = nn.Linear(n_observations, 256) ## 128 -> 256
        self.layer2 = nn.Linear(256, 256)
        self.layer3 = nn.Linear(256, n_actions)

    def forward(self, x):
        x = self.flatten(x)
        x = F.relu(self.layer1(x))
        x = F.relu(self.layer2(x))
        return self.layer3(x)

In [None]:
BATCH_SIZE = 128
GAMMA = 0.8
EPS_START = 0.9
EPS_END = 0.05
EPS_DECAY = 200
TAU = 0.005
LR = 0.001

In [None]:
env = qc()

n_actions = env.act_space
env.reset()
n_observations = len(env.obs * 4)  ## 128

In [None]:
policy_net = DQN(n_observations, n_actions).to(device)
target_net = DQN(n_observations, n_actions).to(device)
target_net.load_state_dict(policy_net.state_dict())

optimizer = optim.AdamW(policy_net.parameters(), lr=LR, amsgrad=True)
memory = ReplayMemory(10000)

steps_done = 0

In [None]:
def select_action(state):
    global steps_done
    sample = random.random()
    eps_threshold = EPS_END + (EPS_START - EPS_END) * \
        math.exp(-1. * steps_done / EPS_DECAY)
    steps_done += 1
    if sample > eps_threshold:
        #print("select :",steps_done)
        with torch.no_grad():
            return policy_net(state).max(1)[1].view(1, 1)
    else:
        return torch.tensor([[env.sample()]], device=device, dtype=torch.long)

In [None]:
def optimize_model():
    if len(memory) < BATCH_SIZE:
        return
    transitions = memory.sample(BATCH_SIZE)
    batch = Transition(*zip(*transitions))

    state_batch = torch.cat(batch.state)
    action_batch = torch.cat(batch.action)
    reward_batch = torch.cat(batch.reward)

    state_action_values = policy_net(state_batch).gather(1, action_batch)
    next_state_values = torch.zeros(BATCH_SIZE, device=device)
    
    with torch.no_grad():
        next_state_values = target_net(torch.cat(batch.next_state)).max(1)[0]

    expected_state_action_values = (next_state_values * GAMMA) + reward_batch

    criterion = nn.SmoothL1Loss()
    loss = criterion(state_action_values, expected_state_action_values.unsqueeze(1))

    optimizer.zero_grad()
    loss.backward()

    torch.nn.utils.clip_grad_value_(policy_net.parameters(), 100)
    optimizer.step()

In [None]:
reward_ep_list = []
reward_sum_ep_list = []
obs_ep_list = []
outs_ep_list = []

In [None]:
num_episodes = 5000

In [None]:
%%time
for i_episode in range(num_episodes):
    env.reset()
    state = env.obs
    state = torch.tensor(state, dtype=torch.float32, device=device).unsqueeze(0)

    reward_list = []; obs_list = []; outs_list = []
    
    for t in count():
        action = select_action(state)
        truncated = not env.step(action.item())

        if truncated:
            print('truncated error')
            break
        
        observation = env.obs
        reward = env.reward
        terminated = env.term
        done = env.done

        reward = torch.tensor([reward], device=device)
        next_state = torch.tensor(observation, dtype=torch.float32, device=device).unsqueeze(0)
        memory.push(state, action, next_state, reward, done)
        state = next_state

        optimize_model()

        target_net_state_dict = target_net.state_dict()
        policy_net_state_dict = policy_net.state_dict()
        for key in policy_net_state_dict:
            target_net_state_dict[key] = policy_net_state_dict[key]*TAU + target_net_state_dict[key]*(1-TAU)
        target_net.load_state_dict(target_net_state_dict)
        
        reward_list.append(reward)
        obs_list.append(env.draw)
        outs_list.append(env.outs)
        
        if terminated:
            max_value = max(reward_list)
            max_index = reward_list.index(max_value)
            reward_ep_list.append(max_value)
            reward_sum_ep_list.append(sum(reward_list))
            obs_ep_list.append(obs_list[max_index])
            outs_ep_list.append(outs_list[max_index])
            print("Episode complete : ", i_episode+1,"(", t+1, ")")
            break

In [None]:
import pickle

In [None]:
with open('reward_ep_list.pkl', 'wb') as file: pickle.dump(reward_ep_list, file)
with open('reward_sum_ep_list.pkl', 'wb') as file: pickle.dump(reward_sum_ep_list, file)
with open('obs_ep_list.pkl', 'wb') as file: pickle.dump(obs_ep_list, file)
with open('outs_ep_list.pkl', 'wb') as file: pickle.dump(outs_ep_list, file)