In [68]:
import argparse
import os
from copy import deepcopy
from typing import Optional, Tuple

import gymnasium as gym
import numpy as np
import torch
from torch.distributions import Independent, Normal
from pettingzoo.classic import hanabi_v4
from torch.utils.tensorboard import SummaryWriter

from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import DummyVectorEnv
from tianshou.env.pettingzoo_env import PettingZooEnv
from tianshou.policy import (
    PPOPolicy,
    MultiAgentPolicyManager)

from tianshou.trainer import OnpolicyTrainer
from tianshou.utils import TensorboardLogger
from tianshou.utils.net.common import ActorCritic, Net
from tianshou.utils.net.continuous  import ActorProb, Critic
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

In [69]:
p = {
    'actor_hidden_layers': [256,256],
    'critic_hidden_layers': [256, 256],
    'gamma': 1,
    'lr': 1e-5,
    'target_update_freq': 200,
    'estimation_steps': 1,
    'num_train':50,
    'num_test':50,
    'buffer_size':50000,
    'batch_size':32,
    'steps_per_collect': 10000,
    'updates_per_train': 500,
    'test_steps': 10000,
    'epochs':1000,
    'test_frequency':5,
    'save_frequency':25,
    'eps':0.05}
path = 'saved_data/AC_ppo'

In [86]:
def get_env(render_mode=None):
    return PettingZooEnv(hanabi_v4.env(colors=5, ranks=5, players=2, hand_size=4, max_information_tokens=5,
max_life_tokens=2, observation_type=1))

In [87]:
def get_agents(p):
    env = get_env()
    observation_space = env.observation_space['observation'] if isinstance(
    env.observation_space, gym.spaces.Dict
    ) else env.observation_space
    state_shape = observation_space.shape or observation_space.n
    action_shape = env.action_space.shape or env.action_space.n
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    train_envs = DummyVectorEnv([get_env for _ in range(p['num_train'])])
    test_envs = DummyVectorEnv([get_env for _ in range(p['num_test'])])
    
    #model
    net_critic = Net(state_shape, hidden_sizes = p['critic_hidden_layers'], device = device).to(device)
    net_actor = Net(state_shape, hidden_sizes = p['actor_hidden_layers'], device = device).to(device)
    
    actor = ActorProb(net_actor, action_shape, unbounded=True, device = device).to(device)
    critic = Critic(net_critic).to(device)
    
    actor_critic = ActorCritic(actor, critic)
    
    #orthogonal initialization
    for m in actor_critic.modules():
        if isinstance(m, torch.nn.Linear):
            torch.nn.init.orthogonal_(m.weight)
            torch.nn.init.zeros_(m.bias)
            
    optim = torch.optim.Adam(actor_critic.parameters(), lr=p['lr'])
    
    def dist(*logits):
        return Independent(Normal(*logits), 1)
    
    agent = PPOPolicy(
    actor,
    critic, 
    optim,
    dist, 
    discount_factor = p['gamma'])
    
    agents = [agent, agent]
    policy = MultiAgentPolicyManager(agents, env)
    agents = env.agents
    
    return policy, agents, train_envs, test_envs

In [88]:
def get_collectors(policy, train_envs, test_envs, p):
    train_collector = Collector(policy, train_envs, VectorReplayBuffer(p['buffer_size'], len(train_envs)))
    test_collector = Collector(policy, test_envs)
    return train_collector, test_collector

In [89]:
def save_policy(policy, agents, path):
    for a in agents:
        torch.save(policy.policies[a].state_dict(), f'{path}/{a}_params.pth')
        
def save_history(history, path):
    np.save(f'{path}/training_rewards.npy', np.array(history))
    
def change_lr(optimizer, new_lr):
    # Run this to change the learning rate to 1e-5:
    for g in optimizer.param_groups:
        g['lr'] = new_lr
        
def plot_and_save(training_history, test_frequency, path, save = True):
    x = np.arange(len(training_history))
    x *= test_frequency
    plt.plot(x, training_history, linewidth=1)
    plt.title('Combined Average Score (PPO Actor Critic, 2 Color game)')
    plt.xlabel('Epoch')
    plt.ylabel('Average Score (max 10)')
    if save: plt.savefig(f'{path}/training_curve.png')
        
def load(policy, path, agents):
    for a in agents:
        policy.policies[a].load_state_dict(torch.load(path + f'{a}_params.pth'))
    his = list(np.load(f'{path}/training_rewards.npy'))
    return his

In [90]:
def train(policy, train_collector, test_collector, p, agents, path, training_history=[]):
    
    for i in tqdm(range(p['epochs'])):
        
        # Collection step
        result = train_collector.collect(n_step = p['steps_per_collect'])
        
        # Test Step
        if i%p['test_frequency'] == 0:
            for a in agents:
                policy.policies[a].set_eps(0)
            result = test_collector.collect(n_step = p['test_steps'])
            mean_reward = result['rews'].mean()
            tqdm.write(str(mean_reward))
            training_history.append(mean_reward)
            for a in agents:
                policy.policies[a].set_eps(p['eps'])
    
        if i%p['save_frequency'] == 0:
            save_policy(policy, agents, path)
            save_history(training_history, path)
            plot_and_save(training_history, p['test_frequency'], path)
    
        # Update step (one epoch)
        for _ in range(p['updates_per_train']): 
            losses = policy.update(p['batch_size'], train_collector.buffer)
    
    plot_and_save(training_history, p['test_frequency'], path)

In [91]:
policy, agents, train_envs, test_envs = get_agents(p)
train_collector, test_collector = get_collectors(policy, train_envs, test_envs, p)
train_collector.collect(n_episode = 100)

AttributeError: 'str' object has no attribute 'ndim'

In [83]:
train(policy, train_collector, test_collector, p, agents, path)

  0%|          | 0/1000 [00:00<?, ?it/s]

AttributeError: 'str' object has no attribute 'ndim'