In [None]:
import numpy as np
import jax
import jax.numpy as jnp

from pathlib import Path
from typing import List, Tuple, Dict, Callable, Optional
from gymnasium.spaces import MultiBinary
from agents.agent import Agent

from dl_algos.dqn import DQNetwork
from logging import Logger

LEADER_ID = 0
TOM_ID = 1
RNG_SEED = 20240729
CONF = 1.0

def is_deadlock(history: List, new_state: str, last_actions: Tuple) -> bool:

	if len(history) < 3:
		return False

	deadlock = True
	# if all([act == Action.NONE for act in last_actions]) or all([act == Action.LOAD for act in last_actions]):
	# 	return False
	#
	# else:
	state_repitition = 0
	for state in history:
		if new_state == state:
			state_repitition += 1
	if state_repitition < 3:
		deadlock = False

	return deadlock


# LB-Foraging

In [None]:
from agents.tom_agent import TomAgent
from dl_envs.lb_foraging.lb_foraging_coop import FoodCOOPLBForaging
from dl_envs.lb_foraging.lb_foraging import Action, Direction

def coordinate_agents(env: FoodCOOPLBForaging, predict_task: str, actions: Tuple[int, int]) -> Tuple[int, int]:

	objective = env.obj_food
	player_pos = [player.position for player in env.players]
	objective_adj = env.get_adj_pos(objective[0], objective[1])

	if all([pos in objective_adj for pos in player_pos]):
		if predict_task == str(objective):
			return Action.LOAD, Action.LOAD
		else:
			return actions

	else:
		leader_pos = player_pos[LEADER_ID]
		tom_pos = player_pos[TOM_ID]
		lead_direction = Direction[Action(actions[LEADER_ID]).name].value
		tom_direction = Direction[Action(actions[TOM_ID]).name].value
		next_lead_pos = (leader_pos[0] + lead_direction[0], leader_pos[1] + lead_direction[1])
		next_tom_pos = (tom_pos[0] + tom_direction[0], tom_pos[1] + tom_direction[1])
		if next_lead_pos == next_tom_pos or all([act == Action.LOAD for act in actions]):
			return actions[LEADER_ID], Action.NONE.value
		else:
			return actions


def load_models(opt_models_dir: Path, leg_models_dir: Path, n_foods_spawn: int, food_locs: List[Tuple], foods_lvl: int, num_layers: int, act_function: Callable,
                layer_sizes: List[int], gamma: float, use_cnn: bool, use_dueling: bool, use_ddqn: bool, cnn_shape: Tuple, cnn_properties: List = None) -> Tuple[Dict, Dict]:
	optim_models = {}
	leg_models = {}
	opt_model_names = [fname.name for fname in (opt_models_dir / ('%d-foods_%d-food-level' % (n_foods_spawn, foods_lvl)) / 'best').iterdir()]
	leg_model_names = [fname.name for fname in (leg_models_dir / ('%d-foods_%d-food-level' % (n_foods_spawn, foods_lvl)) / 'best').iterdir()]
	try:
		for loc in food_locs:
			# Find the optimal model name for the food location
			model_name = ''
			for name in opt_model_names:
				if name.find("%sx%s" % (loc[0], loc[1])) != -1:
					model_name = name
					break
			assert model_name != ''
			opt_dqn = DQNetwork(len(Action), num_layers, act_function, layer_sizes, gamma, use_dueling, use_ddqn, use_cnn, cnn_properties)
			opt_dqn.load_model(model_name, opt_models_dir / ('%d-foods_%d-food-level' % (n_foods_spawn, foods_lvl)) / 'best', None, cnn_shape, True)
			optim_models[str(loc)] = opt_dqn

			# Find the legible model name for the food location
			model_name = ''
			for name in leg_model_names:
				if name.find("%sx%s" % (loc[0], loc[1])) != -1:
					model_name = name
					break
			assert model_name != ''
			leg_dqn = DQNetwork(len(Action), num_layers, act_function, layer_sizes, gamma, use_dueling, use_ddqn, use_cnn, cnn_properties)
			leg_dqn.load_model(model_name, leg_models_dir / ('%d-foods_%d-food-level' % (n_foods_spawn, foods_lvl)) / 'best', None, cnn_shape, True)
			leg_models[str(loc)] = leg_dqn

		return optim_models, leg_models

	except AssertionError as e:
		print(e)
		return {}, {}

