In [1]:
import argparse
import os
from copy import deepcopy
from typing import Optional, Tuple

import gymnasium as gym
import numpy as np
import torch
from pettingzoo.classic import hanabi_v4
from torch.utils.tensorboard import SummaryWriter

from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import DummyVectorEnv
from tianshou.env.pettingzoo_env import PettingZooEnv
from tianshou.policy import (
    BasePolicy,
    RainbowPolicy,
    MultiAgentPolicyManager,
    A2CPolicy
)
from tianshou.utils import TensorboardLogger
from tianshou.utils.net.common import Net, ActorCritic
from tianshou.utils.net.discrete import Actor, Critic
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

In [2]:
# copy of rainbow used in deepmind paper
p = {
    'hidden_layers': [256,256],
    'gamma': 0.99,
    'lr': 1e-4,
    'target_update_freq': 500,
    'estimation_steps': 1,
    'num_train':32,
    'num_test':32,
    'buffer_size':50000,
    'vmax':25,
    'vmin':-25,
    'noisy_std':0.1,
    'atom_size':51,
    'minimum_replay_history':512,
    'batch_size':32,
    'steps_per_collect': 10016,
    'updates_per_train': 1563,
    'test_steps': 20000,
    'epochs':5000,
    'eps_decay_period': 100,
    'test_frequency': 3,
    'test_eps': 0,
    'save_frequency': 25,
    'eps_final':0.01,
    'adam_eps': 3.125e-5,
    'path': 'results/a2c/',
    'lr_scheduler_factor': 0.1,
    'lr_scheduler_patience': 20
}


In [3]:
def get_env(render_mode=None):
    return PettingZooEnv(hanabi_v4.env(colors=2, ranks=5, players=2, hand_size=3, max_information_tokens=3,
max_life_tokens=1, observation_type=1))

In [11]:
def get_agents(p):
    
    # Return Policy, Agents, Envs
    env = get_env()
    observation_space = env.observation_space['observation'] if isinstance(
    env.observation_space, gym.spaces.Dict
    ) else env.observation_space

    state_shape = observation_space.shape or observation_space.n
    action_shape = env.action_space.shape or env.action_space.n
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    net = Net(
            state_shape,
            action_shape,
            hidden_sizes=p['hidden_layers'],
            device = device,
    ).to(device)
    actor = Actor(net, action_shape, device = device).to(device)
    critic = Critic(net, device = device).to(device)
    optim = torch.optim.Adam(ActorCritic(actor, critic).parameters(),
                             lr= p['lr'], eps=p['adam_eps'])
    
    lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optim, mode = 'max', factor = p['lr_scheduler_factor'],
                                                              patience = p['lr_scheduler_patience'])

    dist = torch.distributions.Categorical
    agent = A2CPolicy(
            actor,
            critic,
            optim,
            dist,
            discount_factor = p['gamma'])

    agents = [agent, agent]
    policy = MultiAgentPolicyManager(agents, env)
    agents = env.agents

    train_envs = DummyVectorEnv([get_env for _ in range(p['num_train'])])
    test_envs = DummyVectorEnv([get_env for _ in range(p['num_test'])])
    
    return policy, agents, train_envs, test_envs, lr_scheduler

In [12]:
def get_collectors(
    policy,
    train_envs,
    test_envs,
    p
):
    
    # Get collectors
    train_collector = Collector(
    policy,
    train_envs,
    VectorReplayBuffer(p['buffer_size'], len(train_envs)))
    
    test_collector = Collector(policy, test_envs)
    
    return train_collector, test_collector

In [18]:
def initialize_buffer(
    train_collector,
    agents,
    policy,
    p
):
    #for a in agents:
        #policy.policies[a].set_eps(1)
    train_collector.collect(n_step = p['minimum_replay_history'])

In [19]:
def save_policy(policy, agents, p):
    for a in agents:
        torch.save(policy.policies[a].state_dict(), f'{p["path"]}{a}_params.pth')

def save_history(history, p):
    np.save(f'{p["path"]}training_rewards.npy', np.array(history))
    
def change_lr(optimizer, new_lr):
    # Run this to change the learning rate to 1e-5:
    for g in optimizer.param_groups:
        g['lr'] = new_lr

In [20]:
def get_eps(iteration, p):
    if iteration > p['eps_decay_period']:
        return p['eps_final']
    else:
        gradient = (1 - p['eps_final'])/p['eps_decay_period']
        return 1 - gradient*iteration
        
def set_eps(policy, agents, new_eps):
    for a in agents:
        policy.policies[a].set_eps(new_eps)
        
def train(
    policy,
    train_collector,
    test_collector,
    agents,
    p,
    lr_scheduler,
    training_history = []
):

    for i in tqdm(range(p['epochs'])):
        
        # Collection step
        result = train_collector.collect(n_step = p['steps_per_collect'])
        
        # Test Step
        if i%p['test_frequency'] == 0:
            result = test_collector.collect(n_step = p['test_steps'])
            mean_reward = result['rews'].mean()
            tqdm.write(str(mean_reward))
            training_history.append(mean_reward)
            lr_scheduler.step(mean_reward)
    
        if i%p['save_frequency'] == 0:
            save_policy(policy, agents,p)
            save_history(training_history,p)
            plot_and_save(training_history, p['test_frequency'],p, show = False)
    
        # Update step (one epoch)
        for _ in range(p['updates_per_train']): 
            losses = policy.update(p['batch_size'], train_collector.buffer)
    
    plot_and_save(training_history, test_frequency)
        

In [21]:
def plot_and_save(training_history, test_frequency, p, save = True, show = True):
    x = np.arange(len(training_history))
    x *= test_frequency
    plt.plot(x, training_history)
    plt.title('Combined Average Score (A2C, 2 Color game, Handsize = 3)')
    plt.xlabel('Epoch')
    plt.ylabel('Average Score (max 10)')
    if save: plt.savefig(f'{p["path"]}training_curve.png')
    if show:
        plt.show()
    else:
        plt.close()
        
def load(policy, agents, p):
    for a in agents:
        policy.policies[a].load_state_dict(torch.load(f'{p["path"]}{a}_params.pth'))
    his = list(np.load(f'{p["path"]}training_rewards.npy'))
    return his

In [22]:
policy, agents, train_envs, test_envs, lr_scheduler = get_agents(p)
train_collector, test_collector = get_collectors(policy, train_envs, test_envs, p)
initialize_buffer(train_collector, agents, policy, p)

AttributeError: 'str' object has no attribute 'ndim'

In [23]:
#training_history = load(policy, agents,p)
training_history = []

In [24]:
train(policy, train_collector, test_collector, agents, p, lr_scheduler, training_history = training_history)

  0%|          | 0/5000 [00:00<?, ?it/s]

AttributeError: 'str' object has no attribute 'ndim'