In [None]:
"""Training PPO"""

In [None]:
from utils import *
import copy
import random
import pandas as pd
from torch import nn
import torch
from torch.optim import Adam
from torch.distributions.categorical import Categorical
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
import os

In [48]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [49]:
torch.manual_seed(1)
random.seed(1)

# Model

## Gen 3 - 3.5

In [50]:
class ppo(nn.Module):
    def __init__(self, n_player, n_card = 33):
        super().__init__()
        self.n_action = 2
        self.n_card = n_card
        self.n_player = n_player
        self.n_param_per_player = self.n_card + 1 # 33 cards + 1 number of chips
        self.n_state_param = self.n_card*2 + 5 # 33 for flipped card, 33 for remain card, 1 for chip in pot, 1 for number of cards remained, 1 for good card self, 1 for good card other, 1 for chipinpot/current
        self.input_dim = self.n_player*self.n_param_per_player + self.n_state_param
        
        self.policy = nn.Sequential(
            nn.Linear(self.input_dim, 256),
            nn.LeakyReLU(negative_slope=0.01),
            nn.Linear(256, 128),
            nn.LeakyReLU(negative_slope=0.01),
            nn.Linear(128, self.n_action)
            )
        
        self.value = nn.Sequential(
            nn.Linear(self.input_dim, 256),
            nn.LeakyReLU(negative_slope=0.01),
            nn.Linear(256, 128),
            nn.LeakyReLU(negative_slope=0.01),
            nn.Linear(128, 1)
            )
    
    def get_policy(self, X, legal_move_mask):
        """Mask the legal output
        legal_move_mask: boolean tensor, True for masked"""
        logit = self.policy(X)
        logit_masked = logit.masked_fill(legal_move_mask, float('-inf'))
        return logit_masked

    def forward(self, X, legal_move_mask, action = None):
        """Get value, probability
        legal_move_mask: boolean tensor
        action: tensor(1) Integer. This is the old sampled action. If none will do sampling
        """
        logit = self.get_policy(X, legal_move_mask)
        prob = Categorical(logits = logit)
        if action == None:
            action = prob.sample() # sample the action
        log_prob = prob.log_prob(action) # this will be used for surrogate loss (log(a) - log(b) = log(a/b))
        value = self.value(X)

        return action, log_prob, prob.entropy(), value # sampled action, log probability of it, its entropy,value from value network



## Gen 4

In [51]:
# Gen 4 - collapsing opponents

class ppo_gen_4(ppo):
    def __init__(self, n_player, n_card = 33):
        super().__init__(n_player, n_card)
        # self.n_action = 2
        # self.n_card = n_card
        # self.n_player = n_player
        self.n_param_per_player = self.n_card + 1 # 33 cards + 1 number of chips
        self.n_state_param = self.n_card*2 + 5 # 33 for flipped card, 33 for remain card, 1 for chip in pot, 1 for number of cards remained, 1 for good card self, 1 for good card other, 1 for chipinpot/current
        self.input_dim = 2*self.n_param_per_player + self.n_state_param # 2 because 1 for self, 1 for opponents
        
        self.policy = nn.Sequential(
            nn.Linear(self.input_dim, 256),
            nn.LeakyReLU(negative_slope=0.01),
            nn.Linear(256, 128),
            nn.LeakyReLU(negative_slope=0.01),
            nn.Linear(128, self.n_action)
            )
        
        self.value = nn.Sequential(
            nn.Linear(self.input_dim, 256),
            nn.LeakyReLU(negative_slope=0.01),
            nn.Linear(256, 128),
            nn.LeakyReLU(negative_slope=0.01),
            nn.Linear(128, 1)
            )

## Gen 5

In [52]:
class Flatten_custom(nn.Module):
    def __init__(self, start_dim_batch: int = 1, start_dim_unbatch: int = 0, end_dim: int = -1) -> None:
        super().__init__()
        self.start_dim_batch = start_dim_batch
        self.start_dim_unbatch = start_dim_unbatch
        self.end_dim = end_dim
    def forward(self, x):
        if len(x.shape) == 4:
            return x.flatten(self.start_dim_batch, self.end_dim)
        else:
            return x.flatten(self.start_dim_unbatch, self.end_dim)

flatten_custom  = Flatten_custom()