def eval_legibility(n_runs: int, test_mode: int, opt_models_dir: Path, leg_models_dir: Path, field_dims: Tuple[int, int], n_agents: int,
                    player_level: int, player_sight: int, max_foods: int, max_foods_spawn: int, food_locs: List[Tuple], foods_lvl: int, max_steps: int, gamma: float,
                    num_layers: int, act_function: Callable, layer_sizes: List[int], use_cnn: bool, use_dueling: bool, use_ddqn: bool,
                    cnn_properties: List = None, use_render: bool = False, start_run: int = 0):

	env = FoodCOOPLBForaging(n_agents, player_level, field_dims, max_foods, player_sight, max_steps, True, foods_lvl, RNG_SEED, food_locs, use_render=use_render,
	                         use_encoding=True, agent_center=True, grid_observation=use_cnn)
	if isinstance(env.observation_space, MultiBinary):
		obs_space = MultiBinary([*env.observation_space.shape[1:]])
	else:
		obs_space = env.observation_space[0]
	cnn_shape = (0,) if not use_cnn else (*obs_space.shape[1:], obs_space.shape[0])

	start_optim_models, start_leg_models = load_models(opt_models_dir, leg_models_dir, max_foods_spawn, food_locs, foods_lvl, num_layers, act_function, layer_sizes, gamma, use_cnn,
	                                       use_dueling, use_ddqn, cnn_shape, cnn_properties)
	results = {}
	for run_nr in range(start_run, n_runs):

		rng_seed = RNG_SEED + run_nr
		# Initialize the agents for the interaction
		if test_mode == 0:
			leader_agent = Agent(LEADER_ID, start_optim_models, rng_seed)
			tom_agent = TomAgent(TOM_ID, start_optim_models, start_optim_models, rng_seed, 1)
		elif test_mode == 1:
			leader_agent = Agent(LEADER_ID, start_optim_models, rng_seed)
			tom_agent = TomAgent(TOM_ID, start_leg_models, start_optim_models, rng_seed, 1)
		elif test_mode == 2:
			leader_agent = Agent(LEADER_ID, start_leg_models, rng_seed)
			tom_agent = TomAgent(TOM_ID, start_optim_models, start_leg_models, rng_seed, 1)
		else:
			leader_agent = Agent(LEADER_ID, start_leg_models, rng_seed)
			tom_agent = TomAgent(TOM_ID, start_leg_models, start_leg_models, rng_seed, 1)

		env = FoodCOOPLBForaging(n_agents, player_level, field_dims, max_foods, player_sight, max_steps, True, foods_lvl, rng_seed, food_locs, use_render=use_render,
								 use_encoding=True, agent_center=True, grid_observation=use_cnn)
		it_results = {}
		rng_gen = np.random.default_rng(rng_seed)
		spawned_foods = [food_locs[idx] for idx in rng_gen.choice(max_foods, size=max_foods_spawn, replace=False)]
		foods_left = spawned_foods.copy()
		n_foods_left = max_foods_spawn
		start_obj = foods_left.pop(rng_gen.integers(max_foods_spawn))
		task = str(start_obj)

		# Setup agents for test
		tasks = [str(food) for food in spawned_foods]
		tasks.sort()
		leader_agent.init_interaction(tasks)
		tom_agent.init_interaction(tasks)

		# Setup environment for test
		env.food_spawn_pos = spawned_foods
		env.n_food_spawn = max_foods_spawn
		env.set_objective(start_obj)
		env.spawn_players()
		env.spawn_food(max_foods_spawn, foods_lvl)
		if isinstance(env.observation_space, MultiBinary):
			obs_space = MultiBinary([*env.observation_space.shape[1:]])
		else:
			obs_space = env.observation_space[0]
		cnn_shape = (0,) if not use_cnn else (*obs_space.shape[1:], obs_space.shape[0])
		obs, *_ = env.reset()

		recent_states = [''.join([''.join(str(x) for x in p.position) for p in env.players]) + ''.join([''.join(str(x) for x in f.position) for f in env.foods])]
		if use_cnn:
			leader_obs = obs[0].reshape((1, *cnn_shape))
			tom_obs = obs[1].reshape((1, *cnn_shape))
		else:
			leader_obs = obs[0]
			tom_obs = obs[1]
		actions = (leader_agent.action(leader_obs, (leader_obs, Action.NONE), CONF, None, task),
				   tom_agent.action(tom_obs, (leader_obs, Action.NONE), CONF, None, tom_agent.predict_task))

		timeout = False
		n_steps = 0
		n_pred_steps = []
		steps_food = []
		deadlock_states = []
		n_deadlocks = 0
		act_try = 0
		later_error = 0
		later_food_step = 0

		if use_render:
			env.render()

		print('Started run number %d:' % (run_nr + 1))
		print(env.get_full_env_log())
		while n_foods_left > 1 and not timeout:
			print('Run number %d, step %d: remaining %d foods, predicted objective %s and real objective %s from ' % (run_nr + 1, n_steps + 1, n_foods_left, tom_agent.predict_task, task) + str(foods_left))
			n_steps += 1
			last_leader_sample = (leader_obs, actions[0])
			if task != tom_agent.predict_task:
				later_error = n_steps
			obs, _, _, timeout, _ = env.step(actions)
			if use_render:
				env.render()
			current_food_count = np.sum([not food.picked for food in env.foods])

			if use_cnn:
				leader_obs = obs[0].reshape((1, *cnn_shape))
				tom_obs = obs[1].reshape((1, *cnn_shape))
			else:
				leader_obs = obs[0]
				tom_obs = obs[1]

			if timeout:
				n_pred_steps += [later_error - later_food_step]
				steps_food += [n_steps - later_food_step]
				break

			elif current_food_count < n_foods_left:
				n_foods_left = current_food_count
				n_pred_steps += [later_error - later_food_step]
				steps_food += [n_steps - later_food_step]
				later_food_step = n_steps
				later_error = n_steps

				if current_food_count > 0:
					# Update tasks remaining and samples
					tasks = [str(food) for food in foods_left]
					tasks.sort()
					tom_agent.reset_inference(tasks)
					last_leader_sample = (leader_obs, Action.NONE)
					recent_states = []

					# Update decision models
					optim_models, leg_models = load_models(opt_models_dir, leg_models_dir, n_foods_left, food_locs, foods_lvl, num_layers, act_function, layer_sizes,
														   gamma, use_cnn, use_dueling, use_ddqn, cnn_shape, cnn_properties)
					if test_mode == 0:
						leader_agent.goal_models = optim_models
						tom_agent.goal_models = optim_models
						tom_agent.sample_models = optim_models
					elif test_mode == 1:
						leader_agent.goal_models = optim_models
						tom_agent.goal_models = leg_models
						tom_agent.sample_models = optim_models
					elif test_mode == 2:
						leader_agent.goal_models = leg_models
						tom_agent.goal_models = optim_models
						tom_agent.sample_models = leg_models
					else:
						leader_agent.goal_models = leg_models
						tom_agent.goal_models = leg_models
						tom_agent.sample_models = leg_models

					# Get next objective
					next_obj = foods_left.pop(rng_gen.integers(n_foods_left))
					task = str(next_obj)
					env.set_objective(next_obj)

			current_state = ''.join([''.join(str(x) for x in p.position) for p in env.players]) + ''.join([''.join(str(x) for x in f.position) for f in env.foods])
			if is_deadlock(recent_states, current_state, actions):
				n_deadlocks += 1
				if current_state not in deadlock_states:
					deadlock_states.append(current_state)
				act_try += 1
				actions = (leader_agent.sub_acting(leader_obs, None, act_try - 1, last_leader_sample, CONF, task),
						   tom_agent.sub_acting(tom_obs, None, act_try, last_leader_sample, CONF))
				# actions = (leader_agent.action(leader_obs, last_leader_sample, CONF, None, task), tom_agent.sub_acting(tom_obs, None, act_try, last_leader_sample, CONF))
			else:
				act_try = 0
				actions = (leader_agent.action(leader_obs, last_leader_sample, CONF, None, task), tom_agent.action(tom_obs, last_leader_sample, CONF, None))

			actions = coordinate_agents(env, tom_agent.predict_task, actions)

			recent_states.append(current_state)
			if len(recent_states) > 3:
				recent_states.pop(0)

		env.close()
		print('Run Over!!')
		it_results['n_steps'] = n_steps
		it_results['pred_steps'] = n_pred_steps
		it_results['avg_pred_steps'] = np.mean(n_pred_steps) if len(n_pred_steps) > 0 else 0
		it_results['caught_foods'] = max_foods_spawn - n_foods_left
		it_results['steps_food'] = steps_food
		it_results['deadlocks'] = n_deadlocks
		results[run_nr] = it_results


