# Yatzy - Deep RL - Gen 8.2
##### Uses 3 networks as opposed to one; provides improvements over Gen 8.1

## Importing

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [2]:
from collections import namedtuple, deque
import random
import math
import itertools as it

In [3]:
from player import Player
from hand import YatzyHand

from ai_gen_8 import AIGenEightPointOne
from deeprlgame2 import DeepRLGame3DQN

In [4]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

## Transitions, Replay Memories and Networks

In [5]:
MovesTransition = namedtuple('MoveTransition', 
                            ('state', 'move', 'reward'))
RerollTransition = namedtuple('RerollTransition',
                            ('state', 'reroll', 'next_state'))
IndicesTransition = namedtuple('IndicesTransition', 
                            ('state', 'indices', 'next_state'))

In [6]:
class ReplayMemory(object):
    def __init__(self, capacity):
        self.memory = deque([], maxlen=capacity)
    
    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)
    
    def __len__(self):
        return len(self.memory)

In [7]:
class MovesReplayMemory(ReplayMemory):
    def push(self, *args):
        self.memory.append(MovesTransition(*args))

class IndicesReplayMemory(ReplayMemory):
    def push(self, *args):
        self.memory.append(IndicesTransition(*args))

class RerollReplayMemory(ReplayMemory):
    def push(self, *args):
        self.memory.append(RerollTransition(*args))

In [8]:
class DQN(nn.Module):

    def __init__(self, output):
        super().__init__()
        self.layer1 = nn.Linear(20, 64)
        self.layer2 = nn.Linear(64, 32)
        self.layer3 = nn.Linear(32, output)
    
    def forward(self, x):
        x = F.relu(self.layer1(x))
        x = F.relu(self.layer2(x))
        return self.layer3(x)

## Helper Functions

In [9]:
def create_state(scoresheet, hand):
    s = []

    for key, value in scoresheet.items():
        if not isinstance(value, int):
            s.append(1)
        else:
            s.append(value * 5)

    for die in hand:
        s.append(float(die) * 100)

    return torch.tensor(s)

In [10]:
def check_state_end(state):
    state = list(state)
    return all([x != 1 for x in state])

In [11]:
def create_action_dicts():
    moves = {0: 'ones', 1: 'twos', 2: 'threes', 3: 'fours', 4: 'fives',
  5: 'sixes', 6: 'one_pair', 7: 'two_pair', 8: 'three_kind', 9: 'four_kind', 10: 'small_straight',
  11: 'large_straight', 12: 'full_house', 13: 'chance', 14: 'yatzy'}

    reroll = {0: 'True', 1: 'False'}

    indices = {0: (0,), 1: (1,), 2: (2,), 3: (3,), 4: (4,), 5: (0, 1), 6: (0, 2),
  7: (0, 3), 8: (0, 4), 9: (1, 2), 10: (1, 3), 11: (1, 4), 12: (2, 3), 13: (2, 4),
  14: (3, 4), 15: (0, 1, 2), 16: (0, 1, 3), 17: (0, 1, 4), 18: (0, 2, 3), 19: (0, 2, 4), 
  20: (0, 3, 4), 21: (1, 2, 3), 22: (1, 2, 4), 23: (1, 3, 4), 24: (2, 3, 4), 25: (0, 1, 2, 3),
  26: (0, 1, 2, 4), 27: (0, 1, 3, 4), 28: (0, 2, 3, 4), 29: (1, 2, 3, 4), 30: (0, 1, 2, 3, 4)}

    
    return moves, reroll, indices

In [12]:
def available_moves(scoresheet):
    available = {}
    for i, (key, value) in enumerate(scoresheet.items()):
        if value is not None:
            available[i] = False
        else:
            available[i] = True
    return available

In [13]:
def select_move(state, available_moves, episode_num, eps=True):
    sample = random.random()
    eps_threshold = eps_final + (eps_initial - eps_final) * math.exp(-1.0 * episode_num / eps_decay)

    if sample < eps_threshold and eps:
        global total_random
        total_random += 1
        return torch.tensor(random.choice([x for x in available_moves if available_moves[x]])).view(1).to(device)
    else:
        global total_nonrand 
        total_nonrand += 1
        with torch.no_grad():
            results = moves_target(state)
            for i in range(15):
                if available_moves[i] == False:
                    results[i] = 0
            return torch.argmax(results).view(1).to(device)