In [53]:
class ppo_gen_5(ppo):
    def __init__(self, n_player, n_card = 33):
        super().__init__(n_player, n_card)
        self.in_channel = n_player + 2 # 1 for flipped card, 1 for remaining cards
        self.out_channel = 16
        self.n_state_param = n_player + 5 # 1 for chip in pot, 1 for number of cards remained, 1 for good card self, 1 for good card other, 1 for chipinpot/current
        self.flatten_dimension = 512 # hard code
        self.gen = 5
        self.flatten_custom = Flatten_custom()

        
        self.cnn_policy = nn.Sequential(
            nn.Conv2d(
            in_channels = self.in_channel,
            out_channels = self.out_channel,
            kernel_size = (1, 3)
            ),
            nn.LeakyReLU(negative_slope=0.01),
            Flatten_custom()
        )

        self.linear_state_policy = nn.Sequential(
            nn.Linear(self.n_state_param, 16),
            nn.LeakyReLU(negative_slope=0.01),
        )
        
        self.ff_policy = nn.Sequential(
            nn.Linear(self.flatten_dimension, 256),
            nn.LeakyReLU(negative_slope=0.01),
            nn.Linear(256, 128),
            nn.LeakyReLU(negative_slope=0.01),
            nn.Linear(128, self.n_action)
            )
        
        self.cnn_value = nn.Sequential(
            nn.Conv2d(
            in_channels = self.in_channel,
            out_channels = self.out_channel,
            kernel_size = (1, 3)
            ),
            nn.LeakyReLU(negative_slope=0.01),
            Flatten_custom()
        )

        self.linear_state_value = nn.Linear(self.n_state_param, 16)

        self.ff_value = nn.Sequential(
            nn.Linear(self.flatten_dimension, 256),
            nn.LeakyReLU(negative_slope=0.01),
            nn.Linear(256, 128),
            nn.LeakyReLU(negative_slope=0.01),
            nn.Linear(128, 1)
            )
        

    def forward_concat(self, x_card, x_state, cnn_layer, linear_layer):
        x_card_flat = cnn_layer(x_card)
        x_state_flat = linear_layer(x_state)
        if len(x_card_flat.shape) == 1:
            dim = 0
        else:
            dim = 1
        return torch.cat([x_card_flat, x_state_flat], dim = dim)

    def get_policy(self, x_card, x_state, legal_move_mask):
        """Mask the legal output
        legal_move_mask: boolean tensor, True for masked"""
        policy_concat = self.forward_concat(x_card, x_state, self.cnn_policy, self.linear_state_policy) # flattened + concat
        logit = self.ff_policy(policy_concat)
        logit_masked = logit.masked_fill(legal_move_mask, float('-inf'))
        return logit_masked

    def get_value(self, x_card, x_state):
        value_concat = self.forward_concat(x_card, x_state, self.cnn_value, self.linear_state_value) # flattened + concat
        value = self.ff_value(value_concat)
        return value

    def forward(self, x_card, x_state, legal_move_mask, action = None):
        """Get value, probability
        legal_move_mask: boolean tensor
        action: tensor(1) Integer. This is the old sampled action. If none will do sampling
        """
        logit = self.get_policy(x_card, x_state, legal_move_mask)
        prob = Categorical(logits = logit)
        if action == None:
            action = prob.sample() # sample the action
        log_prob = prob.log_prob(action) # this will be used for surrogate loss (log(a) - log(b) = log(a/b))

        value = self.get_value(x_card, x_state)
        
        return action, log_prob, prob.entropy(), value # sampled action, log probability of it, its entropy,value from value network
        


# Rollout

