# Five Crowns Deep-Q-Learning

### Import Libraries

In [10]:
import copy
import torch
import random
import numpy as np
import torch.nn as nn
import torch.optim as optim

from collections import deque

from deck import Card
from player import Player
from five_crowns import Game
from greedy import GreedyPlayer
from constants import GET_DISCARD, DRAW_CARD
from scoring import score_hand, get_best_discard

### Define the Network

In [11]:
class DQN(nn.Module):
    """
    Deep Q Network
    
    Args:
        state_dim (int): dimension of the state space
        action_dim (int): dimension of the action space
        
    Attributes:
        fc (torch.nn.Sequential): fully connected layers

    Methods:
        forward: forward pass of the network
    """
    def __init__(self, state_dim, action_dim):
        super(DQN, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(state_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, action_dim)
        )

    def forward(self, x):
        """
        Forward pass of the network
        
        Args:
            x (torch.Tensor): input tensor
        
        Returns:
            torch.Tensor: output tensor
        """
        return self.fc(x)

### Define the Memory Buffer

In [None]:
class ReplayBuffer:
    """
    Replay Buffer
    
    Args:
        capacity (int): capacity of the buffer
        
    Attributes:
        buffer (collections.deque): buffer to store experiences
        
    Methods:
        add: add an experience to the buffer
        sample: sample a batch of experiences from the buffer
        size: get the size of the buffer
    """
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)

    def add(self, experience):
        """
        Add an experience to the buffer
        
        Args:
            experience (tuple): experience to add to the buffer
        """
        self.buffer.append(experience)

    def sample(self, batch_size):
        """
        Sample a batch of experiences from the buffer
        
        Args:
            batch_size (int): size of the batch to sample
        
        Returns:
            list: batch of experiences
        """
        return random.sample(self.buffer, batch_size)

    def size(self):
        """
        Get the size of the buffer
        
        Returns:
            int: size of the buffer
        """
        return len(self.buffer)

### Define Helper Functions

In [13]:
def payoff(game, current_player):
    """
    Calculates the reward for the player at this state

    Returns:
        int: The reward for the current player in the state
    """
    if not game._go_out:
        return 0
    
    # If the player went out, calculate the score
    hand_score = score_hand(
        game.get_player_hand(current_player), game
    )
    
    return 1 if hand_score == 0 else -1

In [14]:
def card_to_idx(suit, rank):
    """
    Convert a card to an index in the state vector

    Args:
        suit (str): suit of the card
        rank (int): rank of the card

    Returns:
        int: index of the card
    """
    suit_dict = {
        'Clubs': 0,
        'Diamonds': 1,
        'Hearts': 2,
        'Spades': 3,
        'Stars': 4,
        'J': 5
    }

    return 11 * suit_dict[suit] + (rank if suit != "J" else 0) - 3