# Pursuit-Prey

In [None]:
from dl_envs.pursuit.pursuit_env import TargetPursuitEnv, Action, ActionDirection
from dl_envs.pursuit.agents.random_prey import RandomPrey
from dl_envs.pursuit.agents.greedy_prey import GreedyPrey
from dl_envs.pursuit.agents.agent import Agent as PreyAgent

PREY_TYPES = {'idle': 0, 'greedy': 1, 'random': 2}


class TomAgent(Agent):

	_goal_prob: jnp.ndarray
	_sample_models: Dict[str, DQNetwork]
	_interaction_likelihoods: jnp.ndarray
	_sign: float
	_predict_task: str

	def __init__(self, agent_id: int, goal_models: Dict[str, DQNetwork], sample_models: Dict[str, DQNetwork], rng_seed: int = 1234567890, sign: float = -1):

		super().__init__(agent_id, goal_models, rng_seed)
		self._sample_models = sample_models
		self._goal_prob = jnp.array([])
		self._interaction_likelihoods = jnp.array([])
		self._sign = sign
		self._predict_task = ''

	@property
	def goal_prob(self) -> jnp.ndarray:
		return self._goal_prob

	@property
	def interaction_likelihoods(self) -> jnp.ndarray:
		return self._interaction_likelihoods

	@property
	def predict_task(self) -> str:
		return self._predict_task

	@property
	def sample_models(self) -> Dict[str, DQNetwork]:
		return self._sample_models

	@goal_prob.setter
	def goal_prob(self, goal_prob: jnp.ndarray) -> None:
		self._goal_prob = goal_prob

	@sample_models.setter
	def sample_models(self, sample_models: Dict[str, DQNetwork]) -> None:
		self._sample_models = sample_models

	def add_sample_model(self, task: str, model: DQNetwork) -> None:
		self._sample_models[task] = model

	def remove_sample_model(self, task: str) -> None:
		self._sample_models.pop(task)

	def init_interaction(self, interaction_tasks: List[str]):
		self._tasks = interaction_tasks.copy()
		self._n_tasks = len(interaction_tasks)
		self._goal_prob = jnp.ones(self._n_tasks) / self._n_tasks
		self._interaction_likelihoods = jnp.ones(self._n_tasks)
		self._predict_task = interaction_tasks[0]

	def reset_inference(self, tasks: List = None):
		if tasks:
			self._tasks = tasks.copy()
			self._n_tasks = len(self._tasks)
		self._interaction_likelihoods = jnp.ones(self._n_tasks)
		self._goal_prob = jnp.ones(self._n_tasks) / self._n_tasks
		self._predict_task = self._tasks[0]

	def sample_probability(self, obs: jnp.ndarray, a: int, conf: float) -> jnp.ndarray:
		goals_likelihood = []
		model_id = list(self._goal_models.keys())[0]

		for task_idx in range(self._n_tasks):
			q = jax.device_get(self._sample_models[model_id].q_network.apply(self._sample_models[model_id].online_state.params, obs[task_idx])[0])
			goals_likelihood += [jnp.exp(self._sign * conf * (q[a] - q.max())) / jnp.sum(jnp.exp(self._sign * conf * (q - q.max())))]

		goals_likelihood = jnp.array(goals_likelihood)
		return goals_likelihood

	def task_inference(self) -> str:
		if not self._tasks:
			print('[ERROR]: List of possible tasks not defined!!')
			return ''

		if len(self._interaction_likelihoods) > 0:
			likelihood = jnp.cumprod(jnp.array(self._interaction_likelihoods), axis=0)[-1]
		else:
			likelihood = jnp.zeros(self._n_tasks)
		goals_prob = self._goal_prob * likelihood
		goals_prob_sum = goals_prob.sum()
		if goals_prob_sum == 0:
			p_max = jnp.ones(self._n_tasks) / self._n_tasks
		else:
			p_max = goals_prob / goals_prob_sum
		high_likelihood = jnp.argwhere(p_max == jnp.amax(p_max)).ravel()
		self._rng_key, subkey = jax.random.split(self._rng_key)
		return self._tasks[jax.random.choice(subkey, high_likelihood)]

	def bayesian_task_inference(self, sample: Tuple[jnp.ndarray, int], conf: float) -> Tuple[str, float]:

		if not self._tasks:
			print('[ERROR]: List of possible tasks not defined!!')
			return '', -1

		states, action = sample
		sample_prob = self.sample_probability(states, action, conf)
		self._interaction_likelihoods = jnp.vstack((self._interaction_likelihoods, sample_prob))

		likelihoods = jnp.cumprod(self._interaction_likelihoods, axis=0)[-1]
		goals_prob = likelihoods * self._goal_prob
		goals_prob_sum = goals_prob.sum()
		if goals_prob_sum == 0:
			p_max = jnp.ones(self._n_tasks) / self._n_tasks
		else:
			p_max = goals_prob / goals_prob_sum
		max_idx = jnp.argwhere(p_max == jnp.amax(p_max)).ravel()
		self._rng_key, subkey = jax.random.split(self._rng_key)
		max_task_prob = jax.random.choice(subkey, max_idx)
		task_conf = float(p_max[max_task_prob])
		task_id = self._tasks[max_task_prob]

		return task_id, task_conf

	def get_actions(self, task_id: str, obs: jnp.ndarray) -> int:
		model_id = list(self._goal_models.keys())[0]
		q = jax.device_get(self._goal_models[model_id].q_network.apply(self._goal_models[model_id].online_state.params, obs)[0])
		pol = jnp.isclose(q, q.max(), rtol=1e-10, atol=1e-10).astype(int)
		pol = pol / pol.sum()
		# print(self._agent_id, task_id, q, q - q.max(), pol)

		self._rng_key, subkey = jax.random.split(self._rng_key)
		return int(jax.random.choice(subkey, len(q), p=pol))

	def action(self, obs: jnp.ndarray, sample: Tuple[jnp.ndarray, int], conf: float, logger: Optional[Logger], task: str = '') -> int:
		predict_task, predict_conf = self.bayesian_task_inference(sample, conf)
		self._predict_task = predict_task
		return self.get_actions(self._predict_task, obs)

	def sub_acting(self, obs: jnp.ndarray, logger: Optional[Logger], act_try: int, sample: Tuple[jnp.ndarray, int], conf: float, task: str = '') -> int:
		predict_task, predict_conf = self.bayesian_task_inference(sample, conf)
		self._predict_task = predict_task
		return super().sub_acting(obs, logger, act_try, sample, conf, self._predict_task if task == '' else task)