In [14]:
def select_reroll(state, available_rerolls, episode_num, eps=True):
    sample = random.random()
    eps_threshold = eps_final + (eps_initial - eps_final) * math.exp(-1.0 * episode_num / eps_decay)

    if sample < eps_threshold and eps:
        global total_random
        total_random += 1
        return torch.tensor(random.choice([x for x in available_rerolls if available_rerolls[x]])).view(1).to(device)
    else:
        global total_nonrand 
        total_nonrand += 1
        with torch.no_grad():
            results = reroll_target(state)
            return torch.argmax(results).view(1).to(device)

In [15]:
def select_indices(state, available_indices, episode_num, eps=True):
    sample = random.random()
    eps_threshold = eps_final + (eps_initial - eps_final) * math.exp(-1.0 * episode_num / eps_decay)

    if sample < eps_threshold and eps:
        global total_random
        total_random += 1
        return torch.tensor(random.choice([x for x in available_indices if available_indices[x]])).view(1).to(device)
    else:
        global total_nonrand 
        total_nonrand += 1
        with torch.no_grad():
            results = indices_target(state)
            return torch.argmax(results).view(1).to(device)

In [16]:
def optimize_moves():
    if len(moves_memory) < bs:
        return
    transitions = moves_memory.sample(bs)

    batch = MovesTransition(*zip(*transitions))

    state_batch = torch.stack(batch.state)
    action_batch = torch.stack(batch.move)
    reward_batch = torch.cat(batch.reward)

    state_action_values = moves_policy(state_batch).gather(1, action_batch).squeeze()
    
    expected_state_action_values = reward_batch

    criterion = nn.SmoothL1Loss()
    loss = criterion(state_action_values, expected_state_action_values)

    moves_opt.zero_grad()
    loss.backward()
    moves_opt.step()


In [17]:
def optimize_reroll():
    if len(reroll_memory) < bs:
        return
    transitions = reroll_memory.sample(bs)

    batch = RerollTransition(*zip(*transitions))

    non_final_mask = torch.tensor(tuple(map(lambda s: not check_state_end(s), batch.next_state)), device=device, dtype = torch.bool)

    state_batch = torch.stack(batch.state)
    action_batch = torch.stack(batch.reroll)
    non_final_next_states = torch.stack([s for s in batch.next_state if not check_state_end(s)])

    state_action_values = reroll_policy(state_batch).gather(1, action_batch).squeeze()
    
    next_state_values = torch.zeros(bs, device=device)
    next_state_values[non_final_mask] = moves_target(non_final_next_states).max(1)[0].detach()

    expected_state_action_values = next_state_values 

    criterion = nn.SmoothL1Loss()
    loss = criterion(state_action_values, expected_state_action_values)

    reroll_opt.zero_grad()
    loss.backward()
    reroll_opt.step()

In [18]:
def optimize_indices():
    if len(indices_memory) < bs:
        return
    transitions = indices_memory.sample(bs)

    batch = IndicesTransition(*zip(*transitions))

    non_final_mask = torch.tensor(tuple(map(lambda s: not check_state_end(s), batch.next_state)), device=device, dtype = torch.bool)

    state_batch = torch.stack(batch.state)
    action_batch = torch.stack(batch.indices)
    non_final_next_states = torch.stack([s for s in batch.next_state if not check_state_end(s)])

    state_action_values = indices_policy(state_batch).gather(1, action_batch).squeeze()
    
    next_state_values = torch.zeros(bs, device=device)
    next_state_values[non_final_mask] = moves_target(non_final_next_states).max(1)[0].detach()
    
    expected_state_action_values = next_state_values

    criterion = nn.SmoothL1Loss()
    loss = criterion(state_action_values, expected_state_action_values)

    indices_opt.zero_grad()
    loss.backward()
    indices_opt.step()

In [19]:
def evaluate(num):
    # evaluates a model by having it play num games
    
    player = AIGenEightPointOne('karen')

    game = DeepRLGame3DQN(player.generation)

    scores = game.evaluate(num)

    print('Average score: {}'.format(sum(scores) / len(scores)))

## Training