In [54]:
class nothanks_ppo(game):
    def __init__(self, card = None):
        super().__init__(card)
        self.move_encode = {0: 'pass',
                            1: 'take'
                            }
        
    def rotate_player(self, turn):
        player_list = list(range(self.n_player))
        return player_list[turn:] + player_list[:turn]
        
        
    def get_state(self):
        # Get info about the state to save it
        player_info = []
        for player in self.players:
            player_info.append((player.card, player.chip))
        return player_info, self.turn, self.remain_card, self.chip_in_pot, self.current_card

    def get_state_gen_3_5(self):
        """rotate player"""
        player_info = []
        player_list = self.rotate_player(self.turn)
        for player_index in player_list:
            player = self.players[player_index]
            player_info.append((player.card, player.chip))
        return player_info, self.turn, self.remain_card, self.chip_in_pot, self.current_card

    def get_state_gen_4(self):
        """Rotate player + Collapse opponents"""
        player_info = []
        player_list = self.rotate_player(self.turn)

        # self
        player_index = player_list[0]
        player = self.players[player_index]
        player_info.append((player.card, player.chip))

        #opponent
        opponent_card = []
        min_chip = 100
        for player_index in player_list[1:]:
            player = self.players[player_index]
            opponent_card.extend(player.card)
            if player.chip < min_chip:
                min_chip = player.chip
        player_info.append((opponent_card, min_chip))
        return player_info, self.turn, self.remain_card, self.chip_in_pot, self.current_card
        
    
    def encode_card(self, card_list: list) -> list:
        """Encode the card list to binaries"""
        encode = [0]* len(self.full_deck)
        for card in card_list:
            encode[card - self.min_card] = 1
        return encode
    
    # def encode_turn(self, turn) -> list:
    #     return [1 if i == turn else 0 for i in range(self.n_player)]

    def check_favorable_self(self):
        player_tmp = self.players[self.turn]
        if any(abs(self.current_card - card_tmp) == 1 for card_tmp in player_tmp.card):
            return 1
        else:
            return 0

    def check_favorable_other(self):
        other_player = [player_tmp for index, player_tmp in enumerate(self.players) if index != self.turn]
        check = []
        for player_tmp in other_player:
            if any(abs(self.current_card - card_tmp) == 1 for card_tmp in player_tmp.card):
                check.append(1)
            else:
                check.append(0)
        return max(check)
        
    
    def encode_state(self, func):
        """Feature engineering here"""
        player_info, turn, remain_card, chip_in_pot, current_card = func()
        result = []
        for player_card, chip in player_info:
            chip_tmp = chip/max(self.full_deck)
            card_tmp = self.encode_card(player_card)
            
            result.extend(card_tmp)
            result.append(chip_tmp)
        
        # result.extend(self.encode_turn(self.turn))
        result.extend(self.encode_card([current_card]))
        result.append(chip_in_pot/max(self.full_deck))
        result.extend(self.encode_card(remain_card))
        result.append((len(self.remain_card) - self.n_remove_card)/(len(self.full_deck) - self.n_remove_card))
        result.append(self.check_favorable_self())

        #new
        result.append(self.check_favorable_other())
        result.append(chip_in_pot/self.current_card)
        # player_a, chip_a, ..., player_n, chip_n, turn, current_card, chip, remain_card, n_legal_remain_card, good card self, good card opponent, chip_in_pot/current_card
        return result

    def encode_state_gen_5(self):
        """Feature engineering here"""
        # player_info, turn, remain_card, chip_in_pot, current_card = self.get_state()
        player_info, turn, remain_card, chip_in_pot, current_card = self.get_state_gen_3_5()
        

        x_card = [self.encode_card(player_card) for player_card, _ in player_info]
        x_card.append(self.encode_card([current_card]))
        x_card.append(self.encode_card(remain_card))
        
        x_state = [chip/max(self.full_deck) for _, chip in player_info]
        x_state.append(chip_in_pot/max(self.full_deck))
        x_state.append((len(self.remain_card) - self.n_remove_card)/(len(self.full_deck) - self.n_remove_card))
        x_state.append(self.check_favorable_self())
        x_state.append(self.check_favorable_other())
        x_state.append(chip_in_pot/self.current_card)
        
        # player_a, chip_a, ..., player_n, chip_n, turn, current_card, chip, remain_card, n_legal_remain_card, good card self, good card opponent, chip_in_pot/current_card
        return x_card, x_state

    
    def calculate_reward_2(self, action):
        player_tmp = self.players[self.turn]
        
        if action == 'pass':
            # pass over half of the card value and the card is favorable is bad, punish for being too greedy
            if any(abs(self.current_card - card_tmp) == 1 for card_tmp in player_tmp.card):
                if self.chip_in_pot >= self.current_card//2:
                    return -3
            return -0.2  # light discouragement to avoid infinite pass
    
        if action == 'take':
            
            # Penalize taking too late or too early
            if player_tmp.chip == 0:
                return -2
            if self.chip_in_pot == 0:
                return -2
                
            reward = 0

            # Reward for taking early in the game
            if self.chip_in_pot < self.current_card and len(player_tmp.card) < 4:
                reward += (self.chip_in_pot / (self.current_card + 1)) * 3
    
            # Encourage sequential cards
            if any(abs(self.current_card - card_tmp) <= 3 for card_tmp in player_tmp.card):
                reward += 2
    
            # Penalty for taking later in the game
            distance_threshold = 4
            if len(player_tmp.card) > 4:
                dist = min(abs(self.current_card - c) for c in player_tmp.card)
                if dist > distance_threshold:
                    reward -= (dist - distance_threshold) * 0.5
            return reward
    
    def calculate_reward_3(self, action):
        return 0
                
    def reward_func(self, move):
        return self.calculate_reward_3(move)
                
    def rollout(self, model, func, n_game = None):
        """Play games, save state
        Need to get the turn
        """
        random_chance = 0.99
        # FIX: need to send 1 terminal state for each player:
        playing_buffer = {i: [] for i in range(self.n_player)}
        i = 0
        while self.is_continue:
            current_state = torch.tensor(self.encode_state(func)).to(device)
            legal_move = self.get_legal_action() # a list 
            legal_move_mask = torch.tensor([False if move in legal_move else True for move in self.move_encode.values()]).to(device)
            random_move = None
            if random.random() > random_chance:
                random_move = torch.tensor(random.choice([0 if move == 'pass' else 1 for move in legal_move])).to(device)
            with torch.no_grad():
                move_raw, log_prob, entropy, value = model.forward(current_state, legal_move_mask, random_move)
            move = self.move_encode.get(move_raw.item())
            reward = torch.tensor([self.reward_func(move)]).to(device)
            playing_buffer[self.turn].append([current_state, move_raw, legal_move_mask, log_prob, value, reward]) # if this change, need to change the hard code
            # move
            self.is_continue = self.action(move)
        final_reward = self.calculate_ranking()

        for player in range(self.n_player):
            playing_buffer[player].append([None, None, None, None, None, final_reward[player]])  # if this change, need to change the hard code
            # playing_buffer[player].append([None, None, None, None, None, -self.players[player].calculate_score() ])  # if this change, need to change the hard code
        return playing_buffer

    def rollout_gen_5(self, model, n_game = None):
        """Play games, save state
        Need to get the turn
        """
        random_chance = 0.99
        # FIX: need to send 1 terminal state for each player:
        playing_buffer = {i: [] for i in range(self.n_player)}
        i = 0
        while self.is_continue:
            
            x_card, x_state = self.encode_state_gen_5()
            # print(torch.Tensor(x_card).shape)
            x_card = torch.Tensor(x_card).unsqueeze(1).to(device) #1 for 5,33 -> 5,1,33 | 0 for 5,1,33 -> 1,5,1,33
            x_state = torch.tensor(x_state).to(device) # 8 to 1,8
            # print(x_card.shape)
            # print(x_state.shape)
            legal_move = self.get_legal_action() # a list 
            legal_move_mask = torch.tensor([False if move in legal_move else True for move in self.move_encode.values()]).to(device)
            # print(legal_move_mask.shape)
            random_move = None
            if random.random() > random_chance:
                random_move = torch.tensor(random.choice([0 if move == 'pass' else 1 for move in legal_move])).to(device)
            with torch.no_grad():
                move_raw, log_prob, entropy, value = model.forward(x_card, x_state, legal_move_mask, random_move)
            move = self.move_encode.get(move_raw.item())
            reward = torch.tensor([self.reward_func(move)]).to(device)
            playing_buffer[self.turn].append([x_card,x_state, move_raw, legal_move_mask, log_prob, value, reward]) # if this change, need to change the hard code
            # move
            self.is_continue = self.action(move)
        final_reward = self.calculate_ranking()

        for player in range(self.n_player):
            playing_buffer[player].append([None, None, None, None, None, None, final_reward[player]])  # if this change, need to change the hard code
            # playing_buffer[player].append([None, None, None, None, None, -self.players[player].calculate_score() ])  # if this change, need to change the hard code
        return playing_buffer