def coordinate_agents(env: TargetPursuitEnv, predict_task: str, actions: Tuple[int], n_tom_agents: int) -> Tuple[int]:

	objective = env.target
	hunter_pos = [env.agents[h_id].pos for h_id in env.hunter_ids]
	objective_adj = env.adj_pos(env.agents[objective].pos)

	if sum([pos in objective_adj for pos in hunter_pos]) >= env.n_catch:
		if predict_task == str(objective):
			return tuple([Action.STAY] * env.n_hunters)
		else:
			return actions

	else:
		leader_pos = hunter_pos[LEADER_ID]
		lead_direction = ActionDirection[Action(actions[LEADER_ID]).name].value
		next_lead_pos = (leader_pos[0] + lead_direction[0], leader_pos[1] + lead_direction[1])
		tom_pos = []
		tom_directions = []
		next_tom_pos = []
		for idx in range(n_tom_agents):
			tom_pos.append(hunter_pos[TOM_ID + idx])
			tom_directions.append(ActionDirection[Action(actions[TOM_ID + idx]).name].value)
			next_tom_pos.append((tom_pos[-1][0] + tom_directions[-1][0], tom_pos[-1][1] + tom_directions[-1][1]))

		coord_acts = actions
		for idx in range(n_tom_agents):
			if next_tom_pos[idx] == next_lead_pos or (idx > 0 and next_tom_pos[idx] == next_tom_pos[idx - 1]):
				coord_acts = coord_acts[:TOM_ID + idx] + (Action.STAY.value, ) + coord_acts[TOM_ID + idx + 1:]

		return coord_acts


