In [1]:
import argparse
import os
from copy import deepcopy
from typing import Optional, Tuple

import gymnasium as gym
import numpy as np
import torch
from pettingzoo.classic import hanabi_v4
from torch.utils.tensorboard import SummaryWriter

from tianshou.data import Collector, PrioritizedVectorReplayBuffer
from tianshou.env import DummyVectorEnv
from tianshou.env.pettingzoo_env import PettingZooEnv
from tianshou.policy import (
    BasePolicy,
    RainbowPolicy,
    MultiAgentPolicyManager,
    RandomPolicy,
)
from tianshou.utils.net.discrete import NoisyLinear
from tianshou.trainer import offpolicy_trainer
from tianshou.utils import TensorboardLogger
from tianshou.utils.net.common import Net
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
from pettingzoo.classic.hanabi.hanabi import raw_env
from pettingzoo.classic.hanabi.hanabi import raw_env
from typing import Dict, List, Optional, Union

import gymnasium
import numpy as np
from gymnasium import spaces
from gymnasium.utils import EzPickle

from pettingzoo import AECEnv
from pettingzoo.utils import agent_selector, wrappers
from hanabi_learning_environment.rl_env import HanabiEnv

In [2]:
# copy of rainbow used in deepmind paper
p = {
    'hidden_layers': [128,128],
    'gamma': 0.99,
    'lr': 1e-4,
    'target_update_freq': 500,
    'estimation_steps': 1,
    'num_train':32,
    'num_test':32,
    'buffer_size':50000,
    'vmax':25,
    'vmin':-25,
    'noisy_std':0.1,
    'atom_size':51,
    'minimum_replay_history':512,
    'batch_size':32,
    'steps_per_collect': 10016,
    'updates_per_train': 1563,
    'test_steps': 20000,
    'epochs':5000,
    'eps_decay_period': 200,
    'test_frequency': 3,
    'test_eps': 0,
    'save_frequency': 25,
    'eps_final':0.01,
    'adam_eps': 3.125e-5,
    'path': 'results_rainbow_infinite/',
    'lr_scheduler_factor': 0.5,
    'lr_scheduler_patience': 5
}


In [3]:
class raw_env_overwrite(raw_env):
    def __init__(
        self,
        colors: int = 5,
        ranks: int = 5,
        players: int = 2,
        hand_size: int = 5,
        max_information_tokens: int = 8,
        max_life_tokens: int = 3,
        observation_type: int = 1,
        random_start_player: bool = False,
        render_mode: Optional[str] = None,
    ):
        EzPickle.__init__(
            self,
            colors,
            ranks,
            players,
            hand_size,
            max_information_tokens,
            max_life_tokens,
            observation_type,
            random_start_player,
            render_mode,
        )

        # Check if all possible dictionary values are within a certain ranges.

        self._config = {
            "colors": colors,
            "ranks": ranks,
            "players": players,
            "hand_size": hand_size,
            "max_information_tokens": max_information_tokens,
            "max_life_tokens": max_life_tokens,
            "observation_type": observation_type,
            "random_start_player": random_start_player,
        }
        self.hanabi_env: HanabiEnv = HanabiEnv(config=self._config)

        # List of agent names
        self.agents = [f"player_{i}" for i in range(self.hanabi_env.players)]
        self.possible_agents = self.agents[:]

        self.agent_selection: str

        # Sets hanabi game to clean state and updates all internal dictionaries
        self.reset()

        # Set action_spaces and observation_spaces based on params in hanabi_env
        self.action_spaces = {
            name: spaces.Discrete(self.hanabi_env.num_moves()) for name in self.agents
        }
        self.observation_spaces = {
            player_name: spaces.Dict(
                {
                    "observation": spaces.Box(
                        low=0,
                        high=1,
                        shape=(self.hanabi_env.vectorized_observation_shape()[0],),
                        dtype=np.float32,
                    ),
                    "action_mask": spaces.Box(
                        low=0,
                        high=1,
                        shape=(self.hanabi_env.num_moves(),),
                        dtype=np.int8,
                    ),
                }
            )
            for player_name in self.agents
        }

        self.render_mode = render_mode
        
