In [1]:
import argparse
import os
from copy import deepcopy
from typing import Optional, Tuple

import gymnasium as gym
import numpy as np
import torch
from pettingzoo.classic import hanabi_v4
from torch.utils.tensorboard import SummaryWriter

from tianshou.data import Collector, PrioritizedVectorReplayBuffer
from tianshou.env import DummyVectorEnv
from tianshou.env.pettingzoo_env import PettingZooEnv
from tianshou.policy import (
    BasePolicy,
    RainbowPolicy,
    MultiAgentPolicyManager,
    RandomPolicy,
)
from tianshou.trainer import offpolicy_trainer
from tianshou.utils import TensorboardLogger
from tianshou.utils.net.common import Net
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

In [2]:
def get_env(render_mode=None):
    return PettingZooEnv(hanabi_v4.env(colors=2, ranks=5, players=2, hand_size=4, max_information_tokens=5,
max_life_tokens=2, observation_type=1))

In [3]:
def get_agents(
    lr = 1e-4, 
    hidden_layers = [256, 256], 
    gamma = 0.99,
    target_update_freq = 200, 
    estimation_steps = 3, 
    num_train = 20, 
    num_test = 20
):
    
    # Return Policy, Agents, Envs
    env = get_env()
    observation_space = env.observation_space['observation'] if isinstance(
    env.observation_space, gym.spaces.Dict
    ) else env.observation_space

    state_shape = observation_space.shape or observation_space.n
    action_shape = env.action_space.shape or env.action_space.n
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    net1 = Net(
            state_shape,
            51,
            hidden_sizes=hidden_layers,
            device = device).to(device)

    net2 = Net(
        state_shape,
        51,
        hidden_sizes=hidden_layers,
        device = device).to(device)

    optim1 = torch.optim.Adam(net1.parameters(), lr= lr)
    optim2 = torch.optim.Adam(net2.parameters(), lr = lr)

    agent1 = RainbowPolicy(
            net1,
            optim1,
            gamma,
            num_atoms = 51,
            v_min = -10,
            v_max = 10,
            estimation_step = estimation_steps,
            target_update_freq=target_update_freq
        )

    agent2 = RainbowPolicy(
            net2,
            optim2,
            gamma,
            num_atoms = 51,
            v_min = -10,
            v_max = 10,
            estimation_step = estimation_steps,
            target_update_freq=target_update_freq
        )

    agents = [agent1, agent2]
    policy = MultiAgentPolicyManager(agents, env)
    agents = env.agents

    train_envs = DummyVectorEnv([get_env for _ in range(num_train)])
    test_envs = DummyVectorEnv([get_env for _ in range(num_test)])
    
    return policy, agents, train_envs, test_envs

In [4]:
def get_collectors(
    policy,
    train_envs,
    test_envs,
    buffer_size
):
    
    # Get collectors
    train_collector = Collector(
    policy,
    train_envs,
    PrioritizedVectorReplayBuffer(buffer_size, len(train_envs), alpha = 0.6, beta = 0.4),
    exploration_noise=True)
    
    test_collector = Collector(policy, test_envs, exploration_noise=True)
    
    return train_collector, test_collector

In [5]:
def initialize_buffer(
    train_collector,
    buffer_size,
    agents,
    policy,
    eps = 0.1
):
    for a in agents:
        policy.policies[a].set_eps(1)
    train_collector.collect(n_step = buffer_size)
    for a in agents:
        policy.policies[a].set_eps(eps)

In [6]:
def save_policy(policy, agents):
    for a in agents:
        torch.save(policy.policies[a].state_dict(), f'saved_data/training_group_2/{a}_params.pth')

def save_history(history):
    np.save(f'saved_data/training_group_2/training_rewards.npy', np.array(history))
    
def change_lr(optimizer, new_lr):
    # Run this to change the learning rate to 1e-5:
    for g in optimizer.param_groups:
        g['lr'] = new_lr

In [11]:
def train(
    policy,
    train_collector,
    test_collector,
    agents,
    epochs=2000,
    collection_steps_per_epoch=10000,
    updates_per_epoch=5000,
    test_frequency=5,
    test_steps=50000,
    save_frequency = 50,
    batch_size = 32,
    eps = 0.1,
    training_history = []
):
    for i in tqdm(range(epochs)):
        
        # Collection step
        result = train_collector.collect(n_step = collection_steps_per_epoch)
        
        # Test Step
        if i%test_frequency == 0:
            for a in agents:
                policy.policies[a].set_eps(0)
            result = test_collector.collect(n_step = test_steps)
            mean_reward = result['rews'].mean()
            tqdm.write(str(mean_reward))
            training_history.append(mean_reward)
            for a in agents:
                policy.policies[a].set_eps(eps)
    
        if i%save_frequency == 0:
            save_policy(policy, agents)
            save_progress(training_history)
    
        # Update step (one epoch)
        for _ in range(updates_per_epoch): losses = policy.update(batch_size, train_collector.buffer)
    
    plot_and_save(training_history, test_frequency)

In [12]:
def plot_and_save(training_history, test_frequency, save = True):
    x = np.arange(len(training_history))
    x *= test_frequency
    plt.plot(x, training_history)
    plt.title('Combined Average Score (2 player, 2 colors, 5 ranks)')
    plt.xlabel('Epoch')
    plt.ylabel('Average Score (max 10)')
    if save: plt.savefig(f'saved_data/training_group_2/training_curve.png')

In [9]:
policy, agents, train_envs, test_envs = get_agents(lr = 1e-5)
train_collector, test_collector = get_collectors(policy, train_envs, test_envs, buffer_size = 100000)

In [10]:
train(policy, train_collector, test_collector, agents)

  0%|          | 0/2000 [00:00<?, ?it/s]

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!