# Yatzy - Deep RL - Gen 8.1
##### Uses 3 networks as opposed to one

## Importing

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [2]:
from collections import namedtuple, deque
import random
import math
import itertools as it

In [3]:
from player import Player
from hand import YatzyHand
from deep_rl import create_state, get_indices, check_end, first_turn, check_state_end

from ai_gen_8 import AIGenEightPointOne
from deeprlgame2 import DeepRLGame3DQN

In [4]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [5]:
MovesTransition = namedtuple('MoveTransition', 
                            ('state', 'move', 'reward'))
RerollTransition = namedtuple('RerollTransition',
                            ('state', 'reroll', 'next_state'))
IndicesTransition = namedtuple('IndicesTransition', 
                            ('state', 'indices', 'next_state'))

## Replay Memories and Networks

In [6]:
class MovesReplayMemory(object):
    def __init__(self, capacity):
        self.memory = deque([], maxlen=capacity)
    
    def push(self, *args):
        self.memory.append(MovesTransition(*args))
    
    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)
    
    def __len__(self):
        return len(self.memory)

In [7]:
class IndicesReplayMemory(object):
    def __init__(self, capacity):
        self.memory = deque([], maxlen=capacity)
    
    def push(self, *args):
        self.memory.append(IndicesTransition(*args))
    
    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)
    
    def __len__(self):
        return len(self.memory)

In [8]:
class RerollReplayMemory(object):
    def __init__(self, capacity):
        self.memory = deque([], maxlen=capacity)
    
    def push(self, *args):
        self.memory.append(RerollTransition(*args))
    
    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)
    
    def __len__(self):
        return len(self.memory)

In [9]:
class MovesDQN(nn.Module):

    def __init__(self):
        super().__init__()
        self.softmax = nn.Softmax(dim=0)
        self.layer1 = nn.Linear(20, 64)
        self.layer2 = nn.Linear(64, 32)
        self.layer3 = nn.Linear(32, 15)
    
    def forward(self, x):
        x = F.relu(self.layer1(x))
        x = F.relu(self.layer2(x))
        return self.softmax(self.layer3(x))

In [10]:
class IndicesDQN(nn.Module):

    def __init__(self):
        super().__init__()
        self.softmax = nn.Softmax(dim=0)
        self.layer1 = nn.Linear(20, 64)
        self.layer2 = nn.Linear(64, 32)
        self.layer3 = nn.Linear(32, 31)
    
    def forward(self, x):
        x = F.relu(self.layer1(x))
        x = F.relu(self.layer2(x))
        return self.softmax(self.layer3(x))

In [11]:
class RerollDQN(nn.Module):

    def __init__(self):
        super().__init__()
        self.softmax = nn.Softmax(dim=0)
        self.layer1 = nn.Linear(20, 64)
        self.layer2 = nn.Linear(64, 32)
        self.layer3 = nn.Linear(32, 2)
    
    def forward(self, x):
        x = F.relu(self.layer1(x))
        x = F.relu(self.layer2(x))
        return self.softmax(self.layer3(x))

## Helper Functions

In [12]:
def create_action_dicts():
    moves = {0: 'ones', 1: 'twos', 2: 'threes', 3: 'fours', 4: 'fives',
  5: 'sixes', 6: 'one_pair', 7: 'two_pair', 8: 'three_kind', 9: 'four_kind', 10: 'small_straight',
  11: 'large_straight', 12: 'full_house', 13: 'chance', 14: 'yatzy'}

    reroll = {0: 'True', 1: 'False'}

    indices = {0: (0,), 1: (1,), 2: (2,), 3: (3,), 4: (4,), 5: (0, 1), 6: (0, 2),
  7: (0, 3), 8: (0, 4), 9: (1, 2), 10: (1, 3), 11: (1, 4), 12: (2, 3), 13: (2, 4),
  14: (3, 4), 15: (0, 1, 2), 16: (0, 1, 3), 17: (0, 1, 4), 18: (0, 2, 3), 19: (0, 2, 4), 
  20: (0, 3, 4), 21: (1, 2, 3), 22: (1, 2, 4), 23: (1, 3, 4), 24: (2, 3, 4), 25: (0, 1, 2, 3),
  26: (0, 1, 2, 4), 27: (0, 1, 3, 4), 28: (0, 2, 3, 4), 29: (1, 2, 3, 4), 30: (0, 1, 2, 3, 4)}

    
    return moves, reroll, indices

