# Five Crowns Deep-Q-Learning

### Import Libraries

In [10]:
import torch.nn as nn
import torch.optim as optim
import random
from collections import deque

from copy import deepcopy
from five_crowns import Game
from greedy import GreedyPlayer
from scoring import score_hand
from state import State

### Define the Network

In [11]:
class DQN(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(DQN, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(state_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, action_dim)
        )

    def forward(self, x):
        return self.fc(x)

### Define the Memory Buffer

In [12]:
class ReplayBuffer:
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)

    def add(self, experience):
        self.buffer.append(experience)

    def sample(self, batch_size):
        return random.sample(self.buffer, batch_size)

    def size(self):
        return len(self.buffer)

In [13]:
def payoff(game, current_player):
    """
    Calculates the reward for the player at this state

    Returns:
        int: The reward for the current player in the state
    """
    if not game._go_out:
        return 0
    # print(self.game.get_player_hand(self.curr_player_id))
    hand_score = score_hand(
        game.get_player_hand(current_player), game
    )
    return 1 if hand_score == 0 else -1

def successor(self, game, action, num_players):
    """
    Returns the successor state given the action from the current state

    Args:
        - action: The action to take from the current state

    Returns:
        - State: The new state after taking the action
    """
    # Make a copy for the successor state
    new_state = deepcopy(game)

    # Get the next player
    next_player_id = (game.get_active_player() + 1) % num_players
    new_state._active_player = next_player_id

    # Get the actions
    first_action = action[0]
    second_action = action[1]

    # Execute the first action
    added_card = None
    if first_action == "deck":
        added_card = new_state.get_deck().draw()
        new_state.get_player_hand(game.get_active_player()).append(added_card)
    elif first_action == "discard":
        added_card = new_state.get_discard_pile().pop()
        new_state.get_player_hand(game.get_active_player()).append(added_card)

    # Execute the second action
    if second_action == None:
        new_state.get_player_hand(self.curr_player_id).remove(added_card)
        new_state.get_discard_pile().append(added_card)
    else:
        new_state.get_player_hand(
            self.curr_player_id).remove(second_action)
        new_state.get_discard_pile().append(second_action)

    # Check the hand score and update game ending conditions
    # print(new_state.get_player_hand(self.curr_player_id))
    hand_score = score_hand(
        new_state.get_player_hand(self.curr_player_id), new_state
    )
    if hand_score == 0:
        new_state._go_out = True
        if new_state._remaining_players == new_state.num_players():
            new_state._go_out_player = self.curr_player_id
        new_state._remaining_players -= 1
        if new_state._remaining_players == 0:
            new_state._game_over = True

    return State(new_state)

def get_actions(self):
      """
      Returns all possible actions from the current state
      """
      if self.is_root:
          actions = [("root", c) for c in self.curr_player_hand if c != self.root_card]

      else:
          # initialize actions list
          actions = []

          if self.discard_pile_card:
              #actions.append(("discard", self.discard_pile_card))
              for card in self.curr_player_hand:
                  actions.append(("discard", card))
          if game.get_deck():
              actions.append(("deck", None))
              for card in self.curr_player_hand:
                  actions.append(("deck", card))

      return actions

In [14]:
import torch
import numpy as np
from deck import Card

def card_to_idx(suit, rank):
  suit_dict = {
      'Clubs': 0,
      'Diamonds': 1,
      'Hearts': 2,
      'Spades': 3,
      'Stars': 4,
      'J': 5
  }

  return 11 * suit_dict[suit] + (rank if suit != "J" else 0) - 3

