<a href="https://colab.research.google.com/github/datvodinh10/Truly-Proximal-Policy-Optimization/blob/main/Agent_PPO_VIS.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
game_name = 'Catan'

In [None]:
PATH = "./"
# from google.colab import drive
# drive.mount('/content/gdrive')
# PATH = f"/content/gdrive/MyDrive/Data 12 Hour/{game_name}/"

Mounted at /content/gdrive


In [None]:
!lscpu | grep 'Model name'

Model name:                      Intel(R) Xeon(R) CPU @ 2.20GHz


In [None]:
!git clone https://github.com/ngoxuanphong/ENV.git
%cd ENV

Cloning into 'ENV'...
remote: Enumerating objects: 4910, done.[K
remote: Counting objects: 100% (962/962), done.[K
remote: Compressing objects: 100% (500/500), done.[K
remote: Total 4910 (delta 408), reused 935 (delta 393), pack-reused 3948[K
Receiving objects: 100% (4910/4910), 276.01 MiB | 34.32 MiB/s, done.
Resolving deltas: 100% (2070/2070), done.
Updating files: 100% (1202/1202), done.
/content/ENV


## Import

In [None]:
import warnings 
warnings.filterwarnings('ignore')
from numba.core.errors import NumbaDeprecationWarning, NumbaPendingDeprecationWarning,NumbaExperimentalFeatureWarning, NumbaWarning
warnings.simplefilter('ignore', category=NumbaDeprecationWarning)
warnings.simplefilter('ignore', category=NumbaPendingDeprecationWarning)
warnings.simplefilter('ignore', category=NumbaExperimentalFeatureWarning)
warnings.simplefilter('ignore', category=NumbaWarning)

from numba import njit
from numba.typed import List
import numba
import numpy as np
import torch
import torch.nn as nn
from torch.distributions import Categorical
from torch.distributions.kl import kl_divergence
import numba
from numba import njit,jit
from numba.typed import List
torch.manual_seed(0)
np.random.seed(0)
import time
import matplotlib.pyplot as plt
import seaborn as sns

## Setup

In [None]:
from setup import make
env = make(game_name)

In [None]:
device = torch.device('cpu')

In [None]:
KPI = 1 / env.getAgentSize() + 0.01 * env.getAgentSize()
TIME = 3600 * 12

## Agent

In [None]:
def StandardScaler(X):
    """Return Data with Standard Scaling"""
    data     = X.T
    new_mean = torch.zeros(X.shape[1])
    new_std  = torch.zeros(X.shape[1])
    for i in range(data.shape[0]):
        new_mean[i] = torch.mean(data[i])
        new_std[i]  = torch.std(data[i])

    new_mean = new_mean.reshape(1,-1)
    new_std  = new_std.reshape(1,-1)
    return (X - new_mean) / (new_std + 1e-8)

In [None]:
def LayerInit(layer, std=np.sqrt(2), bias_const=0.0):
    """Init Weight and Bias with Constraint"""
    torch.nn.init.orthogonal_(layer.weight, std)
    torch.nn.init.constant_(layer.bias, bias_const)
    return layer

In [None]:
@njit
def MonteCarloRewards(rewards,is_terminals,rank,gamma):
    """Reward-to-go"""
    rtgs = List.empty_list(numba.float32)
    discounted_reward = 0
    # rewards = (rewards - rewards.mean()) / (rewards.std()+1e-10)
    for reward, is_terminal in np.column_stack((rewards[::-1],is_terminals[::-1])):
        if is_terminal==1:
            discounted_reward = 0
        discounted_reward = reward + (gamma * discounted_reward)
        rtgs.append(discounted_reward)
    return rtgs

In [None]:
@njit
def RandomChoiceWithProb(arr, prob):
    """Choice with given Probability"""
    return arr[np.searchsorted(np.cumsum(prob), np.random.rand(), side="right")]
@njit
def StableSoftmax(x):
    """Return Softmax of the output"""
    return np.exp(x - np.max(x)) / np.sum(np.exp(x - np.max(x)))

In [None]:
@njit
def ChooseAction(state,per,prob=True):
    """Return Action with given State and Per"""
    param0,param1,param2,param3,param4,param5 = per[-6][0],per[-5][0],per[-4][0],per[-3][0],per[-2][0],per[-1][0]
    list_action = np.where(env.getValidActions(state)==1)[0]
    out1        = np.dot(state.reshape(1,-1).astype(np.float32),param0) + param1
    out2        = np.dot(np.tanh(out1),param2) + param3
    policy      = np.dot(np.tanh(out2),param4) + param5
    if prob:
        return RandomChoiceWithProb(list_action,StableSoftmax(policy[0][list_action]))
    else:
        return list_action[np.argmax(policy[0][list_action])]

In [None]:
@njit
def Play(state,per):
    """Agent to get Data"""
    action = ChooseAction(state,per)
    valid_action =  env.getValidActions(state)
    if valid_action[action] != 1: # Prevent Underflow cause invalid action
        print(action,valid_action)
        action = np.random.choice(np.where(valid_action==1)[0])
    if env.getReward(state)==-1:
        per[0].append(np.array([[action]],dtype=np.float32))#action
        per[1].append(state.reshape(1,-1).astype(np.float32))#state
        per[2].append(np.array([[-0.001]],dtype=np.float32))#reward
        per[3].append(np.array([[0.]],dtype=np.float32))#is_terminals
        per[4].append(env.getValidActions(state).reshape(1,-1).astype(np.float32))#action masking
    else:
        per[0].append(np.array([[action]],dtype=np.float32))
        per[1].append(state.reshape(1,-1).astype(np.float32))
        per[2].append(np.array([[env.getReward(state)*1.0]],dtype=np.float32))
        per[3].append(np.array([[1.]],dtype=np.float32))
        per[4].append(env.getValidActions(state).reshape(1,-1).astype(np.float32))
    return action,per

In [None]:
class ActorModel(nn.Module):
    """Actor Model"""
    def __init__(self):
        super().__init__()
        self.actor = nn.Sequential(
            LayerInit(nn.Linear(env.getStateSize(),64)),
            nn.Tanh(),
            LayerInit(nn.Linear(64,64)),
            nn.Tanh(),
            LayerInit(nn.Linear(64,env.getActionSize()),std=0.01)
        )

class CriticModel(nn.Module):
    """Critic Model"""
    def __init__(self):
        super().__init__()
        self.critic = nn.Sequential(
                LayerInit(nn.Linear(env.getStateSize(),128)),
                nn.Tanh(),
                LayerInit(nn.Linear(128,128)),
                nn.Tanh(),
                LayerInit(nn.Linear(128,1),std=1)
            )

In [None]:
class Agent(nn.Module):
    def __init__(self,num_games = 10,num_game_per_batch=128,n_iter=5,lr=1e-3,batch_size=2048,lr_decay=1,entropy_coef=0,critic_coef = 1,gamma=0.995,value_clip = 0.2,policy_params=20,policy_kl_range=0.0008):
        super().__init__()
        self.actor                  = ActorModel().actor
        self.critic                 = CriticModel().critic
        
        self.num_game_per_batch     = num_game_per_batch
        self.n_iters                = n_iter
        self.lr                     = lr
        self.lr_decay               = lr_decay
        self.batch_size             = batch_size
        self.optimizer              = torch.optim.Adam([
                                        {'params': self.actor.parameters(), 'lr': self.lr},
                                        {'params': self.critic.parameters(), 'lr': self.lr}
                                    ])
        self.critic_coef            = critic_coef
        self.entropy_coef           = entropy_coef
        self.gamma                  = gamma
        self.num_games              = num_games
        self.value_clip             = value_clip
        self.policy_kl_range        = policy_kl_range
        self.policy_params          = policy_params
        self.best_actor_state_dict  = self.actor.state_dict()
        self.best_critic_state_dict = self.critic.state_dict()

        self.entropy_data           = []
        self.mean_win_data          = []
        self.time_data              = []
    # @torch.jit.script_method
    def Forward(self,state):
        return self.actor(state),self.critic(state)
    def GetPolicy(self,state):
        return self.actor(state)
    # @torch.jit.script_method
    
    def CalculateTrulyLoss(self,value,value_new,entropy,log_prob,log_prob_new,rtgs):
        """Calculate Model Loss"""
        advantage       = rtgs - value.detach()
        ratios          = torch.exp(torch.clamp(log_prob_new-log_prob.detach(),min=-20.,max=5.))
        Kl              = kl_divergence(Categorical(logits=log_prob), Categorical(logits=log_prob_new))

        actor_loss      = -torch.where(
                            (Kl >= self.policy_kl_range) & (ratios >= 1),
                            ratios * advantage - self.policy_params * Kl,
                            ratios * advantage
                        ).mean()
        # print(actor_loss)
        value_clipped   = value + torch.clamp(value_new - value, -self.value_clip, self.value_clip)

        critic_loss     = 0.5 * torch.max((rtgs-value_new)**2,(rtgs-value_clipped)**2).mean()
        total_loss      = actor_loss + self.critic_coef * critic_loss - self.entropy_coef * entropy
        # print(ratios.shape,actor_loss,critic_loss,entropy,total_loss)
        # print(advantage,ratios,Kl,actor_loss,critic_loss,total_loss)
        return total_loss
    
    def UpdatePer(self,actor_state_dict):
        """Update per file"""
        perx = [List.empty_list(numba.types.Array(dtype=numba.float32,ndim=2,layout='C')),
                List.empty_list(numba.types.Array(dtype=numba.float32,ndim=2,layout='C')),
                List.empty_list(numba.types.Array(dtype=numba.float32,ndim=2,layout='C')),
                List.empty_list(numba.types.Array(dtype=numba.float32,ndim=2,layout='C')),
                List.empty_list(numba.types.Array(dtype=numba.float32,ndim=2,layout='C')),
                List.empty_list(numba.types.Array(dtype=numba.float32,ndim=2,layout='C')),
                List.empty_list(numba.types.Array(dtype=numba.float32,ndim=2,layout='C'))]

        param0 = actor_state_dict['0.weight'].detach().numpy().T
        param1 = actor_state_dict['0.bias'].detach().numpy().reshape(1,-1)
        param2 = actor_state_dict['2.weight'].detach().numpy().T
        param3 = actor_state_dict['2.bias'].detach().numpy().reshape(1,-1)
        param4 = actor_state_dict['4.weight'].detach().numpy().T
        param5 = actor_state_dict['4.bias'].detach().numpy().reshape(1,-1)
        params = [param0,param1,param2,param3,param4,param5]
        for param in params:
            perx.append(List([np.array(param,ndmin=2,order='C')]))
        return perx
    
    def plot(self,plot_graph=True):
        """Plot data"""
        sns.set_style('darkgrid') # darkgrid, white grid, dark, white and ticks
        plt.rc('axes', titlesize=18)     # fontsize of the axes title
        plt.rc('axes', labelsize=14)    # fontsize of the x and y labels
        plt.rc('xtick', labelsize=13)    # fontsize of the tick labels
        plt.rc('ytick', labelsize=13)    # fontsize of the tick labels
        plt.rc('legend', fontsize=13)    # legend fontsize
        plt.rc('font', size=13)
        plt.figure(figsize=(12,4))
        plt.subplot(1,2,1)
        plt.plot(self.time_data,self.entropy_data)
        plt.title('Entropy')
        plt.xlabel('Time(s)')

        plt.subplot(1,2,2)
        plt.plot(self.time_data,self.mean_win_data)
        plt.title('Avg Win Rate(%)')
        plt.xlabel('Time(s)')
        if plot_graph:
            plt.show()
        plt.savefig(f'{PATH}{game_name}-Plot.png')
        
    def TrainModel(self,num_games=100,lr_decay=False,entropy_decay=True,level=1,save_every_epochs = 100,print_every_epochs=100,time_kpi = 3600):
        """Training model"""
        global perx
        self.num_games  = num_games // self.num_game_per_batch
        list_rank       = []
        NUM_MEAN        = 10_000 / self.num_game_per_batch
        best_win_rate   = -100

        start = time.time()
        
        pass_kpi = False
        for i in range(self.num_games):
            perx = self.UpdatePer(self.actor.state_dict())
            s1 = time.time()
            rank = env.numba_main_2(Play,self.num_game_per_batch,perx,level)[0] / self.num_game_per_batch
            # rank = numba_main(Play,RandomPlayer,RandomPlayer,RandomPlayer,self.num_game_per_batch,perx)[0][0] / self.num_game_per_batch
            
            if i%print_every_epochs==0:
                print(f"| BATCH: {i:>5} | RUN: {str(f'{time.time()-s1:.2f}')+'s':>7} | WIN RATE: {str(f'{rank*100:.2f}')+'%':>7} |",end=" " )
            
            batch_actions           = torch.as_tensor(np.array(perx[0]).squeeze(),dtype=torch.float32).reshape(-1).detach()
            batch_states            = torch.as_tensor(np.array(perx[1]).squeeze(),dtype=torch.float32).detach()
            batch_rtgs              = np.flip(MonteCarloRewards(np.array(perx[2]).squeeze(),np.array(perx[3]).squeeze(),rank,self.gamma),axis=0)
            batch_rtgs              = torch.as_tensor(batch_rtgs.copy(),dtype=torch.float32).detach()
            batch_mask              = torch.as_tensor(np.array(perx[4]).squeeze(),dtype=torch.float32).detach()
            # if torch.min(batch_mask)<0:
            #     batch_mask              = (batch_mask >= 0).astype(torch.float32)
            policy_old,batch_values = self.Forward(batch_states)
            policy_old              = policy_old.detach()
            batch_values            = batch_values.detach()
            prob_old                = Categorical(logits=policy_old+torch.log(batch_mask))      
            batch_logprobs          = prob_old.log_prob(batch_actions.view(1,-1)).detach().squeeze(0)
            old_entropy             = prob_old.entropy().detach().mean().item()
            for ix in range(6):
                perx[ix].clear()
            # if i%50==0:
            #   print(batch_actions.shape,batch_states.shape,batch_rtgs.shape,batch_mask.shape,batch_values.shape)
            if i%print_every_epochs==0:
                print(f'ENTROPY: {old_entropy:.4f} |',end=" ")
            s2 = time.time()
            for _ in range(self.n_iters):
                n_samples = batch_states.shape[0]
                # if i%print_every_epochs==0:
                #   print(n_samples)
                index = torch.randperm(n_samples)
                states,actions,probs,rtgss,masks,values= batch_states[index],batch_actions[index],batch_logprobs[index],batch_rtgs[index],batch_mask[index],batch_values[index]
                for idx in range(0, n_samples, self.batch_size):
                    begin, end = idx, min(idx + self.batch_size, n_samples)
                    if idx + self.batch_size > n_samples + 256:
                        continue
                    else:
                        state,action,log_prob,rtgs,mask,value =  states[begin:end],actions[begin:end],probs[begin:end],rtgss[begin:end],masks[begin:end],values[begin:end]

                        policy,value_new    = self.Forward(state)
                        value_new           = value_new.squeeze(1)
                        value               = value.squeeze(1)
                        prob1               = Categorical(logits=policy+ torch.log(mask))
                        log_prob_new        = prob1.log_prob(action.view(1,-1)).squeeze(0)
                        
                        entropy             = prob1.entropy().mean()
                        total_loss          = self.CalculateTrulyLoss(value,value_new,entropy,log_prob,log_prob_new,rtgs)

                        if not torch.isnan(total_loss).any():
                            self.optimizer.zero_grad()
                            total_loss.backward()
                            nn.utils.clip_grad_norm_(self.parameters(),0.5)
                            self.optimizer.step()
                    
                del actions
                del states
                del rtgss
                del mask
                del probs
                del values 

            del batch_actions
            del batch_states
            del batch_rtgs
            del batch_mask
            del batch_logprobs
            del batch_values
            del policy_old
            del prob_old
            
            perx = self.UpdatePer(self.actor.state_dict())

            list_rank.append(rank)
            if len(list_rank)>NUM_MEAN:
                list_rank.pop(0)
            win_rate_new = sum(list_rank) / NUM_MEAN

            if i % print_every_epochs==0:
                print(f"TRAIN: {time.time()-s2:.2f}s | MEAN WIN RATE: {sum(list_rank) / NUM_MEAN * 100:.1f} % |")
                lst_per = []
                for param in perx[-6:]:
                    lst_per.append(param[0])

            if i% save_every_epochs==0:
                np.save(f'{PATH}per_{game_name}.npy',np.array(lst_per))
                torch.save(self.best_actor_state_dict,f'{PATH}actor_state_dict_{game_name}.pt')
                torch.save(self.best_critic_state_dict,f'{PATH}critic_state_dict_{game_name}.pt')
                self.plot(plot_graph=False)

            # /content/gdrive/MyDrive/Data Truly PPO 8 Hour/{game_name}

            if lr_decay:
                if i==2000:
                    self.optimizer.param_groups[0]['lr'] = self.optimizer.param_groups[0]['lr'] * 0.2
                    print(f"LEARNING RATE CHANGE TO: {self.optimizer.param_groups[0]['lr']}")
            if entropy_decay:
                if i==1000:
                    self.entropy_coef = self.entropy_coef * 0.1
                    print(f"ENTROPY COEFFICIENT CHANGE TO: {self.entropy_coef}")


            #save data plot
            if i%5==0:
                self.entropy_data.append(old_entropy)
                self.mean_win_data.append(win_rate_new*100)
                self.time_data.append(time.time()-start)

            if win_rate_new >= best_win_rate:
                self.best_critic_state_dict = self.critic.state_dict()
                self.best_actor_state_dict  = self.actor.state_dict()
                best_win_rate               = win_rate_new

            if best_win_rate - win_rate_new > 0.05: # Model start to underfit
                self.actor.load_state_dict(self.best_actor_state_dict)
                self.critic.load_state_dict(self.best_critic_state_dict)

            if best_win_rate >= KPI and pass_kpi==False:
                print(f"TIME: {time.time()-start} s | PASS KPI! ")
                pass_kpi = True
            if time.time() - start >= time_kpi:
                print('TRAINING 12 HOUR COMPLETED!')
                print(f'TOTAL TIME: {time.time() - start:.2f} s')
                perx    = self.UpdatePer(self.best_actor_state_dict)
                lst_per = []
                for param in perx[-6:]:
                    lst_per.append(param[0])
                np.save(f'{PATH}per_{game_name}.npy',np.array(lst_per))
                torch.save(self.best_actor_state_dict,f'{PATH}actor_state_dict_{game_name}.pt')
                torch.save(self.best_critic_state_dict,f'{PATH}critic_state_dict_{game_name}.pt')
                self.plot(plot_graph=True)
                # /content/gdrive/MyDrive/Data Truly PPO 8 Hour/{game_name}
                break
        return perx


## Train

In [None]:
if env.getStateSize() > 450: #prevent out of memory
    agent  = Agent(num_games=1,num_game_per_batch=50,n_iter=2,lr=1e-3,batch_size=512,gamma=1,entropy_coef=0.01,value_clip=0.2)
    perx = agent.TrainModel(num_games=100_000_000,level=1,time_kpi = TIME,save_every_epochs=25,lr_decay=False)
elif env.getActionSize() < 16:
    agent  = Agent(num_games=1,num_game_per_batch=200,n_iter=2,lr=1e-3,batch_size=1024,gamma=1,entropy_coef=0,value_clip=0.2)
    perx = agent.TrainModel(num_games=100_000_000,level=1,time_kpi = TIME,save_every_epochs=200,lr_decay=False)
else:
    agent  = Agent(num_games=1,num_game_per_batch=200,n_iter=2,lr=1e-3,batch_size=1024,gamma=1,entropy_coef=0.001,value_clip=0.2)
    perx = agent.TrainModel(num_games=100_000_000,level=1,time_kpi = TIME,save_every_epochs=100,lr_decay=False)

| BATCH:     0 | RUN: 109.65s | WIN RATE:   0.00% | ENTROPY: 1.1997 | TRAIN: 1.50s | MEAN WIN RATE: 0.0 % |
| BATCH:   100 | RUN:  13.07s | WIN RATE:  29.50% | ENTROPY: 1.0445 | TRAIN: 0.90s | MEAN WIN RATE: 23.8 % |
TIME: 2108.5097312927246 s | PASS KPI! 
| BATCH:   200 | RUN:  13.51s | WIN RATE:  34.00% | ENTROPY: 0.9980 | TRAIN: 1.03s | MEAN WIN RATE: 34.7 % |
| BATCH:   300 | RUN:  13.60s | WIN RATE:  44.00% | ENTROPY: 0.9481 | TRAIN: 1.16s | MEAN WIN RATE: 39.9 % |
| BATCH:   400 | RUN:  13.78s | WIN RATE:  46.00% | ENTROPY: 0.8925 | TRAIN: 1.25s | MEAN WIN RATE: 48.0 % |
| BATCH:   500 | RUN:  14.20s | WIN RATE:  61.00% | ENTROPY: 0.8468 | TRAIN: 1.57s | MEAN WIN RATE: 58.5 % |
| BATCH:   600 | RUN:  13.55s | WIN RATE:  73.50% | ENTROPY: 0.7484 | TRAIN: 1.48s | MEAN WIN RATE: 68.6 % |
| BATCH:   700 | RUN:  13.66s | WIN RATE:  75.00% | ENTROPY: 0.6972 | TRAIN: 1.53s | MEAN WIN RATE: 69.5 % |
| BATCH:   800 | RUN:  13.66s | WIN RATE:  72.00% | ENTROPY: 0.6602 | TRAIN: 1.56s | MEAN

## Test

In [None]:
@njit
def Test_per(state,per):
    param0,param1,param2,param3,param4,param5 = per[-6],per[-5],per[-4],per[-3],per[-2],per[-1]
    list_action = np.where(env.getValidActions(state)==1)[0]
    out1 = np.dot(state.reshape(1,-1).astype(np.float32),param0) + param1
    out2 = np.dot(np.tanh(out1),param2) + param3
    policy = np.dot(np.tanh(out2),param4) + param5
    return RandomChoiceWithProb(list_action,StableSoftmax(policy[0][list_action])),per

per_file = list(np.load(f'{PATH}per_{game_name}.npy',allow_pickle=True))
num_test = 10_000
win = env.numba_main_2(Test_per,num_test,per_file,1)[0]
print(f'| GAME: {game_name:<18} | WIN RATE: {win / num_test * 100:.2f} %{"":>2} vs KPI: {KPI*100:.1f} % |')