In [1]:
from maze_env import MazeEnv_v0
from env_utils.PettingZooEnv_new import PettingZooEnv_new
import supersuit
import numpy as np
from tianshou.env.utils import PettingZooEnv
import tianshou as ts
from tianshou.utils.net.common import Net
import torch
import torch.nn as nn
import torch.nn.functional as F
import gymnasium as gym
import wandb

In [2]:
eps_train, eps_test = 0.95, 0.0 # exploration rate for training and testing respectively
eps_decay, eps_min = 0.999, 0.15 # the exploration rate decay and the minimum exploration rate
lr, epochs, batch_size = 5e-4, 150, 512 # the learning rate, max epochs per new maze intro and the update batch size
gamma, n_step, target_update_freq = 0.9, 3, 100 # gamma in dqn formula, number of steps to look ahead, number of update calls before updating target network
train_num, test_num = 10, 1 # num of simultaneous training and testing environments respectively
buffer_size = 30000 # buffer size
step_per_epoch, step_per_collect, ep_per_collect = 10000, 200, 1 # number of steps for each epoch, number of steps to collect before updating, number of episodes before updating
maze_width = 6 # maze width (not incl. walls)
high_eps_run, obs_train, passed_mazes = False, True, True # for random high eps run, for interleaving (might be broken?), whether the policies passed the mazes
steps_total, steps_n = 0, 0 # steps count total, steps count within epoch
n_mazes, total_mazes = 0, 26 # start with 3 (it will add one later) mazes initially to prevent single maze overfitting, total number of random mazes
# for the trivial maze, we use 36 (since it should be 'easier')
test_mazes = [] # for printing later
threshold_rew = 0.5 # threshold reward to consider a maze passed (tentative value)

"""
logger = ts.utils.WandbLogger(train_interval=1, update_interval=1)
writer = SummaryWriter('log/test_maze')
writer.add_text("run 2", "wandb")
logger.load(writer)
"""

'\nlogger = ts.utils.WandbLogger(train_interval=1, update_interval=1)\nwriter = SummaryWriter(\'log/test_maze\')\nwriter.add_text("run 2", "wandb")\nlogger.load(writer)\n'

In [3]:
# define some helper functions
def preprocess_maze_env(render_mode=None, size=maze_width):
    env = MazeEnv_v0.env_single(render_mode=render_mode, size=size)
    env = supersuit.multiagent_wrappers.pad_observations_v0(env)
    env = PettingZooEnv_new(env)
    return env

"""
def preprocess_maze_env(render_mode=None, size=maze_width):
    env = MazeEnv_v0.env(render_mode=render_mode, size=size)
    env = supersuit.multiagent_wrappers.pad_observations_v0(env)
    env = PettingZooEnv_new(env)
    return env
"""

def interleave_training(obs_train):
    if obs_train:
        policy.policies[agents[0]].set_eps(eps_train)
        policy.policies[agents[1]].set_eps(0)
        obs_train = obs_train != True
    else:
        policy.policies[agents[0]].set_eps(0)
        policy.policies[agents[1]].set_eps(eps_train)
        obs_train = obs_train != True

def set_eps(eps1, eps2=None, single = False):
    if single:
        policy.set_eps(eps1)
    else:
        policy.policies[agents[0]].set_eps(eps1)
        policy.policies[agents[1]].set_eps(eps2)

# create a CNN for the observer
class CNN(nn.Module):
    def __init__(self):
        super().__init__()
        
        lin_size = ((((maze_width*2+1)-3+1)-3+1)-3+1)
        self.model = nn.Sequential(
            # assume maze size of 6x6 (13x13 with walls)
            nn.Conv2d(3, 16, 3), nn.ReLU(inplace=True),  # (13-3)+1 = 11, 
            nn.Conv2d(16, 32, 3), nn.ReLU(inplace=True), # 11-3+1=9, 
            nn.Conv2d(32, 64, 3), nn.ReLU(inplace=True), # 9-3+1=7
            nn.Flatten(), nn.Linear(64*lin_size*lin_size, 128), nn.ReLU(inplace=True),
            nn.Linear(128, 64), nn.ReLU(inplace=True),
            nn.Linear(64, 5)
        )
    
    def forward(self, obs, state=None, info={}):
        if not isinstance(obs, torch.Tensor):
            obs = torch.tensor(obs, dtype=torch.float)
        self.batch = obs.shape[0]
        #logits = self.model(obs.view(batch, -1))
        logits = self.model(obs)
        return logits, state

