In [1]:
import rlcard
import collections
import random
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from rlcard.utils import Logger, tournament
from rlcard.agents import RandomAgent
import pickle

In [2]:
random.seed(42)
learning_rate = 0.0005
gamma = 0.98
buffer_limit = 50000
batch_size = 32

In [3]:
class ReplayBuffer():
    def __init__(self):
        self.buffer = collections.deque(maxlen=buffer_limit)    # double-ended queue
    
    def put(self, transition):
        self.buffer.append(transition)

    def sample(self, n):
        mini_batch = random.sample(self.buffer, n)
        s_lst, a_lst, r_lst, s_prime_lst, done_mask_lst = [], [], [], [], []

        for transition in mini_batch:
            s, a, r, s_prime, done_mask = transition
            s_lst.append(s)
            a_lst.append([a])
            r_lst.append([r])
            s_prime_lst.append(s_prime)
            done_mask_lst.append([done_mask])

        return torch.tensor(s_lst, dtype=torch.float), torch.tensor(a_lst), \
               torch.tensor(r_lst), torch.tensor(s_prime_lst, dtype=torch.float), \
               torch.tensor(done_mask_lst)

    def size(self):
        return len(self.buffer)

In [4]:
class Qnet(nn.Module):
    def __init__(self):
        super(Qnet, self).__init__()
        self.fc1 = nn.Linear(36, 128)
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, 4)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
      
    def sample_action(self, obs, epsilon):
        out = self.forward(obs)
        coin = random.random()
        if coin < epsilon:
            return random.randint(0,3)
        else : 
            return out.argmax().item()   

In [5]:
def train(q, q_target, memory, optimizer,loss_list):
    for i in range(10):
        s,a,r,s_prime,done_mask = memory.sample(batch_size)
        q_out = q(s)
        

        q_a = q_out.gather(1,a)
        # DQN
        max_q_prime = q_target(s_prime).max(1)[0].unsqueeze(1)

        target = r + gamma * max_q_prime * done_mask
        target = target.type(torch.FloatTensor)
        # MSE Loss
        loss = F.mse_loss(q_a, target)
        loss_list.append(loss)
        #print(i,loss)
        #print('-------------------------------')
        # Smooth L1 Loss
        #loss = F.smooth_l1_loss(q_a, target)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    return loss_list

In [6]:
env=rlcard.make('leduc-holdem')
s=env.reset()
print(s)

actions_lst=['call','raise','fold','check']
'''for i in range(20):
    s=env.reset()
    print(s)
    print('--------------------------------------------------')
    #print(s)
    while (env.is_over() is False):
        print(s[1],env.get_state(s[1])['raw_obs'])
        print(s[0]['obs'])
        a=s[0]['legal_actions']
        #print(s)
        a=random.choice(list(a.keys()))
        print(s[1],actions_lst[a])
        s=env.step(a)
        
        
    print(env.get_payoffs())'''

'''s1=env.step(1)
print(s1)
s2=env.step(0)
print(s2)'''

({'legal_actions': OrderedDict([(0, None), (1, None), (2, None)]), 'obs': array([1., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0.]), 'raw_obs': {'hand': 'HJ', 'public_card': None, 'all_chips': [1, 2], 'my_chips': 1, 'legal_actions': ['call', 'raise', 'fold'], 'current_player': 0}, 'raw_legal_actions': ['call', 'raise', 'fold'], 'action_record': []}, 0)


's1=env.step(1)\nprint(s1)\ns2=env.step(0)\nprint(s2)'

In [7]:
def main():
    env = rlcard.make('leduc-holdem')
    env.set_agents([RandomAgent(num_actions=env.num_actions)])
    q = Qnet()
    q_target = Qnet()
    q_target.load_state_dict(q.state_dict())
    memory = ReplayBuffer()
    loss_list = []
    print_interval = 20
    score = 0.0
    optimizer = optim.Adam(q.parameters(), lr=learning_rate)
    before = 0
    after = 0
    for n_epi in range(4000):
        epsilon = max(0.01, 0.08 - 0.01*(n_epi/200)) #Linear annealing from 8% to 1%
        s,id = env.reset()
        episodes = [[] for _ in range(2)]

        while not env.is_over():
            episodes[id].append(s)
            a=q.sample_action(torch.from_numpy(s['obs']).float(),epsilon)
            if a not in s['legal_actions'].keys():
                a=2
            s_prime, id_prime = env.step(a)
            episodes[id].append(a)
            episodes[id].append(s_prime)
            if env.is_over()==True:
                episodes[id].append(0.0)
            else:
                episodes[id].append(1.0)
            id=id_prime
            s=s_prime
        r = env.get_payoffs()
        for i in range(0,len(episodes[0]),4):
            state = episodes[0][i]['obs']
            action = episodes[0][i+1]
            reward = r[0]
            next_state = episodes[0][i+2]['obs']
            done_mask = episodes[0][i+3]
            memory.put((state,action,reward,next_state,done_mask))
        for i in range(0,len(episodes[1]),4):
            state = episodes[1][i]['obs']
            action = episodes[1][i+1]
            reward = r[1]
            next_state = episodes[1][i+2]['obs']
            done_mask = episodes[1][i+3]   
            memory.put((state,action,reward,next_state,done_mask))

            

        if memory.size()>2000:
            loss_list=train(q, q_target, memory, optimizer,loss_list)

        if n_epi%print_interval==0 and n_epi!=0:
            q_target.load_state_dict(q.state_dict())

            pay_0 = 0
            pay_1 = 0
            for i in range(10):
                s,id = env.reset()
                while not env.is_over():
                    if id == 0:
                        a=q.sample_action(torch.from_numpy(s['obs']).float(),epsilon)
                    else:
                        a=random.randint(0,3)
                        
                    if a not in s['legal_actions'].keys():
                        a=2
                    s_prime, id_prime = env.step(a)
                    s=s_prime
                    id = id_prime
                payoffs = env.get_payoffs()
                pay_0 = pay_0+payoffs[0]
                pay_1 = pay_1+payoffs[1]
            pay_0 = pay_0/10
            pay_1 = pay_1/10
            print("n_episode :{}, score_id_0 : {:.1f}, score_id_1 : {:.1f}, n_buffer : {}, eps : {:.1f}%".format(
                                                            n_epi, pay_0, pay_1, memory.size(), epsilon*100))
            
            if n_epi<=2000:
                before+=pay_0
            else:
                after += pay_0
    print(before,after)
    print(len(loss_list))
    with open('logs/loss_dqn.txt','wb') as f:
        pickle.dump(loss_list,f)
    torch.save(q,'logs/dqn_model.pt')


