In [1]:
import rlcard
import collections
import random
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from rlcard.utils import Logger, tournament
from rlcard.agents import RandomAgent
import pickle

In [2]:
random.seed(42)
learning_rate = 0.0005
gamma = 0.98
buffer_limit = 50000
batch_size = 32

In [3]:
class ReplayBuffer():
    def __init__(self):
        self.buffer = collections.deque(maxlen=buffer_limit)    # double-ended queue
    
    def put(self, transition):
        self.buffer.append(transition)

    def sample(self, n):
        mini_batch = random.sample(self.buffer, n)
        s_lst, a_lst, r_lst, s_prime_lst, done_mask_lst = [], [], [], [], []

        for transition in mini_batch:
            s, a, r, s_prime, done_mask = transition
            s_lst.append(s)
            a_lst.append([a])
            r_lst.append([r])
            s_prime_lst.append(s_prime)
            done_mask_lst.append([done_mask])

        return torch.tensor(s_lst, dtype=torch.float), torch.tensor(a_lst), \
               torch.tensor(r_lst), torch.tensor(s_prime_lst, dtype=torch.float), \
               torch.tensor(done_mask_lst)

    def size(self):
        return len(self.buffer)

In [4]:
class Qnet(nn.Module):
    def __init__(self):
        super(Qnet, self).__init__()
        self.fc1 = nn.Linear(36, 128)
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, 4)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
      
    def sample_action(self, obs, epsilon):
        out = self.forward(obs)
        coin = random.random()
        if coin < epsilon:
            return random.randint(0,3)
        else : 
            return out.argmax().item()   

In [5]:
def train(q, q_target, memory, optimizer,loss_list):
    for i in range(10):
        s,a,r,s_prime,done_mask = memory.sample(batch_size)
        q_out = q(s)
        

        q_a = q_out.gather(1,a)
        # DQN
        max_q_prime = q_target(s_prime).max(1)[0].unsqueeze(1)

        target = r + gamma * max_q_prime * done_mask
        target = target.type(torch.FloatTensor)
        # MSE Loss
        loss = F.mse_loss(q_a, target)
        loss_list.append(loss)
        #print(i,loss)
        #print('-------------------------------')
        # Smooth L1 Loss
        #loss = F.smooth_l1_loss(q_a, target)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    return loss_list

In [35]:
class DuelingQnet(nn.Module):
    def __init__(self):
        super(DuelingQnet, self).__init__()
        self.fc1 = nn.Linear(36, 128)
        self.fc_value = nn.Linear(128, 128)
        self.fc_adv = nn.Linear(128, 128)
        self.value = nn.Linear(128, 1)
        self.adv = nn.Linear(128, 4)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        v = F.relu(self.fc_value(x))
        a = F.relu(self.fc_adv(x))
        v = self.value(v)
        a = self.adv(a)
        a_avg = torch.mean(a)
        q = v + a - a_avg
        return q

    def sample_action(self, obs, epsilon):
        out = self.forward(obs)
        coin = random.random()
        if coin < epsilon:
            return random.randint(0,3)
        else : 
            return out.argmax().item()  


In [38]:
def main():
    env = rlcard.make('leduc-holdem')
    env.set_agents([RandomAgent(num_actions=env.num_actions)])

    model=torch.load('logs/dueling_dqn_model.pt') # logs/double_dqn_model.pt, logs/dueling_dqn_model.pt
    model.eval()
    
    for it in range(10):
        pay_0 = 0
        pay_1 = 0
        for n_epi in range(100):
            epsilon = max(0.01, 0.08 - 0.01*(n_epi/200))
            s,id = env.reset()
            while not env.is_over():
                if id == 0:
                    a=model.sample_action(torch.from_numpy(s['obs']).float(),epsilon)
                else:
                    a=random.randint(0,3)

                if a not in s['legal_actions'].keys():
                    a=2
                s_prime, id_prime = env.step(a)
                s=s_prime
                id = id_prime
            payoffs = env.get_payoffs()
            pay_0 = pay_0+payoffs[0]
            pay_1 = pay_1+payoffs[1]
        
        

        print("score_id_0 : {:.1f}, score_id_1 : {:.1f}".format(pay_0, pay_1))



In [39]:
if __name__ == '__main__':
    main()

score_id_0 : 22.0, score_id_1 : -22.0
score_id_0 : 11.0, score_id_1 : -11.0
score_id_0 : 28.0, score_id_1 : -28.0
score_id_0 : 25.5, score_id_1 : -25.5
score_id_0 : 29.0, score_id_1 : -29.0
score_id_0 : 12.5, score_id_1 : -12.5
score_id_0 : 33.0, score_id_1 : -33.0
score_id_0 : 37.5, score_id_1 : -37.5
score_id_0 : 13.0, score_id_1 : -13.0
score_id_0 : 0.0, score_id_1 : 0.0