# Calculate targets & create training data

In [55]:
# Gen 4

def create_training_data(play_buffer, gamma = 0.99, reward_index = -1):
    """
    gamma: discount constant
    reward_index: index of reward returned in the play buffer
    """
    training_data = []
    for _, player_data in play_buffer.items():
        discounted_reward = player_data[-1][reward_index] # take the reward of the last stage. At the end of the game, all players receive 1 more step containing the final reward (final rank)
        for index in reversed(range(len(player_data) - 1)): # go from back to front, skip the final step

            state_tmp, move_tmp, legal_move_mask, log_prob_tmp, value_tmp, reward_tmp = player_data[index]
            
            # discounted reward
            discounted_reward = reward_tmp + discounted_reward*gamma

            #advantage
            advantage_tmp = discounted_reward - value_tmp
            # advantage_target = [advantage_tmp] + advantage_target # need to detach this at policy loss

            # policy_old = [policy_tmp] + policy_old # need to detach this at policy loss
            training_data.append([state_tmp, move_tmp, legal_move_mask, log_prob_tmp, advantage_tmp, discounted_reward]) #state, action, sampled action, advantage, discounted_reward (aka return)
    return training_data

In [56]:
# Gen 5
def create_training_data_gen_5(play_buffer, gamma = 0.99, reward_index = -1):
    """
    gamma: discount constant
    reward_index: index of reward returned in the play buffer
    """
    training_data = []
    for _, player_data in play_buffer.items():
        discounted_reward = player_data[-1][reward_index] # take the reward of the last stage. At the end of the game, all players receive 1 more step containing the final reward (final rank)
        for index in reversed(range(len(player_data) - 1)): # go from back to front, skip the final step

            x_card_tmp, x_state_tmp, move_tmp, legal_move_mask, log_prob_tmp, value_tmp, reward_tmp = player_data[index]
            
            # discounted reward
            discounted_reward = reward_tmp + discounted_reward*gamma

            #advantage
            advantage_tmp = discounted_reward - value_tmp
            # advantage_target = [advantage_tmp] + advantage_target # need to detach this at policy loss

            # policy_old = [policy_tmp] + policy_old # need to detach this at policy loss
            training_data.append([x_card_tmp, x_state_tmp, move_tmp, legal_move_mask, log_prob_tmp, advantage_tmp, discounted_reward]) #state, action, sampled action, advantage, discounted_reward (aka return)
    return training_data

# Data loader

In [57]:
class dataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        return self.data[index]

In [58]:
def collate_fn(batch_data):
    result = []
    for index, item in enumerate(zip(*batch_data)):
        # result.append(torch.stack(item).unsqueeze(dim = 1))
        result.append(torch.stack(item))
    return result