def idx_to_card(idx):
    """
    Convert an index to a card in the state vector
    
    Args:
        idx (int): index of the card
        
    Returns:
        Card: card corresponding to the index
    """
    suit_dict = {
        0: 'Clubs',
        1: 'Diamonds',
        2: 'Hearts',
        3: 'Spades',
        4: 'Stars',
        5: 'J'
    }

    suit = suit_dict[idx // 11]
    if suit == "J":
        rank = 50
    else:
        rank = (idx % 11) + 3

    return Card(rank, suit)

def encode_state(num_players, full_deck, player_deck, discard_card, gone_out_status):
    """
    Encode the state of the game into a vector
    
    Args:
        num_players (int): number of players in the game
        full_deck (list): full deck of cards
        player_deck (list): deck of the player
        discard_card (Card): card that was discarded
        gone_out_status (bool): whether the player has gone out
        
    Returns:
        np.array: encoded state of the game
    """
    num_players = num_players

    deck = set(full_deck)

    encoded_deck = np.zeros(len(deck))
    for card in cards:
        card_idx = card_to_idx(card.suit(), card.rank())
        encoded_deck[card_idx] += 1

    # Encode discard card as (rank, suit)
    discard_card_encoded = np.zeros(len(deck))
    if discard_card is not None:
        discard_idx = card_to_idx(discard_card.suit(), discard_card.rank())
        discard_card_encoded[discard_idx] = 1

    # Gone out status
    gone_out_status_encoded = int(gone_out_status)

    return np.concatenate([
        encoded_deck,
        discard_card_encoded,
        [gone_out_status_encoded]
    ])

def inference(game, hand, discard_card, policy_net):
    """
    Inference function for the policy network
    
    Args:
        game (Game): game object
        hand (list): hand of the player
        discard_card (Card): card that was discarded
        policy_net (DQN): policy network
        
    Returns:
        Card: card to play
    """
    encoded_state = encode_state(game.num_players(), game.get_full_deck()._cards, hand, discard_card, game._go_out)
    model = policy_net(encoded_state.shape[0], 56)
    model.eval()

    with torch.no_grad():
        output = model(encoded_state).item()

        sorted_list = [(output[i], idx_to_card(i)) for i in range(len(output))]
        sorted_list.sort(key=lambda x: x[0], reverse=True)


        for i in range(len(sorted_list)):
            if sorted_list[i][1] in hand:
                return sorted_list[i][1]

### Define the DQNPlayer Class

In [15]:
class DQNPlayer(Player):
    """
    DQN player always takes action that minimize score for turn

    Args:
        player_id (int): player id
        policy_net (DQN): policy network

    Attributes:
        prev_discard (Card): previous discard card
        epsilon (float): epsilon for epsilon-greedy policy
        epsilon_decay (float): decay rate for epsilon
        min_epsilon (float): minimum epsilon
        prev_action (int): previous action
        policy_net (DQN): policy network

    Methods:
        draw_phase: draw phase of the player
        discard_phase: discard phase of the player
    """
    def __init__(self, player_id, policy_net):
        super().__init__(player_id)
        self.prev_discard = None
        self.epsilon = 1.0
        self.epsilon_decay = 0.995
        self.min_epsilon = 0.1
        self.prev_action = None
        self.policy_net = policy_net

    def draw_phase(self, game):
        """
        Draw phase of the player
        
        Args:
            game (Game): game object
            
        Returns:
            int: action to take
        """
        # Get best score if we take discard
        new_card = game.get_discard_pile()[-1]
        temp_hand = self.hand + [new_card]
        _, discard_score = get_best_discard(temp_hand,game,excluded_discard=new_card)

        # Get best expected score if we draw random
        remaining_deck = copy.deepcopy(game.get_full_deck().get_cards())
        draw_scores = []
        for card in game.get_discard_pile() + self.hand:
            remaining_deck.remove(card)
        for card in remaining_deck:
            temp_hand = self.hand + [card]
            _, draw_score = get_best_discard(temp_hand, game)
            draw_scores.append(draw_score)
        expected_draw_score = sum(draw_scores)/len(draw_scores)

        # Take action with better expected score
        if discard_score < expected_draw_score:
            self.prev_discard = game.get_discard_pile()[-1]
            return GET_DISCARD
        self.prev_discard = None
        return DRAW_CARD

    def discard_phase(self, game):
        """
        Discard phase of the player
        
        Args:
            game (Game): game object
            
        Returns:
            Card: card to discard
        """
        if random.random() < self.epsilon:
            action = random.choice([card for card in self.hand if card != self.prev_discard])
        else:
            with torch.no_grad():
                action = inference(game, self.hand, self.prev_discard, self.policy_net)

        encoded_action = np.zeros(56)
        idx = card_to_idx(action.suit(), action.rank())
        encoded_action[idx] = 1
        self.prev_action = idx
        return action


### Set Up the Environment and Hyperparameters

In [16]:
# Set up the environment
device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")
agents = 4
state_dim = 113
action_dim = 56

policy_net = DQN(state_dim, action_dim).to(device)
target_net = DQN(state_dim, action_dim).to(device)
target_net.load_state_dict(policy_net.state_dict())
target_net.eval()

# Set up the players and game
dqn_player = DQNPlayer(0, policy_net)
players = [dqn_player] + [GreedyPlayer(i) for i in range(1, agents)]

env = Game(players)
env.initialize_game()

state_encoded = encode_state(agents, env.get_full_deck()._cards, env._players[0].hand, env._discard_pile[-1], 0)

# Set up the optimizer, buffer, and hyperparameters
optimizer = optim.Adam(policy_net.parameters(), lr=1e-4)
buffer = ReplayBuffer(10000)
batch_size = 64
gamma = 0.99
target_update_freq = 50

### Training Loop

In [17]:
# Train the model
for episode in range(1000):
    
    # Initialize the environment
    players = [dqn_player] + [GreedyPlayer(i) for i in range(1, agents)]
    env = Game(players)
    env.initialize_game()
    state = encode_state(agents, env.get_full_deck()._cards, env._players[0].hand, env._discard_pile[-1], 0)
    done = False

    # Play the game while it is not over
    while not done:
        for _ in range(agents):
          env.play_round()
          if env.is_game_over():
            reward = payoff(env, 0)
            done = True
          else:
            reward = 0
            done = False

        # Get the next state
        next_state = encode_state(agents, env.get_full_deck()._cards, env._players[0].hand, env._discard_pile[-1], 0)

        buffer.add((state, env._players[0].prev_action, reward, next_state, done))
        state = next_state

        # Sample a batch of experiences from the buffer
        if buffer.size() >= batch_size:
            batch = buffer.sample(batch_size)
            states, actions, rewards, next_states, dones = zip(*batch)

            states = torch.tensor(states, dtype=torch.float32).to(device)
            actions = torch.tensor(actions, dtype=torch.long).to(device)
            rewards = torch.tensor(rewards, dtype=torch.float32).to(device)
            next_states = torch.tensor(next_states, dtype=torch.float32).to(device)
            dones = torch.tensor(dones, dtype=torch.float32).to(device)

            # Calculate the q-values and loss
            q_values = policy_net(states)
            q_values = q_values.gather(1, actions.unsqueeze(1)).squeeze()
            next_q_values = target_net(next_states).max(1)[0]
            target = rewards + (gamma * next_q_values * (1 - dones))

            loss = nn.MSELoss()(q_values, target)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if episode % 50 == 0:
              print(f"Episode {episode + 1}, Loss: {loss}")

    if episode % target_update_freq == 0:
        target_net.load_state_dict(policy_net.state_dict())
print("Ending loss: ", loss.item())

Episode 51, Loss: 0.38696005940437317
Episode 101, Loss: 0.3487950563430786
Episode 101, Loss: 0.4406189024448395
Episode 151, Loss: 0.39293473958969116
Episode 151, Loss: 0.45359545946121216
Episode 201, Loss: 0.325755774974823
Episode 201, Loss: 0.33951467275619507
Episode 251, Loss: 0.3827046751976013
Episode 251, Loss: 0.44889533519744873
Episode 251, Loss: 0.39683130383491516
Episode 301, Loss: 0.3639007806777954
Episode 301, Loss: 0.3315916657447815
Episode 351, Loss: 0.37709981203079224
Episode 351, Loss: 0.3358755111694336
Episode 401, Loss: 0.3760087490081787
Episode 401, Loss: 0.3506568968296051
Episode 451, Loss: 0.4215092658996582
Episode 451, Loss: 0.32699063420295715
Episode 501, Loss: 0.46768128871917725
Episode 501, Loss: 0.33663687109947205
Episode 501, Loss: 0.44248345494270325
Episode 551, Loss: 0.427690327167511
Episode 551, Loss: 0.4267599582672119
Episode 601, Loss: 0.4447503983974457
Episode 601, Loss: 0.4428597688674927
Episode 601, Loss: 0.4026695489883423
Epis

In [18]:
# Save the model
torch.save(policy_net.state_dict(), "five_crowns_dqn.pth")