def load_models(opt_models_dir: Path, leg_models_dir: Path, n_hunters: int, prey_type: str, n_preys_alive: int, num_layers: int, act_function: Callable,
                layer_sizes: List[int], gamma: float, use_cnn: bool, use_dueling: bool, use_ddqn: bool, cnn_shape: Tuple, cnn_properties: List = None) -> Tuple[Dict, Dict]:
	optim_models = {}
	leg_models = {}
	opt_model_names = [fname.name for fname in (opt_models_dir / ('%d-hunters' % n_hunters) / ('%s-prey' % prey_type) / 'best').iterdir()]
	leg_model_names = [fname.name for fname in (leg_models_dir / ('%d-hunters' % n_hunters) / ('%s-prey' % prey_type) / 'best').iterdir()]
	try:
		# Find the optimal model name for the food location
		model_name = ''
		for name in opt_model_names:
			if name.find("%d" % n_preys_alive) != -1:
				model_name = name
				break
		assert model_name != ''
		opt_dqn = DQNetwork(len(Action), num_layers, act_function, layer_sizes, gamma, use_dueling, use_ddqn, use_cnn, cnn_properties)
		opt_dqn.load_model(model_name, opt_models_dir / ('%d-hunters' % n_hunters) / ('%s-prey' % prey_type)  / 'best', None, cnn_shape, True)
		optim_models['p%d' % n_preys_alive] = opt_dqn

		# Find the legible model name for the food location
		model_name = ''
		for name in leg_model_names:
			if name.find("%d" % n_preys_alive) != -1:
				model_name = name
				break
		assert model_name != ''
		leg_dqn = DQNetwork(len(Action), num_layers, act_function, layer_sizes, gamma, use_dueling, use_ddqn, use_cnn, cnn_properties)
		leg_dqn.load_model(model_name, leg_models_dir / ('%d-hunters' % n_hunters) / ('%s-prey' % prey_type) / 'best', None, cnn_shape, True)
		leg_models['p%d' % n_preys_alive] = leg_dqn

		return optim_models, leg_models

	except AssertionError as e:
		print(e)
		return {}, {}