In [13]:
def available_moves(scoresheet):
    available = {}
    for i, (key, value) in enumerate(scoresheet.items()):
        if value is not None:
            available[i] = False
        else:
            available[i] = True
    return available

In [14]:
def available_rerolls():
    return {0: True, 1: True}

In [15]:
def available_indices():
    final = {}
    for i in range(1, 6):
         for j, comb in enumerate(it.combinations([0, 1, 2, 3, 4], i)):
            final[j] = comb
    return final

In [16]:
def select_move(state, available_moves, episode_num, eps=True):
    sample = random.random()
    eps_threshold = eps_final + (eps_initial - eps_final) * math.exp(-1.0 * episode_num / eps_decay)

    if sample < eps_threshold and eps:
        return torch.tensor(random.choice([x for x in available_moves if available_moves[x]])).view(1).to(device)
    else:
        with torch.no_grad():
            results = moves_target(state)
            for i in range(15):
                if available_moves[i] == False:
                    results[i] = 0
            return torch.argmax(results).view(1).to(device)

In [17]:
def select_reroll(state, available_rerolls, episode_num, eps=True):
    sample = random.random()
    eps_threshold = eps_final + (eps_initial - eps_final) * math.exp(-1.0 * episode_num / eps_decay)

    if sample < eps_threshold and eps:
        return torch.tensor(random.choice([x for x in available_rerolls if available_rerolls[x]])).view(1).to(device)
    else:
        with torch.no_grad():
            results = reroll_target(state)
            return torch.argmax(results).view(1).to(device)

In [18]:
def select_indices(state, available_indices, episode_num, eps=True):
    sample = random.random()
    eps_threshold = eps_final + (eps_initial - eps_final) * math.exp(-1.0 * episode_num / eps_decay)

    if sample < eps_threshold and eps:
        return torch.tensor(random.choice([x for x in available_indices if available_indices[x]])).view(1).to(device)
    else:
        with torch.no_grad():
            results = indices_target(state)
            return torch.argmax(results).view(1).to(device)

In [19]:
def optimize_moves():
    if len(moves_memory) < bs:
        return
    transitions = moves_memory.sample(bs)

    batch = MovesTransition(*zip(*transitions))

    state_batch = torch.stack(batch.state)
    action_batch = torch.stack(batch.move)
    reward_batch = torch.cat(batch.reward)

    state_action_values = moves_policy(state_batch).gather(1, action_batch).squeeze()
    
    
    expected_state_action_values = reward_batch

    criterion = nn.SmoothL1Loss()
    loss = criterion(state_action_values, expected_state_action_values)

    moves_opt.zero_grad()
    loss.backward()
    moves_opt.step()


In [20]:
def optimize_reroll():
    if len(reroll_memory) < bs:
        return
    transitions = reroll_memory.sample(bs)

    batch = RerollTransition(*zip(*transitions))

    non_final_mask = torch.tensor(tuple(map(lambda s: not check_state_end(s), batch.next_state)), device=device, dtype = torch.bool)

    state_batch = torch.stack(batch.state)
    action_batch = torch.stack(batch.reroll)
    non_final_next_states = torch.stack([s for s in batch.next_state if not check_state_end(s)])

    state_action_values = reroll_policy(state_batch).gather(1, action_batch).squeeze()
    
    next_state_values = torch.zeros(bs, device=device)
    next_state_values[non_final_mask] = reroll_target(non_final_next_states).max(1)[0].detach()

    expected_state_action_values = next_state_values 

    criterion = nn.SmoothL1Loss()
    loss = criterion(state_action_values, expected_state_action_values)

    reroll_opt.zero_grad()
    loss.backward()
    reroll_opt.step()

In [21]:
def optimize_indices():
    if len(indices_memory) < bs:
        return
    transitions = indices_memory.sample(bs)

    batch = IndicesTransition(*zip(*transitions))

    non_final_mask = torch.tensor(tuple(map(lambda s: not check_state_end(s), batch.next_state)), device=device, dtype = torch.bool)

    state_batch = torch.stack(batch.state)
    action_batch = torch.stack(batch.indices)
    non_final_next_states = torch.stack([s for s in batch.next_state if not check_state_end(s)])

    state_action_values = indices_policy(state_batch).gather(1, action_batch).squeeze()
    
    next_state_values = torch.zeros(bs, device=device)
    next_state_values[non_final_mask] = indices_target(non_final_next_states).max(1)[0].detach()
    
    expected_state_action_values = next_state_values

    criterion = nn.SmoothL1Loss()
    loss = criterion(state_action_values, expected_state_action_values)

    indices_opt.zero_grad()
    loss.backward()
    indices_opt.step()