In [20]:
eps_initial = 0.9
eps_final = 0.05

lr = 0.01
bs = 64
TARGET_UPDATE = 10

total_random = 0
total_nonrand = 0

moves_policy = DQN(15).to(device)
moves_target = DQN(15).to(device)
moves_target.load_state_dict(moves_policy.state_dict())
moves_target.eval()
moves_opt = optim.SGD(moves_policy.parameters(), lr=lr)
moves_memory = MovesReplayMemory(10000)

reroll_policy = DQN(2).to(device)
reroll_target = DQN(2).to(device)
reroll_target.load_state_dict(reroll_policy.state_dict())
reroll_target.eval()
reroll_opt = optim.SGD(reroll_policy.parameters(), lr=lr)
reroll_memory = RerollReplayMemory(10000)


indices_policy = DQN(31).to(device)
indices_target = DQN(31).to(device)
indices_target.load_state_dict(indices_policy.state_dict())
indices_target.eval()
indices_opt = optim.SGD(indices_policy.parameters(), lr=lr)
indices_memory = IndicesReplayMemory(10000)

In [21]:
def train(n, gen, num):
    moves_dict, reroll_dict, indices_dict = create_action_dicts()

    global eps_decay
    eps_decay = n

    for i in range(n):
        player = Player('karen')
        scoresheet = player.scoresheet

        turns = 0
        while turns < 15:
            hand = YatzyHand()
            state = create_state(scoresheet, hand)

            rerolls = 0

            reroll_action = select_reroll(state, reroll_dict, i)
            
            while reroll_action == 0:
                indices_action = select_indices(state, indices_dict, i)
                hand = hand.reroll(list(indices_action))
                rerolls += 1
                new_state = create_state(scoresheet, hand)

                reroll_memory.push(state, reroll_action, new_state)
                indices_memory.push(state, indices_action, new_state)

                reroll_action = select_reroll(new_state, reroll_dict, i) if rerolls < 2 else 1
                state = new_state
            
            move_action = select_move(state, available_moves(scoresheet), i)
            action_name = moves_dict[move_action.item()]
            score = getattr(hand, action_name)()
            player.update_scoresheet(action_name, score)

            moves_memory.push(state, move_action, torch.tensor([score]))
            turns += 1
        
        
        optimize_moves()
        optimize_reroll()
        optimize_indices()

        if i % TARGET_UPDATE == 0:
            moves_target.load_state_dict(moves_policy.state_dict())
            reroll_target.load_state_dict(reroll_policy.state_dict())
            indices_target.load_state_dict(indices_policy.state_dict())
        
        
        if i % 1000 == 0:
            print('Run {} done'.format(i))
            torch.save(moves_target.state_dict(), f'dqn_models/{gen}test0/{num}/moves.pt')
            torch.save(reroll_target.state_dict(), f'dqn_models/{gen}test0/{num}/reroll.pt')
            torch.save(indices_target.state_dict(), f'dqn_models/{gen}test0/{num}/indices.pt')
            print('Total random: ', total_random)
            print('Total nonrandom: ', total_nonrand)

In [27]:
train(20000, '8.2', 1)

Run 0 done
Total random:  168822
Total nonrandom:  124776
Run 1000 done
Total random:  214178
Total nonrandom:  130841
Run 2000 done
Total random:  258159
Total nonrandom:  139288
Run 3000 done
Total random:  300977
Total nonrandom:  150057
Run 4000 done
Total random:  342611
Total nonrandom:  162841
Run 5000 done
Total random:  382844
Total nonrandom:  178040
Run 6000 done
Total random:  421867
Total nonrandom:  194945
Run 7000 done
Total random:  459512
Total nonrandom:  213953
Run 8000 done
Total random:  495892
Total nonrandom:  235115
Run 9000 done
Total random:  531355
Total nonrandom:  258125
Run 10000 done
Total random:  565458
Total nonrandom:  283061
Run 11000 done
Total random:  598519
Total nonrandom:  309722
Run 12000 done
Total random:  630381
Total nonrandom:  338141
Run 13000 done
Total random:  661046
Total nonrandom:  368019
Run 14000 done
Total random:  690643
Total nonrandom:  399832
Run 15000 done
Total random:  719217
Total nonrandom:  433064
Run 16000 done
Total 