def watch(gym_reset_kwargs):
    assert gym_reset_kwargs is not None, "Please input reset kwargs i.e. options"
    # set policy to eval mode
    policy.eval()
    human_collector.reset_env(gym_reset_kwargs=gym_reset_kwargs)
    #np.random.seed()
    human_collector.collect(n_episode=1, render=1/120, gym_reset_kwargs=gym_reset_kwargs)
# reset back to training mode
    policy.train()

In [5]:
# get the vectorized training/testing environments
train_envs = ts.env.DummyVectorEnv([preprocess_maze_env for _ in range(train_num)])
test_envs = ts.env.DummyVectorEnv([preprocess_maze_env for _ in range(test_num)])

# set up training with no render environment
env = preprocess_maze_env()

# set up human render environment
env_human = preprocess_maze_env(render_mode="human")
env_human = ts.env.DummyVectorEnv([lambda: env_human])

# get agent names
agents = env.agents

# observation spaces/action spaces for the two agents
state_shape = env.observation_space.shape or env.observation_space.n
action_shape = env.action_space.shape or env.action_space.n

# define DQN network (128x3 hidden units linear)
#net_obs = Net(state_shape, action_shape, [128,128,128])
#net_obs = Net(state_shape, action_shape, [512, 512, 512])
net_obs = CNN()
optim_obs = torch.optim.Adam(params=net_obs.parameters(), lr=lr)

#net_exp = Net(state_shape, action_shape, [8])
#optim_exp = torch.optim.Adam(params=net_exp.parameters(), lr=lr)

# set up policy and collectors
agent_observer = ts.policy.DQNPolicy(net_obs, optim_obs, gamma, n_step, target_update_freq)
#agent_explorer = ts.policy.DQNPolicy(net_exp, optim_exp, gamma, n_step, target_update_freq)
#agent_policies = [agent_observer, agent_explorer]
#agent_policies = [ts.policy.RandomPolicy(), ts.policy.RandomPolicy()] # baseline testing
#policy = ts.policy.MultiAgentPolicyManager(agent_policies, env)

policy = agent_observer

# define the training collector (the calc q and step functions)
train_collector = ts.data.Collector(
    policy, 
    train_envs, 
    ts.data.VectorReplayBuffer(buffer_size, train_num),
    exploration_noise=True
)

abstraction_buffer = ts.data.ReplayBuffer(100000)
episode_buffer = ts.data.ReplayBuffer(100000)

since Python 3.9 and will be removed in a subsequent version. The only 
supported seed types are: None, int, float, str, bytes, and bytearray.
  random.seed(seed)


In [6]:
mazes = 1
policy.set_eps(1)

train_collector.reset_env(gym_reset_kwargs={"options":{"maze_type":"random", "n_mazes":mazes, "random":True}})
result = train_collector.collect(n_episode=ep_per_collect, gym_reset_kwargs={"options":{"maze_type":"random", "n_mazes":mazes, "random":True}})
print(result)

since Python 3.9 and will be removed in a subsequent version. The only 
supported seed types are: None, int, float, str, bytes, and bytearray.
  random.seed(seed)


In [42]:
# testing
# check if the episode was successful or not
if result['rews'] > 0:
    # sequentially add the batches into the 
    for i in range(int(result['idxs']), int(result['idxs'])+int(result['len'])):
        episode_buffer.add(train_collector.buffer[i])

In [33]:
train_collector.buffer[int(result['idxs']):int(result['idxs'] + result['lens'])]

Batch(
    obs: Batch(
             agent_id: array(['observer', 'observer', 'observer', 'observer', 'observer',
                              'observer', 'observer', 'observer', 'observer', 'observer',
                              'observer', 'observer', 'observer', 'observer', 'observer',
                              'observer', 'observer', 'observer', 'observer', 'observer',
                              'observer', 'observer', 'observer', 'observer', 'observer',
                              'observer', 'observer', 'observer', 'observer', 'observer',
                              'observer', 'observer', 'observer', 'observer', 'observer',
                              'observer', 'observer', 'observer', 'observer', 'observer',
                              'observer', 'observer', 'observer', 'observer', 'observer',
                              'observer', 'observer', 'observer', 'observer', 'observer',
                              'observer', 'observer', 'observer', 'observer',