In [2]:
import tianshou as ts
import numpy as np
from tianshou.policy import (
    BasePolicy,
    MultiAgentPolicyManager
)
from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import DummyVectorEnv
from tianshou.env.pettingzoo_env import PettingZooEnv
from pettingzoo.classic import hanabi_v4
import gymnasium as gym
from tqdm import tqdm

In [4]:
p = {
    'gamma': 0.99,
    'lr': 1e-4,
    'num_train':32,
    'num_test':32,
    'test_steps': 20000,
    'epochs':10**7,
    'test_frequency':20000,
    'eps':0.01
}
path = 'results/test_1/'

In [23]:
# Need to define a new policy class that is compatible with tabular q learning
""" 
Function to define: 
1) forward (takes in a Batch, output a Batch)
2) learn (takes in a Batch, outputs a Dict)
"""


class Tabular_QL(BasePolicy):
    
    def __init__(self, discount_factor = 1, learning_rate = 1e-5, **kwargs):
        super().__init__(**kwargs)
        self.discount_factor = discount_factor
        self.learning_rate = learning_rate
        self.Q_table = {}
        self.eps = 0
        
    def set_eps(self, eps):
        self.eps = eps
        
    def train(self, mode: bool):
        self.training = mode
        
    def forward(self):
        pass
    
    def process_fn(self, batch, buffer, indicies):
        pass
    
    def learn(self, batch):
        # Takes in a batch and updates the Q table.
        q = self(batch).logits
        q = q[np.arange(len(q)), batch.act]
        print(q)

In [24]:
def get_env(render_mode=None):
    return PettingZooEnv(hanabi_v4.env(colors=2, ranks=5, players=2, hand_size=4, max_information_tokens=5,
max_life_tokens=2, observation_type=1))

In [25]:
def get_agents(
    lr = 1e-4, 
    hidden_layers = [256, 256], 
    gamma = 0.99,
    target_update_freq = 200, 
    estimation_steps = 1, 
    num_train = 50, 
    num_test = 50,
    atom_size = 51,
    noisy_std = 0.1):
    
    env = get_env()
    observation_space = env.observation_space['observation'] if isinstance(
    env.observation_space, gym.spaces.Dict
    ) else env.observation_space
    state_shape = observation_space.shape or observation_space.n
    action_shape = env.action_space.shape or env.action_space.n
    agent = Tabular_QL()
    agents = [agent, agent]
    policy = MultiAgentPolicyManager(agents, env)
    agents = env.agents
    train_envs = DummyVectorEnv([get_env for _ in range(num_train)])
    test_envs = DummyVectorEnv([get_env for _ in range(num_test)])
    return policy, agents, train_envs, test_envs

In [26]:
def get_collectors(
    policy,
    train_envs,
    test_envs,
    buffer_size
):
    
    # Get collectors
    train_collector = Collector(
    policy,
    train_envs,
    VectorReplayBuffer(buffer_size, len(train_envs), alpha = 0.6, beta = 0.4, weight_norm=True),
    exploration_noise=True)
    
    test_collector = Collector(policy, test_envs, exploration_noise=True)
    
    return train_collector, test_collector

In [27]:
def initialize_buffer(
    train_collector,
    buffer_size,
    agents,
    policy,
    eps = 0.1
):
    for a in agents:
        policy.policies[a].set_eps(1)
    train_collector.collect(n_step = buffer_size)
    for a in agents:
        policy.policies[a].set_eps(eps)

In [28]:
def save_policy(policy, agents):
    for a in agents:
        torch.save(policy.policies[a].state_dict(), f'saved_data/training_group_2/{a}_params.pth')

def save_history(history):
    np.save(f'saved_data/training_group_2/training_rewards.npy', np.array(history))
    
def change_lr(optimizer, new_lr):
    # Run this to change the learning rate to 1e-5:
    for g in optimizer.param_groups:
        g['lr'] = new_lr

In [29]:
def train(
    policy,
    train_collector,
    test_collector,
    agents,
    epochs=1000,
    collection_steps_per_epoch=10000,
    updates_per_epoch=500,
    test_frequency=5,
    test_steps=50000,
    save_frequency = 50,
    batch_size = 32,
    eps = 0.1,
    training_history = []
):
    for i in tqdm(range(epochs)):
        
        # Collection step
        result = train_collector.collect(n_step = collection_steps_per_epoch)
        
        # Test Step
        if i%test_frequency == 0:
            for a in agents:
                policy.policies[a].set_eps(0)
            result = test_collector.collect(n_step = test_steps)
            mean_reward = result['rews'].mean()
            tqdm.write(str(mean_reward))
            training_history.append(mean_reward)
            for a in agents:
                policy.policies[a].set_eps(eps)
    
        #if i%save_frequency == 0:
            #save_policy(policy, agents)
            #save_history(training_history)
    
        # Update step (one epoch)
        for _ in range(updates_per_epoch): 
            losses = policy.update(batch_size, train_collector.buffer)
    
    plot_and_save(training_history, test_frequency)

In [32]:
policy, agents, train_envs, test_envs = get_agents()
train_collector, test_collector = get_collectors(policy, train_envs, test_envs, buffer_size = p['buffer_size'])

In [37]:
train(policy, train_collector, test_collector, agents)

  0%|                                                                                          | 0/1000 [00:00<?, ?it/s]


TypeError: forward() got an unexpected keyword argument 'batch'