In [None]:
import math
import random
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from collections import namedtuple
from itertools import count
from PIL import Image
import IPython as ip

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as T
import yacht_main as yacht
from yacht_test import create_train_set

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#device = torch.device("cpu")
torch.cuda.is_available()

In [None]:
Transition = namedtuple('Transition',
                        ('state', 'action', 'next_state', 'reward'))


class ReplayMemory(object):

    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = []
        self.position = 0

    def push(self, *args):
        """transition 저장"""
        if len(self.memory) < self.capacity:
            self.memory.append(None)
        self.memory[self.position] = Transition(*args)
        self.position = (self.position + 1) % self.capacity

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

In [None]:

class DQN(nn.Module):

    def __init__(self, input_size, output_size):
        super(DQN, self).__init__()
        self.fc1 = nn.Linear(input_size, 300)
        self.fc2 = nn.Linear(300, 300)
        self.fc3 = nn.Linear(300, 50)
        self.fc6 = nn.Linear(50, output_size)
        
        
    def forward(self, x):
        x = torch.sigmoid(self.fc1(x))
        x = torch.sigmoid(self.fc2(x))
        x = torch.sigmoid(self.fc3(x))

        x = self.fc6(x)

        return x

def init_weights(m):
    if type(m) == nn.Linear:
        torch.nn.init.constant_(m,0.1)

In [None]:
EPS_START = 0.9
EPS_END = 0.05
EPS_DECAY = 500
steps_done = 0
INPUT_SIZE = 0
OUTPUT_SIZE = 0

def select_action(state):
    global steps_done
    sample = random.random()
    eps_threshold = EPS_END + (EPS_START - EPS_END) * \
        math.exp(-1. * steps_done / EPS_DECAY)
    steps_done += 1
    if sample > eps_threshold:
        with torch.no_grad():
            return policy_net(state).max(1)[1].view(1, 1)
    else:
        return torch.tensor([[random.randrange(OUTPUT_SIZE)]], device=device, dtype=torch.long)


episode_scores = []

def plot_scores():
    plt.figure(2)
    plt.clf()
    scores_t = torch.tensor(episode_scores, dtype=torch.float)
    plt.title('Training...')
    plt.xlabel('Episode')
    plt.ylabel('Score')
    plt.plot(scores_t.numpy())
    if len(scores_t) >= 50:
        means = scores_t.unfold(0, 50, 1).mean(1).view(-1)
        means = torch.cat((torch.zeros(49), means))
        plt.plot(means.numpy())

    ip.display.clear_output(wait=True)
    plt.pause(0.001)

In [None]:
def init_net(isize, osize, msize)
    INPUT_SIZE = isize
    OUTPUT_SIZE = osize

    MEMORY_SIZE = msize

    TARGET_UPDATE = 5
    TRAINSET_PERIOD = 1000


    optimizer = optim.RMSprop(policy_net.parameters())
    memory = ReplayMemory(MEMORY_SIZE)

    policy_net = DQN(INPUT_SIZE,OUTPUT_SIZE).to(device)
    target_net = DQN(INPUT_SIZE,OUTPUT_SIZE).to(device)

    init_weights(policy_net)
    target_net.load_state_dict(policy_net.state_dict())
    target_net.eval()

    EPS_START = 0.9
    EPS_END = 0.05
    EPS_DECAY = 500
    steps_done = 0

    episode_scores = []

In [None]:
BATCH_SIZE = 128
GAMMA = 0.8

def optimize_model():
    if len(memory) < BATCH_SIZE:
        return
    transitions = memory.sample(BATCH_SIZE)
    batch = Transition(*zip(*transitions))
    
    non_final_mask = torch.tensor(tuple(map(lambda s: s is not None,
                                          batch.next_state)), device=device, dtype=torch.bool)
    non_final_next_states = torch.cat([s for s in batch.next_state
                                                if s is not None])
    state_batch = torch.cat(batch.state)
    action_batch = torch.cat(batch.action)
    reward_batch = torch.cat(batch.reward)
    
    state_action_values = policy_net(state_batch).gather(1, action_batch)
    
    next_state_values = torch.zeros(BATCH_SIZE, device=device)
    next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0].detach()
    expected_state_action_values = (next_state_values * GAMMA) + reward_batch
    loss = F.smooth_l1_loss(state_action_values, expected_state_action_values.unsqueeze(1))

    optimizer.zero_grad()
    loss.backward()
    for param in policy_net.parameters():
        param.grad.data.clamp_(-1, 1)
    optimizer.step()

In [None]:
def add_trainset():
    print("Creating train set...")
    train_set_size = 3000 // 5
    train_set = create_train_set(train_set_size)
    for state, action, new_state, step_reward in train_set:
        state_tensor = torch.tensor(state, dtype=torch.float, device=device, requires_grad = False)
        action_tensor = torch.tensor([[action]], device=device, dtype=torch.long)
        new_state_tensor = torch.tensor(new_state, dtype=torch.float, device=device, requires_grad = False)
        step_reward_tensor = torch.tensor([step_reward], device=device, requires_grad = False)
        memory.push(state_tensor.reshape(1,INPUT_SIZE), action_tensor, new_state_tensor.reshape(1,INPUT_SIZE), step_reward_tensor)
    print("Created", train_set_size * 5, "train set")