def idx_to_card(idx):
    suit_dict = {
        0: 'Clubs',
        1: 'Diamonds',
        2: 'Hearts',
        3: 'Spades',
        4: 'Stars',
        5: 'J'
    }

    suit = suit_dict[idx // 11]
    if suit == "J":
        rank = 50
    else:
        rank = (idx % 11) + 3

    return Card(rank, suit)

def encode_state(num_players, full_deck, player_deck, discard_card, gone_out_status):
    num_players = num_players

    deck = set(full_deck)

    encoded_deck = np.zeros(len(deck))
    for idx, card in enumerate(player_deck):
        card_idx = card_to_idx(card.suit(), card.rank())
        encoded_deck[card_idx] += 1

    # Encode discard card as (rank, suit)
    discard_card_encoded = np.zeros(len(deck))
    if discard_card is not None:
        discard_idx = card_to_idx(discard_card.suit(), discard_card.rank())
        discard_card_encoded[discard_idx] = 1

    # Gone out status
    gone_out_status_encoded = int(gone_out_status)

    return np.concatenate([
        encoded_deck,
        discard_card_encoded,
        [gone_out_status_encoded]
    ])

def inference(game, hand, discard_card, policy_net):
    encoded_state = encode_state(game.num_players(), game.get_full_deck()._cards, hand, discard_card, game._go_out)
    model = policy_net(encoded_state.shape[0], 56)
    model.eval()

    with torch.no_grad():
        output = model(encoded_state).item()

        sorted_list = [(output[i], idx_to_card(i)) for i in range(len(output))]
        sorted_list.sort(key=lambda x: x[0], reverse=True)


        for i in range(len(sorted_list)):
            if sorted_list[i][1] in hand:
                return sorted_list[i][1]

In [15]:
from player import Player
from scoring import get_best_discard
from constants import GET_DISCARD, DRAW_CARD
import copy

class DQNPlayer(Player):
    """
    DQN player always takes action that minimize score for turn
    """
    def __init__(self, player_id, policy_net):
        super().__init__(player_id)
        self.prev_discard = None
        self.epsilon = 1.0
        self.epsilon_decay = 0.995
        self.min_epsilon = 0.1
        self.prev_action = None
        self.policy_net = policy_net

    def draw_phase(self, game):
        # Get best score if we take discard
        new_card = game.get_discard_pile()[-1]
        temp_hand = self.hand + [new_card]
        _, discard_score = get_best_discard(temp_hand,game,excluded_discard=new_card)

        # Get best expected score if we draw random
        remaining_deck = copy.deepcopy(game.get_full_deck().get_cards())
        draw_scores = []
        for card in game.get_discard_pile() + self.hand:
            remaining_deck.remove(card)
        for card in remaining_deck:
            temp_hand = self.hand + [card]
            _, draw_score = get_best_discard(temp_hand, game)
            draw_scores.append(draw_score)
        expected_draw_score = sum(draw_scores)/len(draw_scores)

        # Take action with better expected score
        if discard_score < expected_draw_score:
            self.prev_discard = game.get_discard_pile()[-1]
            return GET_DISCARD
        self.prev_discard = None
        return DRAW_CARD

    def discard_phase(self, game):
        if random.random() < self.epsilon:
            action = random.choice([card for card in self.hand if card != self.prev_discard])
        else:
            with torch.no_grad():
                action = inference(game, self.hand, self.prev_discard, self.policy_net)

        epsilon = max(self.min_epsilon, self.epsilon * self.epsilon_decay)

        encoded_action = np.zeros(56)
        idx = card_to_idx(action.suit(), action.rank())
        encoded_action[idx] = 1
        self.prev_action = idx
        return action


In [16]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")
agents = 4
state_dim = 113
action_dim = 56

policy_net = DQN(state_dim, action_dim).to(device)
target_net = DQN(state_dim, action_dim).to(device)
target_net.load_state_dict(policy_net.state_dict())
target_net.eval()

dqn_player = DQNPlayer(0, policy_net)
players = [dqn_player] + [GreedyPlayer(i) for i in range(1, agents)]

env = Game(players)
env.initialize_game()

state_encoded = encode_state(agents, env.get_full_deck()._cards, env._players[0].hand, env._discard_pile[-1], 0)

optimizer = optim.Adam(policy_net.parameters(), lr=1e-4)
buffer = ReplayBuffer(10000)
batch_size = 64
gamma = 0.99
target_update_freq = 50

### Training Loop

In [17]:
for episode in range(1000):
    players = [dqn_player] + [GreedyPlayer(i) for i in range(1, agents)]
    env = Game(players)
    env.initialize_game()
    state = encode_state(agents, env.get_full_deck()._cards, env._players[0].hand, env._discard_pile[-1], 0)
    done = False
    while not done:
        for _ in range(agents):
          env.play_round()
          if env.is_game_over():
            reward = payoff(env, 0)
            done = True
          else:
            reward = 0
            done = False

        next_state = encode_state(agents, env.get_full_deck()._cards, env._players[0].hand, env._discard_pile[-1], 0)

        buffer.add((state, env._players[0].prev_action, reward, next_state, done))
        state = next_state

        if buffer.size() >= batch_size:
            batch = buffer.sample(batch_size)
            states, actions, rewards, next_states, dones = zip(*batch)

            states = torch.tensor(states, dtype=torch.float32).to(device)
            actions = torch.tensor(actions, dtype=torch.long).to(device)
            rewards = torch.tensor(rewards, dtype=torch.float32).to(device)
            next_states = torch.tensor(next_states, dtype=torch.float32).to(device)
            dones = torch.tensor(dones, dtype=torch.float32).to(device)

            # print(states.shape, actions.shape)
            q_values = policy_net(states)
            # print(q_values.shape)
            # print(actions.shape)
            q_values = q_values.gather(1, actions.unsqueeze(1)).squeeze()
            next_q_values = target_net(next_states).max(1)[0]
            target = rewards + (gamma * next_q_values * (1 - dones))

            loss = nn.MSELoss()(q_values, target)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if episode % 50 == 0:
              print(f"Episode {episode + 1}, Loss: {loss}")

    if episode % target_update_freq == 0:
        target_net.load_state_dict(policy_net.state_dict())
print("Ending loss: ", loss.item())

Episode 51, Loss: 0.38696005940437317
Episode 101, Loss: 0.3487950563430786
Episode 101, Loss: 0.4406189024448395
Episode 151, Loss: 0.39293473958969116
Episode 151, Loss: 0.45359545946121216
Episode 201, Loss: 0.325755774974823
Episode 201, Loss: 0.33951467275619507
Episode 251, Loss: 0.3827046751976013
Episode 251, Loss: 0.44889533519744873
Episode 251, Loss: 0.39683130383491516
Episode 301, Loss: 0.3639007806777954
Episode 301, Loss: 0.3315916657447815
Episode 351, Loss: 0.37709981203079224
Episode 351, Loss: 0.3358755111694336
Episode 401, Loss: 0.3760087490081787
Episode 401, Loss: 0.3506568968296051
Episode 451, Loss: 0.4215092658996582
Episode 451, Loss: 0.32699063420295715
Episode 501, Loss: 0.46768128871917725
Episode 501, Loss: 0.33663687109947205
Episode 501, Loss: 0.44248345494270325
Episode 551, Loss: 0.427690327167511
Episode 551, Loss: 0.4267599582672119
Episode 601, Loss: 0.4447503983974457
Episode 601, Loss: 0.4428597688674927
Episode 601, Loss: 0.4026695489883423
Epis

In [18]:
torch.save(policy_net.state_dict(), "five_crowns_dqn.pth")