In [1]:
from Big2.environment import Big2Env
from Big2.setting import *
from Big2.observer import *
from utils import *

from policy import *
from dataset import PPODataset
from memory import *

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import Tuple, List, Dict
from model import *
from torch.distributions.categorical import Categorical
from torch.utils.data import DataLoader
from tqdm import tqdm
from copy import deepcopy
import gc
import ray

ray.shutdown()
ray.init()

2024-02-23 02:08:43,053	INFO worker.py:1715 -- Started a local Ray instance. View the dashboard at [1m[32mhttp://127.0.0.1:8265 [39m[22m


0,1
Python version:,3.11.7
Ray version:,2.9.1
Dashboard:,http://127.0.0.1:8265


In [2]:
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

epoch = 5
batch_size = 128
n_epsiode = 20

c1 = 5e-1
c2 = 1e-3

save_period = 100
sync_preiod = 500

log_clip = np.log(1.1)
logit_diff_clamp = 10.0
greedy_coef = 0.005

model = AC()

policies: List[ActorGreedyPolicy] = [ActorGreedyPolicy(model, 0.1)] + [ActorGreedyPolicy(AC(), 0.1)]*3
memorys: List[PPOMemory] = [PPOMemory(
    discount_factor=0.99, 
    gae_factor=0.95, 
    gae = True, 
    gae_target = True)] + [NoneMemory()]*3


optimizer = torch.optim.Adam(model.parameters(), 1e-5)

env = Big2Env(DefaultObserver())
next_observation, info = env.reset()
done = False

chips = []
actor_losses = []
critic_losses =[]
entropys = []
max_losses = []

sync_target_model_param = deepcopy(model.state_dict())



In [3]:
@ray.remote
def get_trajectory_ray(env, memorys, policies, fixed_cards):
    for memory in memorys:
        if not memory is None:
            memory.clean()

    chips = []
    next_observation, info = env.reset(player_fix_cards = fixed_cards)
    done = False

    while not done:
        observation = next_observation
        turn = env.turn
        agent, memory = policies[turn], memorys[turn]
        log_prob, state_value, action = agent.choose_action(observation.cpu())
        next_observation, reward, done, _, info = env.step(action)
        reward = 0.0

        if not type(memory) is None:
            memory.append(observation.cpu(), log_prob, state_value, reward, action)

    #score
    chips = [0, 0, 0, 0]
    remains = [env.board.hands[i].card_tensor.sum().item() for i in range(players)]
    n_two = [env.board.hands[i].card_tensor[:, :, num_to_int['2']].sum().item() for i in range(players)]
    for toss_player in range(players):
        for accept_player in range(players):
            amount = max(remains[toss_player] - remains[accept_player], 0) * (2 ** n_two[toss_player])
            chips[toss_player] -= amount
            chips[accept_player] += amount

    for i, memory in enumerate(memorys):
        memory.rewards[-1] = chips[i] / 10.0
        continue
    

    result = memorys[0].get_tensor()
    result['chips'] = torch.Tensor([chips[0]])
    
    return result

In [4]:
def agent_train(policy: ActorGreedyPolicy, optimizer: torch.optim.Optimizer, data_loader: DataLoader):
    
    model = policy.model.train().to(DEVICE)
    for j in range(epoch):
        for idx, (target_action_values, advantages, old_log_probs, states, actions, masking) in enumerate(data_loader):

            n_sample = states.size(0)
            actions = actions.long()
            input_state = states.view(-1, 21, 13)
            action_values, logits = model(input_state)

            masking = masking.bool()
            masked_action_values = torch.where(masking, action_values.view(n_sample, -1), -1e38)

            #actor
            target = masking.sum(dim=-1) > 1
            logits = logits.view(n_sample, -1)

            masking = masking[target]
            logits = logits[target]
            advantages = advantages[target]

            logits = torch.where(masking, logits, -1e38)
            probs = torch.log_softmax(logits, dim=-1)
            new_log_probs = torch.gather(probs, dim=1, index= actions[target].unsqueeze(dim=-1)).flatten()
            
            max_action_probs = torch.gather(probs, dim=1, index= masked_action_values.argmax(dim=1)[target].unsqueeze(dim=-1)).flatten()
            max_action_penalty = (-greedy_coef * max_action_probs).mean()


            logit_diff = torch.clamp(new_log_probs - old_log_probs[target], -logit_diff_clamp, logit_diff_clamp)

            actor_loss = -(torch.exp(
                torch.where(advantages > 0, 
                torch.where(logit_diff < log_clip, logit_diff, log_clip), 
                torch.where(logit_diff > -log_clip, logit_diff, -log_clip))) * advantages).mean() * c1

            entropy = -new_log_probs * torch.exp(new_log_probs)
            entropy_penalty = -entropy.mean() * c2

            #critic
            action_values = action_values.view(n_sample, -1)
            new_values = torch.gather(action_values, dim=1, index= actions.unsqueeze(dim=-1)).flatten()
            critic_loss = torch.nn.functional.mse_loss(new_values, target_action_values)

            loss = actor_loss + critic_loss + entropy_penalty + max_action_penalty

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()


    torch.cuda.empty_cache()
    model = model.eval().cpu()




In [None]:

for i in range(100000):

    
    if (i % sync_preiod) == 0:
        for i in range(1, 4):
            policies[i].model.load_state_dict(sync_target_model_param)
        
        sync_target_model_param = deepcopy(policies[0].model.state_dict())
    trajectories = ray.get([get_trajectory_ray.remote(env, memorys, policies, np.zeros((4, 52))) for i in range(n_epsiode)])
    train_tensors = concat_dict_tensors(trajectories)

    del train_tensors['chips']

    ds = PPODataset(train_tensors, DEVICE)
    dl = DataLoader(ds, batch_size, True)
    
    agent_train(policies[0], optimizer, dl)

    del dl
    del ds
    del trajectories

    gc.collect()
    

    
            