In [None]:
def net1_main():
    init_net(43, 12, 10000)

    num_episodes = 500000
    made_prob = 0.66
    episode_scores = []

    for i_episode in range(num_episodes):
        yacht.reset_game()
        state, score, _, _ = yacht.get_yacht_output()
        state = torch.tensor(state, dtype=torch.float, device=device, requires_grad = False)
        for t in count():
            action = select_action(state.reshape(1,INPUT_SIZE)) + 31
            reward = yacht.update(action)

            new_state, _, done, _ = yacht.get_yacht_output()
            step_reward = torch.tensor([reward], device=device, requires_grad = False)

            if random.random() < made_prob:
                yacht.made_dice()

            if not done:
                new_state = torch.tensor(new_state, dtype=torch.float, device=device, requires_grad = False)
                memory.push(state.reshape(1,INPUT_SIZE), action, new_state.reshape(1,INPUT_SIZE), \
                        step_reward)
            else:
                new_state = None
                memory.push(state.reshape(1,INPUT_SIZE), action, None, \
                        step_reward)

            state = new_state

            optimize_model()
            if done:
                state, score, _, _ = yacht.get_yacht_output()
                episode_scores.append(score)
                #print("{0}) {1}\tscore : {2}, turns = {3}".format(i_episode, state[:12], score, t+1))

                if i_episode % 200 == 0:
                    plot_scores()

                break

        if i_episode % TARGET_UPDATE == 0:
            target_net.load_state_dict(policy_net.state_dict())
        if i_episode % 10000 == 0:
            torch.save(policy_net.state_dict(), './data/net1/net_' + str(i_episode//10000))

    print('Complete')

In [None]:
def net2_main(net1_name):
    init_net(43, 32, 10000)
    
    reward_net = DQN(INPUT_SIZE, OUTPUT_SIZE)
    reward_net.load_state_dict(torch.load('./data/net1/' + net1_name))
    reward_net.eval()
    reward_net.requires_grad = False
    
    num_episodes = 500000
    made_prob = 0.66
    episode_scores = []

    for i_episode in range(num_episodes):
        yacht.reset_game()
        state, score, _, _ = yacht.get_yacht_output()
        state = torch.tensor(state, dtype=torch.float, device=device, requires_grad = False)
        for t in count():
            action = select_action(state.reshape(1,INPUT_SIZE))
            
            if action == 31:
                action = reward_net(state).max(1)[1].view(1, 1)
            yacht.update(action)

            new_state, _, done, _ = yacht.get_yacht_output()
            step_reward = reward_net(state).max(1)[0]


            if not done:
                new_state = torch.tensor(new_state, dtype=torch.float, device=device, requires_grad = False)
                memory.push(state.reshape(1,INPUT_SIZE), action, new_state.reshape(1,INPUT_SIZE), \
                        step_reward)
            else:
                new_state = None
                memory.push(state.reshape(1,INPUT_SIZE), action, None, \
                        step_reward)

            state = new_state

            optimize_model()
            if done:
                state, score, _, _ = yacht.get_yacht_output()
                episode_scores.append(score)
                if i_episode % 200 == 0:
                    plot_scores()

                break

        if i_episode % TARGET_UPDATE == 0:
            target_net.load_state_dict(policy_net.state_dict())
        if i_episode % 10000 == 0:
            torch.save(policy_net.state_dict(), './data/net2/net_' + str(i_episode//10000))

    print('Complete')

In [None]:
def step_by_step():
    yacht.reset_game()
    
    state, _, _ = yacht.get_yacht_output()
    print("{0}\tscore : {1}, turns = {2}".format(state, 0, 0))
    state = torch.tensor(state, dtype=torch.float, device=device, requires_grad = False)
    
    for t in count():
        action = select_action(state.reshape(1,INPUT_SIZE))
        print("\nActions: {0}".format(action[0][0] ))
        yacht.update(action)
        
        state, score, done = yacht.get_yacht_output()
        print("{0}\tscore : {1}, turns = {2}".format(state, score, t+1))
        state = torch.tensor(state, dtype=torch.float, device=device, requires_grad = False)
        
        if done:
            break
step_by_step()

In [None]:
def print_memory(mem):
    print("\n\nState: {0}".format(mem.state[0].int().tolist() ))
    print("Action: {0}".format(action[0][0] ))
    if mem.next_state == None:
        print("State: None")
    else:
        print("State: {0}".format(mem.next_state[0].int().tolist() ))
    print("Reward: {0}".format(mem.reward[0][0].int() ))

In [None]:
for i in range(10):
    print_memory(memory.memory[i])

In [None]:
net1 = DQN(INPUT_SIZE, OUTPUT_SIZE)
net2 = DQN(INPUT_SIZE, OUTPUT_SIZE)
def duel_action(state):
    action = net2(state).max(1)[1]

    if action == 31:
        action = net1(non_final_next_states).max(1)[1]
    return action

In [None]:
def test_model(net1,net2):
    net1 = DQN(43, 12)
    net1.load_state_dict(torch.load('./data/net1/' + net1))
    net1.eval()
    net2 = DQN(43, 32)
    net2.load_state_dict(torch.load('./data/net2/' + net2))
    net2.eval()
    
    yacht.reset_game()
    done = False:
    while not done :
        state,reward,done,_ = yacht.get_yacht_output()
        print(state)
        action = duel_action(state)
        yacht.update(action)
        print(action)

In [None]:
net1_main()

In [None]:
net2_main('net_')

In [None]:
test_model('net_', 'net_')