In [8]:
if __name__ == '__main__':
    main()

n_episode :20, score_id_0 : 0.2, score_id_1 : -0.2, n_buffer : 25, eps : 7.9%
n_episode :40, score_id_0 : -0.1, score_id_1 : 0.1, n_buffer : 46, eps : 7.8%
n_episode :60, score_id_0 : 0.3, score_id_1 : -0.3, n_buffer : 66, eps : 7.7%
n_episode :80, score_id_0 : -0.3, score_id_1 : 0.3, n_buffer : 86, eps : 7.6%
n_episode :100, score_id_0 : -0.1, score_id_1 : 0.1, n_buffer : 110, eps : 7.5%
n_episode :120, score_id_0 : -0.5, score_id_1 : 0.5, n_buffer : 133, eps : 7.4%
n_episode :140, score_id_0 : -0.2, score_id_1 : 0.2, n_buffer : 156, eps : 7.3%
n_episode :160, score_id_0 : -0.4, score_id_1 : 0.4, n_buffer : 176, eps : 7.2%
n_episode :180, score_id_0 : -0.1, score_id_1 : 0.1, n_buffer : 196, eps : 7.1%
n_episode :200, score_id_0 : 0.1, score_id_1 : -0.1, n_buffer : 217, eps : 7.0%
n_episode :220, score_id_0 : -0.6, score_id_1 : 0.6, n_buffer : 237, eps : 6.9%
n_episode :240, score_id_0 : -0.3, score_id_1 : 0.3, n_buffer : 257, eps : 6.8%
n_episode :260, score_id_0 : -0.2, score_id_1 : 

  return torch.tensor(s_lst, dtype=torch.float), torch.tensor(a_lst), \


n_episode :1960, score_id_0 : 0.8, score_id_1 : -0.8, n_buffer : 2047, eps : 1.0%
n_episode :1980, score_id_0 : 0.1, score_id_1 : -0.1, n_buffer : 2113, eps : 1.0%
n_episode :2000, score_id_0 : 0.5, score_id_1 : -0.5, n_buffer : 2173, eps : 1.0%
n_episode :2020, score_id_0 : 0.3, score_id_1 : -0.3, n_buffer : 2231, eps : 1.0%
n_episode :2040, score_id_0 : 0.2, score_id_1 : -0.2, n_buffer : 2299, eps : 1.0%
n_episode :2060, score_id_0 : 1.1, score_id_1 : -1.1, n_buffer : 2367, eps : 1.0%
n_episode :2080, score_id_0 : 0.3, score_id_1 : -0.3, n_buffer : 2424, eps : 1.0%
n_episode :2100, score_id_0 : 0.8, score_id_1 : -0.8, n_buffer : 2489, eps : 1.0%
n_episode :2120, score_id_0 : 0.8, score_id_1 : -0.8, n_buffer : 2534, eps : 1.0%
n_episode :2140, score_id_0 : 0.6, score_id_1 : -0.6, n_buffer : 2577, eps : 1.0%
n_episode :2160, score_id_0 : 0.6, score_id_1 : -0.6, n_buffer : 2630, eps : 1.0%
n_episode :2180, score_id_0 : 0.1, score_id_1 : -0.1, n_buffer : 2678, eps : 1.0%
n_episode :2200,

n_episode :3960, score_id_0 : 0.3, score_id_1 : -0.3, n_buffer : 6413, eps : 1.0%
n_episode :3980, score_id_0 : 0.5, score_id_1 : -0.5, n_buffer : 6453, eps : 1.0%
-16.750000000000004 60.39999999999997
20560
