In [1]:
import argparse
import os
from copy import deepcopy
from typing import Optional, Tuple

import gymnasium as gym
import numpy as np
import torch
from pettingzoo.classic import hanabi_v4
from torch.utils.tensorboard import SummaryWriter

from tianshou.data import Collector, PrioritizedVectorReplayBuffer
from tianshou.env import DummyVectorEnv
from tianshou.env.pettingzoo_env import PettingZooEnv
from tianshou.policy import (
    BasePolicy,
    DQNPolicy,
    MultiAgentPolicyManager,
    RandomPolicy,
)
from tianshou.trainer import offpolicy_trainer
from tianshou.utils import TensorboardLogger
from tianshou.utils.net.common import Net
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

In [None]:
p = {
    'hidden_layers': [256,256],
    'gamma': 0.99,
    'lr': 1e-4,
    'target_update_freq': 200,
    'estimation_steps': 1,
    'num_train':50,
    'num_test':50,
    'buffer_size':50000,
    'batch_size':32,
    'steps_per_collect': 10000,
    'updates_per_train': 500,
    'test_steps': 50000
    'epochs':1000,
    'test_frequency':5,
    'save_frequency':25
}
path = 'saved_data/training_group_3/'

In [2]:
def get_env(render_mode=None):
    return PettingZooEnv(hanabi_v4.env(colors=2, ranks=5, players=2, hand_size=4, max_information_tokens=5,
max_life_tokens=2, observation_type=1))

In [3]:
def get_agents(
    lr = 1e-4, 
    hidden_layers = [256, 256], 
    gamma = 0.99,
    target_update_freq = 200, 
    estimation_steps = 1, 
    num_train = 50, 
    num_test = 50
):
    
    # Return Policy, Agents, Envs
    env = get_env()
    observation_space = env.observation_space['observation'] if isinstance(
    env.observation_space, gym.spaces.Dict
    ) else env.observation_space

    state_shape = observation_space.shape or observation_space.n
    action_shape = env.action_space.shape or env.action_space.n
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    net1 = Net(
            state_shape,
            action_shape,
            hidden_sizes=hidden_layers,
            device = device).to(device)

    net2 = Net(
        state_shape,
        action_shape,
        hidden_sizes=hidden_layers,
        device = device).to(device)

    optim1 = torch.optim.Adam(net1.parameters(), lr= lr)
    optim2 = torch.optim.Adam(net2.parameters(), lr = lr)

    agent1 = DQNPolicy(
            net1,
            optim1,
            gamma,
            estimation_steps,
            target_update_freq=target_update_freq
        )

    agent2 = DQNPolicy(
            net2,
            optim2,
            gamma,
            estimation_steps,
            target_update_freq=target_update_freq
    )

    agents = [agent1, agent1]
    policy = MultiAgentPolicyManager(agents, env)
    agents = env.agents

    train_envs = DummyVectorEnv([get_env for _ in range(num_train)])
    test_envs = DummyVectorEnv([get_env for _ in range(num_test)])
    
    return policy, agents, train_envs, test_envs

In [4]:
def get_collectors(
    policy,
    train_envs,
    test_envs,
    buffer_size=50000
):
    
    # Get collectors
    train_collector = Collector(
    policy,
    train_envs,
    PrioritizedVectorReplayBuffer(buffer_size, len(train_envs), alpha = 0.6, beta = 0.4, weight_norm = True),
    exploration_noise=True)
    
    test_collector = Collector(policy, test_envs, exploration_noise=True)
    
    return train_collector, test_collector

In [5]:
def initialize_buffer(
    train_collector,
    buffer_size,
    agents,
    policy,
    eps = 0.1
):
    for a in agents:
        policy.policies[a].set_eps(1)
    train_collector.collect(n_step = buffer_size)
    for a in agents:
        policy.policies[a].set_eps(eps)

In [6]:
def save_policy(policy, agents):
    for a in agents:
        torch.save(policy.policies[a].state_dict(), f'saved_data/training_group_3/{a}_params.pth')

def save_history(history):
    np.save(f'saved_data/training_group_3/training_rewards.npy', np.array(history))
    
def change_lr(optimizer, new_lr):
    # Run this to change the learning rate to 1e-5:
    for g in optimizer.param_groups:
        g['lr'] = new_lr

In [7]:
def train(
    policy,
    train_collector,
    test_collector,
    agents,
    epochs=1000,
    collection_steps_per_epoch=10000,
    updates_per_epoch=500,
    test_frequency=5,
    test_steps=50000,
    save_frequency = 50,
    batch_size = 32,
    eps = 0.1,
    training_history = []
):
    for i in tqdm(range(epochs)):
        
        # Collection step
        result = train_collector.collect(n_step = collection_steps_per_epoch)
        
        # Test Step
        if i%test_frequency == 0:
            for a in agents:
                policy.policies[a].set_eps(0)
            result = test_collector.collect(n_step = test_steps)
            mean_reward = result['rews'].mean()
            tqdm.write(str(mean_reward))
            training_history.append(mean_reward)
            for a in agents:
                policy.policies[a].set_eps(eps)
    
        if i%save_frequency == 0:
            save_policy(policy, agents)
            save_history(training_history)

        # Update step (one epoch)
        for _ in range(updates_per_epoch): losses = policy.update(batch_size, train_collector.buffer)
    
    plot_and_save(training_history, test_frequency)

In [8]:
def plot_and_save(training_history, test_frequency, save = True):
    x = np.arange(len(training_history))
    x *= test_frequency
    plt.plot(x, training_history)
    plt.title('Combined Average Score (2 player, 2 colors, 5 ranks)')
    plt.xlabel('Epoch')
    plt.ylabel('Average Score (max 10)')
    if save: plt.savefig(f'saved_data/training_group_3/training_curve.png')

In [9]:
policy, agents, train_envs, test_envs = get_agents(lr = p['lr'], hidden_layers = p['hidden_layers'], gamma = p['gamma'], 
    target_update_freq = p['target_update_freq'], estimation_steps = p['estimation_steps'], num_train = p['num_train'], 
    num_test = p['num_test'])
train_collector, test_collector = get_collectors(policy, train_envs, test_envs, buffer_size = p['buffer_size'])
initialize_buffer(train_collector, p['buffer_size'], agents, policy, eps = p['eps'])



In [None]:
train(policy, train_collector, test_collector, agents,
    epochs=p['epochs'],
    collection_steps_per_epoch=p['steps_per_collect'],
    updates_per_epoch=p['updates_per_train'],
    test_frequency=p['test_frequency'],
    test_steps=p['test_steps'],
    save_frequency = p['save_frequnecy'],
    batch_size = p['batch_size'],
    eps = p['eps'])

  0%|          | 0/1000 [00:00<?, ?it/s]

0.0




1.6433915211970074
1.6549057515708072
2.225836431226766
2.5367577756833177
2.7200563644903712
2.9105188005711566
3.007718282682103
2.8875
2.928603302097278
3.145325653522375
3.088510638297872
2.91169671752196
3.0460193281178096
3.03096274794388
3.4030261348005504
3.4640858208955225
3.5244239631336405
3.8306878306878307
3.666076957098629
4.049978270317253
3.9185119574844998
4.019715726730857
4.066666666666666
4.172336615935541
4.328077091546211
4.170105263157895
4.40356698465367
4.399200710479573
4.526777875329236
3.8417782026768643