def eval_legibility(n_runs: int, test_mode: int, opt_models_dir: Path, leg_models_dir: Path, field_dims: Tuple[int, int], hunters: List[Tuple[str, int]],
                    preys: List[Tuple[str, int]], player_sight: int, prey_ids: List[str], prey_type: str, require_catch: bool, catch_reward: float, max_steps: int, gamma: float,
                    num_layers: int, act_function: Callable, layer_sizes: List[int], use_cnn: bool, use_dueling: bool, use_ddqn: bool,
                    cnn_properties: List = None, use_render: bool = False, start_run: int = 0):

	env = TargetPursuitEnv(hunters, preys, field_dims, player_sight, prey_ids[0], require_catch, max_steps, use_layer_obs=True, agent_centered=True, catch_reward=catch_reward)
	if isinstance(env.observation_space, MultiBinary):
		obs_space = MultiBinary([*env.observation_space.shape[1:]])
	else:
		obs_space = env.observation_space[0]
	cnn_shape = (0,) if not use_cnn else (*obs_space.shape[1:], obs_space.shape[0])

	start_optim_models, start_leg_models = load_models(opt_models_dir, leg_models_dir, env.n_hunters, prey_type, env.n_preys, num_layers, act_function, layer_sizes, gamma, use_cnn,
	                                       use_dueling, use_ddqn, cnn_shape, cnn_properties)
	results = {}
	for run_nr in range(start_run, n_runs):

		rng_seed = RNG_SEED + run_nr
		# Initialize the agents for the interaction
		n_hunters = len(hunters)
		n_tom_hunters = n_hunters - 1
		n_preys = len(preys)
		if test_mode == 0:
			leader_agent = Agent(LEADER_ID, start_optim_models, rng_seed)
			tom_agents = [TomAgent(TOM_ID + idx, start_optim_models, start_optim_models, rng_seed, 1) for idx in range(n_tom_hunters)]
		elif test_mode == 1:
			leader_agent = Agent(LEADER_ID, start_optim_models, rng_seed)
			tom_agents = [TomAgent(TOM_ID + idx, start_leg_models, start_optim_models, rng_seed, 1) for idx in range(n_tom_hunters)]
		elif test_mode == 2:
			leader_agent = Agent(LEADER_ID, start_leg_models, rng_seed)
			tom_agents = [TomAgent(TOM_ID + idx, start_optim_models, start_leg_models, rng_seed, 1) for idx in range(n_tom_hunters)]
		else:
			leader_agent = Agent(LEADER_ID, start_leg_models, rng_seed)
			tom_agents = [TomAgent(TOM_ID + idx, start_leg_models, start_leg_models, rng_seed, 1) for idx in range(n_tom_hunters)]

		env = TargetPursuitEnv(hunters, preys, field_dims, player_sight, prey_ids[0], require_catch, max_steps, use_layer_obs=True, agent_centered=True, catch_reward=catch_reward)
		env.seed(rng_seed)
		it_results = {}
		rng_gen = np.random.default_rng(rng_seed)

		# Setup agents for test
		preys_left = prey_ids.copy()
		task = preys_left.pop(rng_gen.integers(n_preys))
		tasks = prey_ids.copy()
		tasks.sort()
		leader_agent.init_interaction(tasks)
		for idx in range(n_tom_hunters):
			tom_agents[idx].init_interaction(tasks)
		prey_agents = {}
		for i, prey_id in enumerate(prey_ids):
			if prey_type == 'random':
				prey_agents[prey_id] = RandomPrey(prey_id, 2, 0, rng_seed + i)
			elif prey_type == 'greedy':
				prey_agents[prey_id] = GreedyPrey(prey_id, 2, 0, rng_seed + i)
			else:
				prey_agents[prey_id] = PreyAgent(prey_id, 2, 0, rng_seed + i)

		# Setup environment for test
		env.reset_init_pos()
		env.target = task
		if isinstance(env.observation_space, MultiBinary):
			obs_space = MultiBinary([*env.observation_space.shape[1:]])
		else:
			obs_space = env.observation_space[0]
		cnn_shape = (0,) if not use_cnn else (*obs_space.shape[1:], obs_space.shape[0])
		n_preys_alive = n_preys
		obs, *_ = env.reset()

		recent_states = [''.join([''.join(str(x) for x in env.agents[a_id].pos) for a_id in env.agents.keys() if env.agents[a_id].alive])]
		if use_cnn:
			leader_obs = obs[LEADER_ID].reshape((1, *cnn_shape))
			leader_sample = [env.make_target_grid_obs(prey)[LEADER_ID].reshape((1, *cnn_shape))  for prey in env.prey_alive_ids]
			tom_obs = [[env.make_target_grid_obs(prey)[idx + 1].reshape((1, *cnn_shape)) for prey in env.prey_alive_ids if prey == tom_agents[idx].predict_task][0] for idx in range(n_tom_hunters)]
		else:
			leader_obs = obs[LEADER_ID]
			leader_sample = [env.make_target_grid_obs(prey)[LEADER_ID] for prey in env.prey_alive_ids]
			tom_obs = [[env.make_target_grid_obs(prey)[idx + 1] for prey in env.prey_alive_ids if prey == tom_agents[idx].predict_task][0] for idx in range(n_tom_hunters)]
		actions = (leader_agent.action(leader_obs, (leader_sample, Action.STAY), CONF, None, 'p%d' % n_preys_alive),
				   *[tom_agents[idx].action(tom_obs[idx], (leader_sample, Action.STAY), CONF, None, 'p%d' % n_preys_alive) for idx in range(n_tom_hunters)])

		for prey_id in env.prey_alive_ids:
			actions += (prey_agents[prey_id].act(env), )

		timeout = False
		n_steps = 0
		n_pred_steps = []
		steps_capture = []
		deadlock_states = []
		n_deadlocks = 0
		act_try = 0
		later_error = 0
		later_food_step = 0

		if use_render:
			env.render()

		print('Started run number %d:' % (run_nr + 1))
		print(env.get_full_env_log())
		while n_preys_alive > 1 and not timeout:
			predicted_objectives = ','.join(['%s for tom agent %d' % (tom_agents[idx].predict_task, tom_agents[idx].agent_id) for idx in range(n_tom_hunters)])
			print('Run number %d, step %d: remaining %d preys, predicted objective %s and real objective %s from ' % (run_nr + 1, n_steps + 1,
																													  env.n_preys_alive, predicted_objectives, task) + ', '.join(env.prey_alive_ids))
			n_steps += 1
			if use_cnn:
				last_leader_sample = ([env.make_target_grid_obs(prey)[LEADER_ID].reshape((1, *cnn_shape)) for prey in env.prey_alive_ids], actions[LEADER_ID])
			else:
				last_leader_sample = ([env.make_target_grid_obs(prey)[LEADER_ID] for prey in env.prey_alive_ids], actions[LEADER_ID])
			if any([task != tom_agents[idx].predict_task for idx in range(n_tom_hunters)]):
				later_error = n_steps
			print('Actions: %s' % ', '.join([str(Action(action).name) for action in actions]))
			obs, _, _, timeout, _ = env.step(actions)
			if use_render:
				env.render()

			if timeout:
				n_pred_steps += [later_error - later_food_step]
				steps_capture += [n_steps - later_food_step]
				break

			elif env.n_preys_alive < n_preys_alive:
				n_preys_alive = env.n_preys_alive
				n_pred_steps += [later_error - later_food_step]
				steps_capture += [n_steps - later_food_step]
				later_food_step = n_steps
				later_error = n_steps

				if env.n_preys_alive > 0:
					# Update tasks remaining and samples
					tasks = env.prey_alive_ids.copy()
					tasks.sort()
					for idx in range(n_tom_hunters):
						tom_agents[idx].init_interaction(tasks)
					if use_cnn:
						last_leader_sample = ([env.make_target_grid_obs(prey)[LEADER_ID].reshape((1, *cnn_shape))  for prey in env.prey_alive_ids], Action.STAY)
					else:
						last_leader_sample = ([env.make_target_grid_obs(prey)[LEADER_ID] for prey in env.prey_alive_ids], Action.STAY)
					recent_states = []

					# Update decision models
					optim_models, leg_models = load_models(opt_models_dir, leg_models_dir, n_hunters, prey_type, n_preys_alive, num_layers, act_function, layer_sizes,
														   gamma, use_cnn, use_dueling, use_ddqn, cnn_shape, cnn_properties)
					if test_mode == 0:
						leader_agent.goal_models = optim_models
						for idx in range(n_tom_hunters):
							tom_agents[idx].goal_models = optim_models
							tom_agents[idx].sample_models = optim_models
					elif test_mode == 1:
						leader_agent.goal_models = optim_models
						for idx in range(n_tom_hunters):
							tom_agents[idx].goal_models = leg_models
							tom_agents[idx].sample_models = optim_models
					elif test_mode == 2:
						leader_agent.goal_models = leg_models
						for idx in range(n_tom_hunters):
							tom_agents[idx].goal_models = optim_models
							tom_agents[idx].sample_models = leg_models
					else:
						leader_agent.goal_models = leg_models
						for idx in range(n_tom_hunters):
							tom_agents[idx].goal_models = leg_models
							tom_agents[idx].sample_models = leg_models

					# Get next objective
					preys_left = env.prey_alive_ids.copy()
					task = preys_left.pop(rng_gen.integers(n_preys_alive))
					env.target = task

			# Update leader and ToM agents' observations
			print('Preys alive: %s' % ', '.join([str(prey) for prey in env.prey_alive_ids]))
			if use_cnn:
				print(obs[LEADER_ID][-1])
				leader_obs = obs[LEADER_ID].reshape((1, *cnn_shape))
				for idx in range(n_tom_hunters):
					for prey in env.prey_alive_ids:
						if prey == tom_agents[idx].predict_task:
							# for layer in env.make_target_grid_obs(prey)[idx + 1]:
							# 	print(layer)
							# input()
							tom_obs[idx] = env.make_target_grid_obs(prey)[idx + 1].reshape((1, *cnn_shape))
			else:
				leader_obs = obs[LEADER_ID]
				for idx in range(n_tom_hunters):
					for prey in env.prey_alive_ids:
						if prey == tom_agents[idx].predict_task:
							tom_obs[idx] = env.make_target_grid_obs(prey)[idx + 1].reshape((1, *cnn_shape))

			current_state = ''.join([''.join(str(x) for x in env.agents[a_id].pos) for a_id in env.agents.keys() if env.agents[a_id].alive])
			if is_deadlock(recent_states, current_state, actions):
				n_deadlocks += 1
				if current_state not in deadlock_states:
					deadlock_states.append(current_state)
				act_try += 1
				actions = (leader_agent.sub_acting(leader_obs, None, act_try - 1, last_leader_sample, CONF, 'p%d' % n_preys_alive),
						   *[tom_agents[idx].sub_acting(tom_obs[idx], None, act_try, last_leader_sample, CONF, 'p%d' % n_preys_alive)for idx in range(n_tom_hunters)])
				# actions = (leader_agent.action(leader_obs, last_leader_sample, CONF, None, task), tom_agent.sub_acting(tom_obs, None, act_try, last_leader_sample, CONF))
			else:
				act_try = 0
				actions = (leader_agent.action(leader_obs, last_leader_sample, CONF, None, 'p%d' % n_preys_alive),
						   *[tom_agents[idx].action(tom_obs[idx], last_leader_sample, CONF, None, 'p%d' % n_preys_alive) for idx in range(n_tom_hunters)])

			actions = coordinate_agents(env, [tom_agents[idx].predict_task for idx in range(n_tom_hunters)], actions, n_tom_hunters)
			for prey_id in env.prey_ids:
				actions += (prey_agents[prey_id].act(env) if prey_id in env.prey_alive_ids else Action.STAY.value, )

			recent_states.append(current_state)
			if len(recent_states) > 3:
				recent_states.pop(0)

		env.close()
		print('Run Over!!')
		it_results['n_steps'] = n_steps
		it_results['pred_steps'] = n_pred_steps
		it_results['avg_pred_steps'] = np.mean(n_pred_steps) if len(n_pred_steps) > 0 else 0
		it_results['preys_captured'] = n_preys - n_preys_alive
		it_results['steps_capture'] = steps_capture
		it_results['deadlocks'] = n_deadlocks
		results[run_nr] = it_results
