In [34]:
!pip install rlcard

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [21]:
import tensorflow as tf

In [166]:
from collections import namedtuple
import random
import numpy as np

Transition = namedtuple('Transition', ['state', 'action', 'reward', 'next_state', 'done'])

class ReplayMemory(object):
    ''' 
    Replay mem for saving transitions
    '''
    def __init__(self, cap, batch_size):
        ''' 
        Initialize ReplayMemory

        :param int cap: the size of the mem buffer
        :param int batch_size: the size of the batches
        '''
        self.cap = cap
        self.mem = []
        self.pos = 0

    def push(self, *args):
        '''
        Save a transition into mem
        '''
        if len(self.mem) < self.cap:
            self.mem.append(None)
        self.mem[self.pos] = Transition(*args)
        self.pos = (self.pos + 1) % self.cap

    def sample(self, batch_size):
        '''
        Choose random sample from the mem with size of the batch size
        '''
        items = random.choice(self.mem, batch_size)
        return map(np.array, zip(*items))


In [167]:
import numpy as np
import torch
import torch.nn as nn

class dqn_l(object):
    '''
    Deep Q-Network
    '''

    def __init__(self, num_states=36, num_actions=4, hid_layer=[64, 32], lr=0.001, dev=None):
      
        self.num_states = num_states
        self.num_actions = num_actions
        self.hid_layer = hid_layer
        self.lr=lr
        self.dev = dev

        # DQN network based on the layers
        layers = self.num_states + self.hid_layer
        dqn_l = [nn.Flatten()]
        dqn_l.append(nn.BatchNorm1d(layers[0]))
        for i in range(len(layers)-1):
            dqn_l.append(nn.Linear(layers[i], layers[i+1], bias=True))
            dqn_l.append(nn.Tanh())
        dqn_l.append(nn.Linear(layers[-1], self.num_actions, bias=True))
        dqn_l = nn.Sequential(*dqn_l)

        dqn_l = dqn_l.to(self.dev)
        self.dqn_l = dqn_l
        self.dqn_l.eval()

        # Initialize weights in the network
        for p in self.dqn_l.parameters():
            if len(p.data.shape) > 1:
                nn.init.xavier_uniform_(p.data)

        # Define loss function
        self.loss_func = nn.MSELoss(reduction='mean')

        # Define optimizer
        
        self.optimizer = torch.optim.RMSprop(self.dqn_l.parameters())


    def get_qvalue(self, nxt_state):
        
        # Disable gradient calculation
        with torch.no_grad():
            # Create torch tensor
            nxt_state = torch.from_numpy(nxt_state).float().to(self.dev)
            # Get Q values
            q_val = self.dqn_l(nxt_state).cpu().numpy()
        return q_val

    def update(self, state_batch, action_batch, target_batch):
        ''' 
        Update the policy network

        :param np.ndarray state_batch: Batch of states from replay memory
        :param np.ndarray action_batch: Batch of actions from replay memory
        :param np.ndarray target_batch: Batch of Q-values from the target policy, it used during the optimization step
        :return float batch_loss: The calculated loss on the batch       
        '''
        # Set the gradients to zero
        self.optimizer.zero_grad()

        # Set the network in training mode
        self.dqn_l.train()

        # Create torch tensors
        state_batch = torch.from_numpy(state_batch).float().to(self.dev)
        action_batch = torch.from_numpy(action_batch).long().to(self.dev)
        target_batch = torch.from_numpy(target_batch).float().to(self.dev)

        # Gather Q-values from network and replay memory actions
        q_val = torch.gather(self.dqn_l(state_batch), dim=-1, index=action_batch.unsqueeze(-1)).squeeze(-1)

        # Optimization step
        batch_loss = self.loss_func(q_val, target_batch)
        batch_loss.backward()
        self.optimizer.step()
        batch_loss = batch_loss.item()
        self.dqn_l.eval()
        return batch_loss

In [172]:
import numpy as np
import torch
import torch.nn as nn
from copy import deepcopy
import random

