In [1]:
import argparse
import os
from copy import deepcopy
from typing import Optional, Tuple

import gymnasium as gym
import numpy as np
import torch
from pettingzoo.classic import hanabi_v4
from torch.utils.tensorboard import SummaryWriter

from tianshou.data import Collector, PrioritizedVectorReplayBuffer
from tianshou.env import DummyVectorEnv
from tianshou.env.pettingzoo_env import PettingZooEnv
from tianshou.policy import (
    BasePolicy,
    RainbowPolicy,
    MultiAgentPolicyManager,
    RandomPolicy,
)
from tianshou.utils.net.discrete import NoisyLinear
from tianshou.trainer import offpolicy_trainer
from tianshou.utils import TensorboardLogger
from tianshou.utils.net.common import Net
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

In [2]:
# copy of rainbow used in deepmind paper
p = {
    'hidden_layers': [256,256],
    'gamma': 0.99,
    'lr': 1e-4,
    'target_update_freq': 500,
    'estimation_steps': 1,
    'num_train':32,
    'num_test':16,
    'buffer_size':50000,
    'vmax':25,
    'vmin':-25,
    'noisy_std':0.1,
    'atom_size':51,
    'minimum_replay_history':500,
    'batch_size':32,
    'steps_per_collect': 4,
    'updates_per_train': 1,
    'test_steps': 10000,
    'epochs':int(10e6),
    'eps_decay_period': 1000,
    'test_frequency': 1000,
    'test_eps': 0,
    'save_frequency': 10000,
    'eps_final':0.01,
    'adam_eps': 3.125e-5,
    'path': 'saved_data/rainbow_hanabi_small/',
    'lr_scheduler_factor': 0.5
}


In [3]:
def get_env(render_mode=None):
    return PettingZooEnv(hanabi_v4.env(colors=2, ranks=5, players=2, hand_size=2, max_information_tokens=3,
max_life_tokens=1, observation_type=1))

In [4]:
def get_agents(p):
    
    def noisy_linear(x, y):
        return NoisyLinear(x, y, p['noisy_std'])
    
    # Return Policy, Agents, Envs
    env = get_env()
    observation_space = env.observation_space['observation'] if isinstance(
    env.observation_space, gym.spaces.Dict
    ) else env.observation_space

    state_shape = observation_space.shape or observation_space.n
    action_shape = env.action_space.shape or env.action_space.n
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    net = Net(
            state_shape,
            action_shape,
            hidden_sizes=p['hidden_layers'],
            device = device,
            softmax = True,
            num_atoms = p['atom_size'],
            dueling_param = ({
                'linear_layer': noisy_linear
            }, {
                'linear_layer': noisy_linear})
    )

    optim = torch.optim.Adam(net.parameters(), lr= p['lr'], eps=p['adam_eps'])
    lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optim, mode = 'max', factor = p['lr_scheduler_factor'],
                                                              patience = 10)

    agent = RainbowPolicy(
            net,
            optim,
            p['gamma'],
            num_atoms = p['atom_size'],
            v_min = p['vmin'],
            v_max = p['vmax'],
            estimation_step = p['estimation_steps'],
            target_update_freq=p['target_update_freq']
        ).to(device)

    agents = [agent, agent]
    policy = MultiAgentPolicyManager(agents, env)
    agents = env.agents

    train_envs = DummyVectorEnv([get_env for _ in range(p['num_train'])])
    test_envs = DummyVectorEnv([get_env for _ in range(p['num_test'])])
    
    return policy, agents, train_envs, test_envs, lr_scheduler

In [5]:
def get_collectors(
    policy,
    train_envs,
    test_envs,
    p
):
    
    # Get collectors
    train_collector = Collector(
    policy,
    train_envs,
    PrioritizedVectorReplayBuffer(p['buffer_size'], len(train_envs), alpha = 0.6, beta = 0.4, weight_norm=True),
    exploration_noise=True)
    
    test_collector = Collector(policy, test_envs, exploration_noise=True)
    
    return train_collector, test_collector

In [6]:
def initialize_buffer(
    train_collector,
    agents,
    policy,
    p
):
    for a in agents:
        policy.policies[a].set_eps(1)
    train_collector.collect(n_step = p['minimum_replay_history'])

In [7]:
def save_policy(policy, agents, p):
    for a in agents:
        torch.save(policy.policies[a].state_dict(), f'{p["path"]}{a}_params.pth')

def save_history(history, p):
    np.save(f'{p["path"]}training_rewards.npy', np.array(history))
    
def change_lr(optimizer, new_lr):
    # Run this to change the learning rate to 1e-5:
    for g in optimizer.param_groups:
        g['lr'] = new_lr

In [8]:
def get_eps(iteration, p):
    if iteration > p['eps_decay_period']:
        return p['eps_final']
    else:
        gradient = (1 - p['eps_final'])/p['eps_decay_period']
        return 1 - gradient*iteration
        