In [59]:
params = {'batch_size': 4,
          'shuffle': True,
          'collate_fn': collate_fn
          }

# Hyper-parameter tuning

In [60]:
N_TRIAL = 10
N_CYCLE = 20
N_PLAYER = 3
BATCH_SIZE = 64
DATA_LENGTH = 5000
N_EPOCH = 4
GEN = 5
LEARNING_RATE_list = [2e-4, 5e-4, 7e-4, 1e-3]
e_list = [0.1, 0.2, 0.3] #clipping constant
value_coef_list = [0.5, 0.6, 0.7]
entropy_coef_list = [0.01, 0.02, 0.03]
param_list = [LEARNING_RATE_list, e_list, value_coef_list, entropy_coef_list]
param_record = []
randomized_record = []

dataloader_params = {'batch_size': BATCH_SIZE,
          'shuffle': True,
          'drop_last': True, # drop the last batch where the size could be 1
          'collate_fn': collate_fn
          }

In [None]:
for trial in range(N_TRIAL):
    if GEN in [3, 3.5]:
        model = ppo(N_PLAYER).to(device)
    elif GEN == 4:
        model = ppo_gen_4(N_PLAYER).to(device)
    elif GEN == 5:
        model = ppo_gen_5(N_PLAYER).to(device)

    while True:
        random_param_index = [random.choice(range(len(a))) for a in param_list]
        if random_param_index not in randomized_record:
            break
    randomized_record.append(random_param_index)
    
    LEARNING_RATE, e, value_coef, entropy_coef = [a[b] for a,b in zip(param_list, random_param_index)]
    
    optimizer = Adam(model.parameters(), lr = LEARNING_RATE)
    param_record.append(
        {'learning_rate': LEARNING_RATE,
        'e': e,
        'value_coef': value_coef,
        'entropy_coef': entropy_coef,
       'index': trial
        }
    )
    
    print(f'learning_rate {LEARNING_RATE} | e: {e} | value_coef: {value_coef} | entropy_coef: {entropy_coef}')  
    for cycle in range(N_CYCLE):
        print(f'-------------CYCLE: {cycle}-------------')
    
        # Rollout
        training_data = []
        game_length = []
        while len(training_data) <= DATA_LENGTH:
            nothanks = nothanks_ppo()
            if GEN == 3:
                play_buffer = nothanks.rollout(model, nothanks.get_state)
                training_data_tmp = create_training_data(play_buffer)
            elif GEN == 3.5:
                play_buffer = nothanks.rollout(model, nothanks.get_state_gen_3_5)
                training_data_tmp = create_training_data(play_buffer)
            elif GEN == 5:
                play_buffer = nothanks.rollout_gen_5(model)
                training_data_tmp = create_training_data_gen_5(play_buffer)

            game_length.append(len(training_data_tmp))
            training_data.extend(training_data_tmp)
        print(f'game length: {round(np.mean(game_length),0)}')
        
        # create a new data loader
        train_data = dataset(training_data)
        dataloader = DataLoader(train_data, **dataloader_params)
        
        # Train
        for epoch in range(N_EPOCH):
            loss_record = 0
            for index, input_data in enumerate(dataloader):
                if GEN == 5:
                    x_card_tmp, x_state_tmp, move_tmp, legal_move_mask, log_prob_old, advantage_tmp, discounted_reward = input_data
                else:
                    state_tmp, move_tmp, legal_move_mask, log_prob_old, advantage_tmp, discounted_reward = input_data
                optimizer.zero_grad()
                if GEN == 5:
                    _, log_prob_new, entropy, value_new = model.forward(x_card_tmp, 
                                                                        x_state_tmp,
                                                                        legal_move_mask= legal_move_mask,
                                                                        action = move_tmp)
                else:
                    _, log_prob_new, entropy, value_new = model.forward(X = state_tmp, 
                                                            legal_move_mask= legal_move_mask,
                                                            action = move_tmp)
                # policy loss
                # advantage_norm_tmp = (advantage_tmp - advantage_tmp.mean())/(advantage_tmp.std() + 1e-8)
                ratio = torch.exp(log_prob_new - log_prob_old).unsqueeze(dim = 1) # pi_new/pi_old
                surrogate_1 = ratio*advantage_tmp
                ratio_clamp = torch.clamp(ratio, 1 - e, 1 + e) # clipped ratio
                surrogate_2 = ratio_clamp*advantage_tmp
                policy_loss = -torch.min(surrogate_1, surrogate_2).mean() #FIX: is this adv detached? yes

                # value loss
                value_loss = ((value_new - discounted_reward)**2).mean()
    
                # entropy loss: to encourage exploration
                entropy_loss = entropy.mean()
                loss = policy_loss + value_coef * value_loss - entropy_coef * entropy_loss
                loss_record += loss.item()
                loss.backward()
                nn.utils.clip_grad_norm_(model.parameters(), max_norm = 0.5) # FIX: check this
                optimizer.step()
    
            print(f'Epoch: {epoch} - loss: {round(loss_record/index,2)}')
    torch.save(model.state_dict(), f'./ppo_weight/hyperparam_tuning_{trial}.pth')