In [22]:
def evaluate(num):
    # evaluates a model by having it play num games
    
    player = AIGenEightPointOne('karen')

    game = DeepRLGame3DQN(player.generation)

    scores = game.evaluate(num)

    print('Average score: {}'.format(sum(scores) / len(scores)))

## Training

In [23]:
eps_initial = 0.9
eps_final = 0.05
eps_decay = 100

lr = 0.01
bs = 64
TARGET_UPDATE = 10

moves_policy = MovesDQN().to(device)
moves_target = MovesDQN().to(device)
moves_target.load_state_dict(moves_policy.state_dict())
moves_target.eval()
moves_opt = optim.SGD(moves_policy.parameters(), lr=lr)
moves_memory = MovesReplayMemory(10000)

reroll_policy = RerollDQN().to(device)
reroll_target = RerollDQN().to(device)
reroll_target.load_state_dict(reroll_policy.state_dict())
reroll_target.eval()
reroll_opt = optim.SGD(reroll_policy.parameters(), lr=lr)
reroll_memory = RerollReplayMemory(10000)


indices_policy = IndicesDQN().to(device)
indices_target = IndicesDQN().to(device)
indices_target.load_state_dict(indices_policy.state_dict())
indices_target.eval()
indices_opt = optim.SGD(indices_policy.parameters(), lr=lr)
indices_memory = IndicesReplayMemory(10000)

In [24]:
def train(n, gen, num):
    moves_dict, reroll_dict, indices_dict = create_action_dicts()

    reroll_options = available_rerolls()
    indices_options = available_indices()

    eps_decay = n / 2

    for i in range(n):
        player = Player('karen')
        scoresheet = player.scoresheet

    
        while not check_end(scoresheet):
            hand = YatzyHand()
            state = create_state(scoresheet, hand)

            rerolls = 0

            reroll_action = select_reroll(state, reroll_options, i)
            
            while reroll_action == 0:
                indices_action = select_indices(state, indices_options, i)
                hand = hand.reroll(list(indices_action))
                rerolls += 1
                new_state = create_state(scoresheet, hand)

                reroll_memory.push(state, reroll_action, new_state)
                indices_memory.push(state, indices_action, new_state)

                reroll_action = select_reroll(new_state, reroll_options, i) if rerolls < 2 else 1
                state = new_state
            
            move_action = select_move(state, available_moves(scoresheet), i)
            action_name = moves_dict[move_action.item()]
            score = getattr(hand, action_name)()
            player.update_scoresheet(action_name, score)

            moves_memory.push(state, move_action, torch.tensor([score]))
        
        
        optimize_moves()
        optimize_reroll()
        optimize_indices()

        if i % TARGET_UPDATE == 0:
            moves_target.load_state_dict(moves_policy.state_dict())
            reroll_target.load_state_dict(reroll_policy.state_dict())
            indices_target.load_state_dict(indices_policy.state_dict())
        
        
        if i % 1000 == 0:
            print('Run {} done'.format(i))
            torch.save(moves_target.state_dict(), f'dqn_models/{gen}test3/{num}/moves.pt')
            torch.save(reroll_target.state_dict(), f'dqn_models/{gen}test3/{num}/reroll.pt')
            torch.save(indices_target.state_dict(), f'dqn_models/{gen}test3/{num}/indices.pt')

In [25]:
train(50000, '8.1', 1)

Run 0 done
Run 1000 done
Run 2000 done
Run 3000 done
Run 4000 done
Run 5000 done
Run 6000 done
Run 7000 done
Run 8000 done
Run 9000 done
Run 10000 done
Run 11000 done
Run 12000 done
Run 13000 done
Run 14000 done
Run 15000 done
Run 16000 done
Run 17000 done
Run 18000 done
Run 19000 done
Run 20000 done
Run 21000 done
Run 22000 done
Run 23000 done
Run 24000 done
Run 25000 done
Run 26000 done
Run 27000 done
Run 28000 done
Run 29000 done
Run 30000 done
Run 31000 done
Run 32000 done
Run 33000 done
Run 34000 done
Run 35000 done
Run 36000 done
Run 37000 done
Run 38000 done
Run 39000 done
Run 40000 done
Run 41000 done
Run 42000 done
Run 43000 done
Run 44000 done
Run 45000 done
Run 46000 done
Run 47000 done
Run 48000 done
Run 49000 done