def set_eps(policy, agents, new_eps):
    for a in agents:
        policy.policies[a].set_eps(new_eps)
        
def train(
    policy,
    train_collector,
    test_collector,
    agents,
    p,
    lr_scheduler,
    training_history = []
):

    for i in tqdm(range(p['epochs'])):
        
        eps = get_eps(i, p)
        set_eps(policy, agents, eps)
        
        # Collection step
        result = train_collector.collect(n_step = p['steps_per_collect'])
        
        # Test Step
        if i%p['test_frequency'] == 0:
            set_eps(policy, agents, p['test_eps'])
            result = test_collector.collect(n_step = p['test_steps'])
            mean_reward = result['rews'].mean()
            tqdm.write(str(mean_reward))
            training_history.append(mean_reward)
            set_eps(policy, agents, eps)
            lr_scheduler.step(mean_reward)
    
        if i%p['save_frequency'] == 0:
            save_policy(policy, agents,p)
            save_history(training_history,p)
            plot_and_save(training_history, p['test_frequency'],p, show = False)
    
        # Update step (one epoch)
        for _ in range(p['updates_per_train']): 
            losses = policy.update(p['batch_size'], train_collector.buffer)
    
    plot_and_save(training_history, test_frequency)
        

In [9]:
def plot_and_save(training_history, test_frequency, p, save = True, show = True):
    x = np.arange(len(training_history))
    x *= test_frequency
    plt.plot(x, training_history)
    plt.title('Combined Average Score (Rainbow, 2 Color game)')
    plt.xlabel('Epoch')
    plt.ylabel('Average Score (max 10)')
    if save: plt.savefig(f'{p["path"]}training_curve.png')
    if show:
        plt.show()
    else:
        plt.close()
        
def load(policy, agents, p):
    for a in agents:
        policy.policies[a].load_state_dict(torch.load(f'{p["path"]}{a}_params.pth'))
    his = list(np.load(f'{p["path"]}training_rewards.npy'))
    return his

In [10]:
policy, agents, train_envs, test_envs, lr_scheduler = get_agents(p)
train_collector, test_collector = get_collectors(policy, train_envs, test_envs, p)
initialize_buffer(train_collector, agents, policy, p)



In [11]:
#training_history = load(policy, agents,p)
training_history = []

In [12]:
train(policy, train_collector, test_collector, agents, p, lr_scheduler, training_history = training_history)

  0%|          | 0/10000000 [00:00<?, ?it/s]



0.21396396396396397




0.5458333333333333
0.7407407407407407
1.0170068027210883
0.8825396825396825
1.2673267326732673
1.2452229299363058
1.2453416149068324
0.9369085173501577
1.2535612535612535
1.160458452722063
1.1913439635535308
1.1717171717171717
1.2
1.2222222222222223
1.2729591836734695
1.4436860068259385
1.0769230769230769
1.255813953488372
1.4625407166123778
1.3647798742138364
1.5906040268456376
1.5394321766561514
1.763157894736842
1.7922077922077921
1.8417721518987342
1.7258566978193146
1.7350993377483444
1.950657894736842
1.7130681818181819
1.8123324396782843
1.957516339869281
1.891025641025641
2.1024844720496896
1.9516616314199395
2.2315112540192925
2.5728155339805827
2.084084084084084
2.11
2.3854748603351954
2.2036474164133737
2.6018808777429467
2.5047619047619047
2.2507552870090635
2.3084415584415585
2.3669467787114846
2.2605042016806722
2.1171171171171173
2.211038961038961
2.462025316455696
2.6807909604519775
2.2493074792243766
2.4684931506849317
2.710691823899371
2.453731343283582
2.832826747720

4.875
4.833846153846154
4.809968847352025
4.8246153846153845
4.798136645962733
4.853211009174312
4.872670807453416
4.806153846153846
4.735202492211838
4.790123456790123
4.8580246913580245
4.773291925465839
4.758513931888545
4.700934579439252
4.806153846153846
4.827160493827161
4.830246913580247
4.7894736842105265
4.814814814814815
4.876160990712075
4.814241486068111
4.711180124223603
4.763803680981595
4.766355140186916
4.782208588957055
4.793846153846154
4.773291925465839
4.805555555555555
4.838006230529595
4.796296296296297
4.71076923076923
4.782608695652174
4.787037037037037
4.777089783281734
4.66358024691358
4.809968847352025
4.739938080495356
4.733746130030959
4.790123456790123
4.756172839506172
4.745341614906832
4.768518518518518
4.795031055900621
4.833846153846154
4.770186335403727
4.814241486068111
4.720496894409938
4.698757763975156
4.79375
4.846625766871166
4.843076923076923
4.804347826086956
4.7407407407407405
4.8478260869565215
4.808049535603715
4.875776397515528
4.919753086

KeyboardInterrupt: 