class HanabiScorePenalty:
    def __init__(self, env):
        self.env = env

    def __float__(self):
        return -float(self.env.hanabi_env.state.score())

def env(**kwargs):
    render_mode = kwargs.get("render_mode")
    if render_mode == "ansi":
        kwargs["render_mode"] = "human"
        env = raw_env_overwrite(**kwargs)
        env = wrappers.CaptureStdoutWrapper(env)
    else:
        env = raw_env_overwrite(**kwargs)

    env = wrappers.TerminateIllegalWrapper(env, illegal_reward=HanabiScorePenalty(env))
    env = wrappers.AssertOutOfBoundsWrapper(env)
    env = wrappers.OrderEnforcingWrapper(env)
    return env

In [4]:
def get_env(render_mode=None):
    return PettingZooEnv(env(colors=1, ranks=5, players=2, hand_size=2, max_information_tokens=10,
max_life_tokens=1, observation_type=1))

In [5]:
def get_agents(p):
    
    def noisy_linear(x, y):
        return NoisyLinear(x, y, p['noisy_std'])
    
    # Return Policy, Agents, Envs
    env = get_env()
    observation_space = env.observation_space['observation'] if isinstance(
    env.observation_space, gym.spaces.Dict
    ) else env.observation_space

    state_shape = observation_space.shape or observation_space.n
    action_shape = env.action_space.shape or env.action_space.n
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    net = Net(
            state_shape,
            action_shape,
            hidden_sizes=p['hidden_layers'],
            device = device,
            softmax = True,
            num_atoms = p['atom_size'],
            dueling_param = ({
                'linear_layer': noisy_linear
            }, {
                'linear_layer': noisy_linear})
    )

    optim = torch.optim.Adam(net.parameters(), lr= p['lr'], eps=p['adam_eps'])
    lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optim, mode = 'max', factor = p['lr_scheduler_factor'],
                                                              patience = p['lr_scheduler_patience'])

    agent = RainbowPolicy(
            net,
            optim,
            p['gamma'],
            num_atoms = p['atom_size'],
            v_min = p['vmin'],
            v_max = p['vmax'],
            estimation_step = p['estimation_steps'],
            target_update_freq=p['target_update_freq']
        ).to(device)

    agents = [agent, agent]
    policy = MultiAgentPolicyManager(agents, env)
    agents = env.agents

    train_envs = DummyVectorEnv([get_env for _ in range(p['num_train'])])
    test_envs = DummyVectorEnv([get_env for _ in range(p['num_test'])])
    
    return policy, agents, train_envs, test_envs, lr_scheduler

In [6]:
def get_collectors(
    policy,
    train_envs,
    test_envs,
    p
):
    
    # Get collectors
    train_collector = Collector(
    policy,
    train_envs,
    PrioritizedVectorReplayBuffer(p['buffer_size'], len(train_envs), alpha = 0.6, beta = 0.4, weight_norm=True),
    exploration_noise=True)
    
    test_collector = Collector(policy, test_envs, exploration_noise=True)
    
    return train_collector, test_collector

In [7]:
def initialize_buffer(
    train_collector,
    agents,
    policy,
    p
):
    for a in agents:
        policy.policies[a].set_eps(1)
    train_collector.collect(n_step = p['minimum_replay_history'])

In [8]:
def save_policy(policy, agents, p):
    for a in agents:
        torch.save(policy.policies[a].state_dict(), f'{p["path"]}{a}_params.pth')

def save_history(history, p):
    np.save(f'{p["path"]}training_rewards.npy', np.array(history))
    
def change_lr(optimizer, new_lr):
    # Run this to change the learning rate to 1e-5:
    for g in optimizer.param_groups:
        g['lr'] = new_lr

In [9]:
def get_eps(iteration, p):
    if iteration > p['eps_decay_period']:
        return p['eps_final']
    else:
        gradient = (1 - p['eps_final'])/p['eps_decay_period']
        return 1 - gradient*iteration
        