learning_rate 0.0005 | e: 0.3 | value_coef: 0.5 | entropy_coef: 0.02
-------------CYCLE: 0-------------
game length: 45.0
tensor(-2.6808, grad_fn=<NegBackward0>) tensor(216.7564, grad_fn=<MeanBackward0>) tensor(0.6900, grad_fn=<MeanBackward0>)
tensor(0.9866, grad_fn=<NegBackward0>) tensor(205.5308, grad_fn=<MeanBackward0>) tensor(0.6918, grad_fn=<MeanBackward0>)
tensor(3.0439, grad_fn=<NegBackward0>) tensor(208.8417, grad_fn=<MeanBackward0>) tensor(0.6927, grad_fn=<MeanBackward0>)
tensor(3.2691, grad_fn=<NegBackward0>) tensor(226.3994, grad_fn=<MeanBackward0>) tensor(0.6930, grad_fn=<MeanBackward0>)
tensor(1.3386, grad_fn=<NegBackward0>) tensor(248.2862, grad_fn=<MeanBackward0>) tensor(0.6931, grad_fn=<MeanBackward0>)
tensor(0.9901, grad_fn=<NegBackward0>) tensor(227.9572, grad_fn=<MeanBackward0>) tensor(0.6930, grad_fn=<MeanBackward0>)
tensor(4.2138, grad_fn=<NegBackward0>) tensor(204.7685, grad_fn=<MeanBackward0>) tensor(0.6925, grad_fn=<MeanBackward0>)
tensor(1.4693, grad_fn=<NegBac

KeyboardInterrupt: 

# Training

In [62]:
BATCH_SIZE = 64
LEARNING_RATE = 5e-4
N_PLAYER = 3
DATA_LENGTH = 5000
N_CYCLE = 100
N_EPOCH = 4
GEN = 5
e = 0.3 #clipping constant
value_coef = 0.7
entropy_coef = 0.03

if GEN in [3, 3.5]:
    model = ppo(N_PLAYER).to(device)
elif GEN == 4:
    model = ppo_gen_4(N_PLAYER).to(device)
elif GEN == 5:
    model = ppo_gen_5(N_PLAYER).to(device)

# model.load_state_dict(torch.load('./weight/hyperparam_tuning_0.pth'))

optimizer = Adam(model.parameters(), lr = LEARNING_RATE)

In [63]:
dataloader_params = {'batch_size': BATCH_SIZE,
          'shuffle': True,
          'drop_last': True, # drop the last batch where the size could be 1
          'collate_fn': collate_fn
          }

In [64]:
loss_list = []
game_length_list = []

In [65]:
train_flag = 1

for cycle in range(N_CYCLE):
    print(f'-------------CYCLE: {cycle}-------------')

    # Rollout
    training_data = []
    game_length = []
    while len(training_data) <= DATA_LENGTH:
        nothanks = nothanks_ppo()

        if GEN == 3:
            play_buffer = nothanks.rollout(model, nothanks.get_state)
            training_data_tmp = create_training_data(play_buffer)
        elif GEN == 3.5:
            play_buffer = nothanks.rollout(model, nothanks.get_state_gen_3_5)
            training_data_tmp = create_training_data(play_buffer)
        elif GEN == 5:
            play_buffer = nothanks.rollout_gen_5(model)
            training_data_tmp = create_training_data_gen_5(play_buffer)

        game_length.append(len(training_data_tmp))
        training_data.extend(training_data_tmp)
    print(f'game length: {round(np.mean(game_length),0)}')
    game_length_list.append(round(np.mean(game_length),0))
    
    # create a new data loader
    train_data = dataset(training_data)
    dataloader = DataLoader(train_data, **dataloader_params)
    
    # Train
    for epoch in range(N_EPOCH):
        loss_record = 0
        # for index, (state_tmp, move_tmp, legal_move_mask, log_prob_old, advantage_tmp, discounted_reward) in enumerate(dataloader):
        for index, input_data in enumerate(dataloader):
            if GEN == 5:
                x_card_tmp, x_state_tmp, move_tmp, legal_move_mask, log_prob_old, advantage_tmp, discounted_reward = input_data
            else:
                state_tmp, move_tmp, legal_move_mask, log_prob_old, advantage_tmp, discounted_reward = input_data
            optimizer.zero_grad()
            if GEN == 5:
                _, log_prob_new, entropy, value_new = model.forward(x_card_tmp, 
                                                                    x_state_tmp,
                                                                    legal_move_mask= legal_move_mask,
                                                                    action = move_tmp)
            else:
                _, log_prob_new, entropy, value_new = model.forward(X = state_tmp, 
                                                        legal_move_mask= legal_move_mask,
                                                        action = move_tmp)
            
            # policy loss
            # advantage_norm_tmp = (advantage_tmp - advantage_tmp.mean())/(advantage_tmp.std() + 1e-8)
            ratio = torch.exp(log_prob_new - log_prob_old).unsqueeze(dim = 1) # pi_new/pi_old
            surrogate_1 = ratio*advantage_tmp
            ratio_clamp = torch.clamp(ratio, 1 - e, 1 + e) # clipped ratio
            surrogate_2 = ratio_clamp*advantage_tmp
            policy_loss = -torch.min(surrogate_1, surrogate_2).mean() #FIX: is this adv detached? yes

            # value loss
            value_loss = ((value_new - discounted_reward)**2).mean()

            # entropy loss: to encourage exploration
            entropy_loss = entropy.mean()

            loss = policy_loss + value_coef * value_loss - entropy_coef * entropy_loss
            loss_record += loss.item()
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), max_norm = 0.5) # FIX: check this
            optimizer.step()

        print(f'Epoch: {epoch} - loss: {round(loss_record/index,2)}')
        loss_list.append(loss_record//index)
    torch.save(model.state_dict(), './ppo_weight/model_state_tmp.pth')
    if cycle % 10 == 0:
        torch.save(model.state_dict(), f'./ppo_weight/model_local_gen_{GEN}_default_rwd_{cycle}_iter.pth')

-------------CYCLE: 0-------------
game length: 51.0
Epoch: 0 - loss: 152.15
Epoch: 1 - loss: 106.49
Epoch: 2 - loss: 86.66
Epoch: 3 - loss: 80.26
-------------CYCLE: 1-------------


KeyboardInterrupt: 

In [None]:
fig, ax = plt.subplots(1,1, figsize = (20,5))
sns.lineplot([i.item() for i in loss_list])

In [None]:
fig, ax = plt.subplots(1,1, figsize = (20,5))
sns.lineplot([i.item() for i in game_length_list])

# Deploy

In [66]:
import os

In [67]:
path = f'./ppo_weight/trained_model/model_gen_5_default_rwd_57_iter.pth'



if 'gen_3' in path:
    model = ppo(N_PLAYER).to(device)
    GEN = 3
    if 'gen_3_5' in path:
        GEN = 3.5
elif 'gen_4' in path:
    model = ppo_gen_4(N_PLAYER).to(device)
    GEN = 4
elif 'gen_5' in path:
    model = ppo_gen_5(N_PLAYER).to(device)
    GEN = 5

model.load_state_dict(torch.load(path, map_location = torch.device(device)))

<All keys matched successfully>

## Deploy

In [68]:
move_encode = {"0": "pass",
                "1": "take"}
nothanks = nothanks_ppo()
human_index = 5

while nothanks.is_continue:
    print('------------------------------')
    print(f'''Card: {nothanks.current_card} | Chip in pot: {nothanks.chip_in_pot} | Player: {nothanks.turn} - {nothanks.players[nothanks.turn]}\n'''
)
    print('------------------------------')
    if nothanks.turn == human_index:
        move = move_encode.get(input("""Your turn:
0: pass
1: take
Enter here: """))
    else:
        with torch.no_grad():
            if GEN != 5:
                if GEN == 3:
                    current_state = torch.tensor(nothanks.encode_state(nothanks.get_state)).to(device)
                elif GEN in (3.5, 4):
                    current_state = torch.tensor(nothanks.encode_state(nothanks.get_state_gen_3_5)).to(device)                    
                legal_move = nothanks.get_legal_action() # a list 
                legal_move_mask = torch.tensor([False if move in legal_move else True for move in nothanks.move_encode.values()]).to(device)
                move_raw, log_prob, entropy, value = model.forward(current_state, legal_move_mask)
            else:
                x_card, x_state = nothanks.encode_state_gen_5()
                x_card = torch.Tensor(x_card).unsqueeze(1).to(device) #1 for 5,33 -> 5,1,33 | 0 for 5,1,33 -> 1,5,1,33
                x_state = torch.tensor(x_state).to(device) # 8 to 1,8
                legal_move = nothanks.get_legal_action() # a list 
                legal_move_mask = torch.tensor([False if move in legal_move else True for move in nothanks.move_encode.values()]).to(device)
                move_raw, log_prob, entropy, value = model.forward(x_card, x_state, legal_move_mask)
            move = nothanks.move_encode.get(move_raw.item())

    print(f'Move taken: {move}')
    nothanks.action(move)

------------------------------
Card: 29 | Chip in pot: 0 | Player: 0 - Chip: 11 | Card owned: []

------------------------------
Move taken: pass
------------------------------
Card: 29 | Chip in pot: 1 | Player: 1 - Chip: 11 | Card owned: []

------------------------------
Move taken: pass
------------------------------
Card: 29 | Chip in pot: 2 | Player: 2 - Chip: 11 | Card owned: []

------------------------------
Move taken: pass
------------------------------
Card: 29 | Chip in pot: 3 | Player: 0 - Chip: 10 | Card owned: []

------------------------------
Move taken: pass
------------------------------
Card: 29 | Chip in pot: 4 | Player: 1 - Chip: 10 | Card owned: []

------------------------------
Move taken: pass
------------------------------
Card: 29 | Chip in pot: 5 | Player: 2 - Chip: 10 | Card owned: []

------------------------------
Move taken: pass
------------------------------
Card: 29 | Chip in pot: 6 | Player: 0 - Chip: 9 | Card owned: []

---------------------------

In [69]:
# End-game hand
for player_tmp in nothanks.players:
    print(player_tmp.card)

[29, 28, 17, 34, 27, 33, 30, 16, 24, 15, 18, 25, 23, 14]
[19, 20, 21]
[10, 11, 5, 12, 3, 35, 6]


In [70]:
# Score
for player_tmp in nothanks.players:
    print(player_tmp.calculate_score())

-66
-19
-51


In [71]:
# Ranking - endgame reward
nothanks.calculate_ranking()

[np.float64(-20.0), np.float64(20.0), np.float64(0.0)]

# Pitting

In [72]:
model_prefix = 'model_gen_5_5_default_rwd'
# model_prefix = 'hyperparam_tuning_'
model_list = [i for i in os.listdir('./ppo_weight/trained_model/') if i.startswith(model_prefix)]
# model_list = ['model_gen_3_default_rwd_50_iter.pth',
#               'model_gen_3_default_rwd_60_iter.pth',
#               'model_gen_3_default_rwd_80_iter.pth',
#              ]
model_name_dict = {a:b for a, b in enumerate(model_list)}
n_model = len(model_name_dict)

In [73]:
select_record = {i:0 for i in range(n_model)}
win_record = {i:0 for i in range(n_model)}
move_encode = {"0": "pass",
                "1": "take"}

n_match = 10

## Pitting

In [74]:
for _ in tqdm(range(n_match)):
    
    model_index = random.sample(range(n_model), k = 3)        
    for index in model_index:
        select_record[index] += 1
    model_list = []
    for index in model_index:
        path = f'./ppo_weight/trained_model/{model_name_dict.get(index)}'
        if 'gen_3' in path:
            model = ppo(N_PLAYER).to(device)
            GEN = 3
            if 'gen_3_5' in path:
                GEN = 3.5
        elif 'gen_4' in path:
            model = ppo_gen_4(N_PLAYER).to(device)
            GEN = 4
        elif 'gen_5' in path:
            model = ppo_gen_5(N_PLAYER).to(device)
            GEN = 5
        
        model.load_state_dict(torch.load(path, map_location = torch.device(device)))
        model_list.append(model)
    
    
    nothanks = nothanks_ppo()
    while nothanks.is_continue:
        with torch.no_grad():
            if GEN != 5:
                if GEN == 3:
                    current_state = torch.tensor(nothanks.encode_state(nothanks.get_state)).to(device)
                elif GEN in (3.5, 4):
                    current_state = torch.tensor(nothanks.encode_state(nothanks.get_state_gen_3_5)).to(device)                    
                legal_move = nothanks.get_legal_action() # a list 
                legal_move_mask = torch.tensor([False if move in legal_move else True for move in nothanks.move_encode.values()]).to(device)
                move_raw, log_prob, entropy, value = model.forward(current_state, legal_move_mask)
            else:
                x_card, x_state = nothanks.encode_state_gen_5()
                x_card = torch.Tensor(x_card).unsqueeze(1).to(device) #1 for 5,33 -> 5,1,33 | 0 for 5,1,33 -> 1,5,1,33
                x_state = torch.tensor(x_state).to(device) # 8 to 1,8
                legal_move = nothanks.get_legal_action() # a list 
                legal_move_mask = torch.tensor([False if move in legal_move else True for move in nothanks.move_encode.values()]).to(device)
                move_raw, log_prob, entropy, value = model.forward(x_card, x_state, legal_move_mask)
            move = nothanks.move_encode.get(move_raw.item())
        nothanks.action(move)    
    
    winner_index = np.argmax(nothanks.calculate_ranking())
    win_record[model_index[winner_index]] += 1

100%|██████████| 10/10 [00:01<00:00,  7.26it/s]


In [75]:
pd.DataFrame([select_record, win_record]).T\
.rename(columns = {0: 'total',
         1: 'win'
        })\
.assign(win_pct = lambda df: df['win']/df['total'])\
.sort_values('win_pct', ascending = False)\
.reset_index()\
.assign(model_name = lambda df: df['index'].apply(lambda x: model_name_dict.get(x)))

Unnamed: 0,index,total,win,win_pct,model_name
0,0,10,4,0.4,model_gen_5_5_default_rwd_37_iter.pth
1,2,10,4,0.4,model_gen_5_5_default_rwd_58_iter.pth
2,1,10,2,0.2,model_gen_5_5_default_rwd_38_iter.pth