class DQN_agent(object):
    '''
    DQN agent
    '''
    def __init__(self,
                state_no,
                num_act,
                extra_act=0,
                replay_memory_capacity=20000,
                min_samp=1000,
                b_s=16,
                train_time=1,
                df=0.99,
                hidden_layers=[64, 32],
                learning_rate=0.0001,
                eps=20000,
                upd_target=1000, 
                device=None):

      
        
        self.min_samp = min_samp
        self.upd_target = upd_target
        self.df = df
        self.eps = eps
        self.b_s = b_s
        self.num_act = num_act
        self.train_time = train_time
        self.extra_act = extra_act

       
        if device is None:
            self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        else:
            self.device = device

        # Create the replay memory
        self.memory = ReplayMemory(replay_memory_capacity, b_s)

        # Initialize current timestep and current training timestep
        self.current_timestep, self.current_training_timestep = 0, 0

        # Create array for the eps values during the eps decay 
        self.epsilons = np.linspace(1.0, 0.1, eps)

        # Create the policy and the target network
        self.policy_dqn = dqn_l(num_act=num_act, learning_rate=learning_rate, state_no=state_no, hidden_layers=hidden_layers, device=self.device)
        self.target_dqn = dqn_l(num_act=num_act, learning_rate=learning_rate, state_no=state_no, hidden_layers=hidden_layers, device=self.device)

        # Set use_raw value for the RLCard environment
        self.use_raw = False

    def store_and_train(self, transition):
       
        self.current_timestep += 1
        # Train the agent if the replay memory has data already and agent reached the next training period
        time_between = self.current_timestep - self.min_samp
        if time_between>=0 and time_between%self.train_time == 0:
            self.train()

    def discard_invalid_actions(self, action_probs, valid_actions):
        ''' 
        Remove invalid actions and normalize the probabilities.

        :param numpy.array[float] action_probs: Probabilities of all action
        :param list[int] valid_actions: Valid actions in the current state
        :return numpy.array[float] normalised_probs: Probabilities of valid actions
        '''
        valid_actions=[0,1,2]
        # Initialize new array
        normalised_probs = np.zeros(action_probs.shape[0])
        # Add probability values of valid actions to the array
        normalised_probs[valid_actions] = action_probs[valid_actions]
        # Normalize probabilities
        normalised_probs[valid_actions] = 1 / len(valid_actions)
        return normalised_probs

    def predict(self, state):
       
        eps = self.epsilons[min(self.current_timestep, self.eps-1)]
        actions = np.ones(self.num_act, dtype=float) * eps / self.num_act
        q_values = self.policy_dqn.get_qvalue(np.expand_dims(state, 0))[0]
        best_action = np.argmax(q_values)
        actions[best_action] += (1.0 - eps)
        return actions

    def step(self, state):
      
        actions = self.predict(state['obs'])
        normalised_probs = self.discard_invalid_actions(actions, state['legal_actions'])
        action = np.random.choice(np.arange(len(actions)), p=normalised_probs)
        return action


    def eval_step(self, state):
       
        q_values = self.policy_dqn.get_qvalue(np.expand_dims(state['obs'], 0))[0]
        normalised_probs = self.discard_invalid_actions(np.exp(q_values), state['legal_actions'])
        # Check version of choosing action
        if self.extra_act == 1:
          # If Raise (1) is a valid action and the best action is Call (0)
          if 1 in state['legal_actions'] and np.argmax(normalised_probs)==0:
            best_action = 1
          else:
            best_action = np.argmax(normalised_probs)
        elif self.extra_act == 2:
          # If Raise (1) is a valid action and the best action is Check (3)
          if 1 in state['legal_actions'] and np.argmax(normalised_probs)==3:
            best_action = 1
          else:
            best_action = np.argmax(normalised_probs)
        elif self.extra_act == 3:
          # If Raise (1) is a valid action and the best action is Fold (2)
          if 1 in state['legal_actions'] and np.argmax(normalised_probs)==2:
            best_action = 1
          else:
            best_action = np.argmax(normalised_probs)
        else:
          best_action = np.argmax(normalised_probs)
        return best_action, normalised_probs

    
    def train(self):
       
        state_batch, action_batch, reward_batch, next_state_batch, done_batch = self.memory.sample(self.b_s)

        # Get best next action using the policy network
        q_values_next = self.policy_dqn.get_qvalue(next_state_batch)
        best_actions = np.argmax(q_values_next, axis=1)

        # Calculate Q values from the target policy
        q_values_next_target = self.target_dqn.get_qvalue(next_state_batch)
        target_batch = reward_batch + np.invert(done_batch).astype(np.float32) * self.df * q_values_next_target[np.arange(self.b_s), best_actions]

        # Update policy network
        state_batch = np.array(state_batch)
        loss = self.policy_dqn.update(state_batch, action_batch, target_batch)

        # Update target network based on the target update period
        if self.current_training_timestep % self.upd_target == 0:
            self.target_dqn = deepcopy(self.policy_dqn)

        self.current_training_timestep += 1


    def get_state_dict(self):
       
        model_dict = {'policy_network': self.policy_dqn.dqn_l.state_dict(), 'target_network': self.target_dqn.dqn_l.state_dict()}
        return model_dict

    def load_networks(self, checkpoint):
       
        self.policy_dqn.dqn_l.load_state_dict(checkpoint['policy_network'])
        self.target_dqn.dqn_l.load_state_dict(checkpoint['target_network'])