def set_eps(policy, agents, new_eps):
    for a in agents:
        policy.policies[a].set_eps(new_eps)
        
def train(
    policy,
    train_collector,
    test_collector,
    agents,
    p,
    lr_scheduler,
    training_history = []
):

    for i in tqdm(range(p['epochs'])):
        
        eps = get_eps(i, p)
        set_eps(policy, agents, eps)
        
        # Collection step
        result = train_collector.collect(n_step = p['steps_per_collect'])
        
        # Test Step
        if i%p['test_frequency'] == 0:
            set_eps(policy, agents, p['test_eps'])
            result = test_collector.collect(n_step = p['test_steps'])
            mean_reward = result['rews'].mean()
            tqdm.write(str(mean_reward))
            training_history.append(mean_reward)
            set_eps(policy, agents, eps)
            lr_scheduler.step(mean_reward)
    
        if i%p['save_frequency'] == 0:
            save_policy(policy, agents,p)
            save_history(training_history,p)
            plot_and_save(training_history, p['test_frequency'],p, show = False)
    
        # Update step (one epoch)
        for _ in range(p['updates_per_train']): 
            losses = policy.update(p['batch_size'], train_collector.buffer)
    
    plot_and_save(training_history, test_frequency)
        

In [10]:
def plot_and_save(training_history, test_frequency, p, save = True, show = True):
    x = np.arange(len(training_history))
    x *= test_frequency
    plt.plot(x, training_history)
    plt.title('Average Score (Cheating DQN, 1 Color game)')
    plt.xlabel('Epoch')
    plt.ylabel('Average Score (max 5)')
    if save: plt.savefig(f'{p["path"]}training_curve.png')
    if show:
        plt.show()
    else:
        plt.close()
        
def load(policy, agents, p):
    for a in agents:
        policy.policies[a].load_state_dict(torch.load(f'{p["path"]}{a}_params.pth'))
    his = list(np.load(f'{p["path"]}training_rewards.npy'))
    return his

In [11]:
policy, agents, train_envs, test_envs, lr_scheduler = get_agents(p)
train_collector, test_collector = get_collectors(policy, train_envs, test_envs, p)
initialize_buffer(train_collector, agents, policy, p)



In [12]:
#training_history = load(policy, agents,p)
training_history = []

In [13]:
train(policy, train_collector, test_collector, agents, p, lr_scheduler, training_history = training_history)

  0%|          | 0/5000 [00:00<?, ?it/s]

0.153427144802542




1.0284463894967177
1.553030303030303
1.9469214437367304
1.9474248927038627
2.074235807860262
2.189044038668099
2.6094736842105264
2.6102564102564103
2.029835390946502
3.011178861788618
2.8982528263103804
2.994908350305499
3.302938196555218
3.2436289500509683
3.6535044422507403
3.688311688311688
3.6726190476190474


ValueError: shape mismatch: value array of shape (32,7) could not be broadcast to indexing result of shape (32,)

In [None]:
policy, agents, train_envs, test_envs, lr_scheduler = get_agents(p)
train_collector, test_collector = get_collectors(policy, train_envs, test_envs, p)

In [None]:
_ = load(policy, agents, p)

In [None]:
set_eps(policy, agents, 0)
result = test_collector.collect(n_step = 200000)

In [None]:
rewards = result['rews'][:,1]
print(len(rewards))

In [None]:
games_played = len(rewards)
heights = np.zeros(6)
x = [str(i) for i in range(6)]
for i in range(6):
    heights[i] = np.sum(rewards==i)
percentages = heights*100/games_played
print(sum(percentages))

In [None]:
fig = plt.figure(figsize=(5,5))
plt.bar(x, percentages, label = 'DQN', alpha = 0.7)
for r,f in zip(x, percentages):
    plt.annotate(f'{round(f,2)}%', (r,f), ha='center', va='bottom')
plt.ylabel('% of games')
plt.xlabel('Score of games')
plt.title('Score breakdown of Rainbow')
plt.legend()
plt.savefig('results/rainbow_score_breakdown.png')
plt.show()