In [182]:
import rlcard
from rlcard import models
from rlcard.agents import RandomAgent
from rlcard.utils import seeding, tournament
from rlcard.utils import Logger
import torch
import os

# Create environments
env = rlcard.make('limit-holdem', config={'seed': 0})
eval_env = rlcard.make('limit-holdem', config={'seed': 0})

# Set a global seed
seeding.create_seed(0)

# Play agressive game based on the version of choosing actual action
# Action with maximum value: 0
# Raise action instead of Call if possible: 1
# Raise action instead of Check if possible: 2
# Raise action instead of Fold if possible: 3
extra_action_version=1



# Create DQN agent
agent = DQN_agent(state_no=[72],
                  act_no=4, 
                  replay_memory_min_sample=1000,
                  training_period=10,
                  hidden_layers=[128, 128],
                  device=torch.device('cpu'),
                  extra_action_version=extra_action_version)

# Create random opponent agent
random_agent = RandomAgent(num_actions=len(eval_env.actions))

# Add the agent to the environments
env.set_agents([agent, random_agent])
eval_env.set_agents([agent, random_agent])



# Number of episodes, number of games during evaluation and evaluation in every N steps
episode_no, evaluate_games, evaluate_period = 100, 50, 10

for episode in range(episode_no):
    # Generate data from the environment
    trajectories, _ = env.run(is_training=True)

    # Feed transitions into agent memory, and train the agent
    for ts in trajectories[0]:
        agent.store_and_train(ts)
    reward=tournament(eval_env,evaluate_games)
    print(reward)
    

[3.67, -3.67]
[2.67, -2.67]
[2.59, -2.59]
[3.76, -3.76]
[3.23, -3.23]
[3.04, -3.04]
[2.77, -2.77]
[3.57, -3.57]
[3.51, -3.51]
[2.96, -2.96]
[2.5, -2.5]
[3.24, -3.24]
[2.55, -2.55]
[4.31, -4.31]
[2.42, -2.42]
[2.19, -2.19]
[3.04, -3.04]
[2.5, -2.5]
[2.83, -2.83]
[3.6, -3.6]
[2.69, -2.69]
[3.54, -3.54]
[2.96, -2.96]
[2.9, -2.9]
[2.65, -2.65]
[3.48, -3.48]
[2.75, -2.75]
[2.06, -2.06]
[3.1, -3.1]
[2.23, -2.23]
[3.11, -3.11]
[3.16, -3.16]
[2.14, -2.14]
[2.56, -2.56]
[1.94, -1.94]
[2.94, -2.94]
[3.51, -3.51]
[2.98, -2.98]
[3.42, -3.42]
[3.33, -3.33]
[3.24, -3.24]
[3.05, -3.05]
[2.27, -2.27]
[2.24, -2.24]
[1.53, -1.53]
[2.97, -2.97]
[3.51, -3.51]
[1.45, -1.45]
[1.96, -1.96]
[3.7, -3.7]
[3.18, -3.18]
[2.51, -2.51]
[3.13, -3.13]
[4.88, -4.88]
[3.73, -3.73]
[2.81, -2.81]
[2.62, -2.62]
[2.45, -2.45]
[3.05, -3.05]
[2.34, -2.34]
[2.74, -2.74]
[2.32, -2.32]
[3.28, -3.28]
[2.99, -2.99]
[3.45, -3.45]
[3.94, -3.94]
[2.8, -2.8]
[2.34, -2.34]
[4.12, -4.12]
[3.21, -3.21]
[3.43, -3.43]
[4.22, -4.22]
[3.31,