In [1]:
# can simply run this command. Recommend
!python run_log.py

pygame 2.1.2 (SDL 2.0.16, Python 3.9.7)
Hello from the pygame community. https://www.pygame.org/contribute.html
from agents.random.submission import my_controller as m0
from agents.random.submission import my_controller as m1
step0
[[array([103.39993], dtype=float32), array([20.14982], dtype=float32)], [array([191.98918], dtype=float32), array([-1.028819], dtype=float32)]]
[[array([-23.471172], dtype=float32), array([-14.427856], dtype=float32)], [array([113.37428], dtype=float32), array([21.149973], dtype=float32)]]
[[array([53.208004], dtype=float32), array([24.749514], dtype=float32)], [array([123.110344], dtype=float32), array([-2.6488307], dtype=float32)]]
[[array([52.299805], dtype=float32), array([9.552677], dtype=float32)], [array([42.91374], dtype=float32), array([11.425793], dtype=float32)]]
[[array([121.579735], dtype=float32), array([4.398463], dtype=float32)], [array([-40.02675], dtype=float32), array([-17.334135], dtype=float32)]]
[[array([117.45216], dtype=float32), arra

In [1]:
# or run the blocks below
import os
import time
import json
import numpy as np
import sys
import pygame

sys.path.append("./olympics_engine")

from env.chooseenv import make
from utils.get_logger import get_logger
from env.obs_interfaces.observation import obs_type


class NpEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, np.integer):
            return int(obj)
        elif isinstance(obj, np.floating):
            return float(obj)
        elif isinstance(obj, np.ndarray):
            return obj.tolist()
        else:
            return super(NpEncoder, self).default(obj)


def get_players_and_action_space_list(g):
    if sum(g.agent_nums) != g.n_player:
        raise Exception("agent number = %d 不正确，与n_player = %d 不匹配" % (sum(g.agent_nums), g.n_player))

    n_agent_num = list(g.agent_nums)
    for i in range(1, len(n_agent_num)):
        n_agent_num[i] += n_agent_num[i - 1]

    # 根据agent number 分配 player id
    players_id = []
    actions_space = []
    for policy_i in range(len(g.obs_type)):
        if policy_i == 0:
            players_id_list = range(n_agent_num[policy_i])
        else:
            players_id_list = range(n_agent_num[policy_i - 1], n_agent_num[policy_i])
        players_id.append(players_id_list)

        action_space_list = [g.get_single_action_space(player_id) for player_id in players_id_list]
        actions_space.append(action_space_list)

    return players_id, actions_space


def get_joint_action_eval(game, multi_part_agent_ids, policy_list, actions_spaces, all_observes):
    if len(policy_list) != len(game.agent_nums):
        error = "模型个数%d与玩家个数%d维度不正确！" % (len(policy_list), len(game.agent_nums))
        raise Exception(error)

    # [[[0, 0, 0, 1]], [[0, 1, 0, 0]]]
    joint_action = []
    for policy_i in range(len(policy_list)):

        if game.obs_type[policy_i] not in obs_type:
            raise Exception("可选obs类型：%s" % str(obs_type))

        agents_id_list = multi_part_agent_ids[policy_i]

        action_space_list = actions_spaces[policy_i]
        function_name = 'm%d' % policy_i
        for i in range(len(agents_id_list)):
            agent_id = agents_id_list[i]
            a_obs = all_observes[agent_id]
            each = eval(function_name)(a_obs, action_space_list[i], game.is_act_continuous)
            game.is_single_valid_action(each, action_space_list[i], policy_i)
            joint_action.append(each)
    print(joint_action)
    return joint_action


def set_seed(g, env_name):
    if env_name.split("-")[0] in ['magent']:
        g.reset()
        seed = g.create_seed()
        g.set_seed(seed)


def run_game(g, env_name, multi_part_agent_ids, actions_spaces, policy_list, render_mode):
    """
    This function is used to generate log for Vue rendering. Saves .json file
    """
    log_path = os.getcwd() + '/logs/'
    if not os.path.exists(log_path):
        os.mkdir(log_path)

    logger = get_logger(log_path, g.game_name, json_file=render_mode)
    set_seed(g, env_name)

    for i in range(len(policy_list)):
        if policy_list[i] not in get_valid_agents():
            raise Exception("agent {} not valid!".format(policy_list[i]))

        file_path = os.path.join(os.getcwd(), "agents", policy_list[i], "submission.py")
        if not os.path.exists(file_path):
            raise Exception("file {} not exist!".format(file_path))

        # import_path = '.'.join(file_path.split('/')[-3:])[:-3]
        import_path = '.'.join(['agents', policy_list[i], 'submission'])
        function_name = 'm%d' % i
        import_name = "my_controller"
        import_s = "from %s import %s as %s" % (import_path, import_name, function_name)
        print(import_s)
        exec(import_s, globals())

    st = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time()))
    game_info = {"game_name": env_name,
                 "n_player": g.n_player,
                 "board_height": g.board_height if hasattr(g, "board_height") else None,
                 "board_width": g.board_width if hasattr(g, "board_width") else None,
                 "init_info": g.init_info,
                 "start_time": st,
                 "mode": "terminal",
                 "seed": g.seed if hasattr(g, "seed") else None,
                 "map_size": g.map_size if hasattr(g, "map_size") else None}

    steps = []
    all_observes = g.all_observes
    while not g.is_terminal():
        step = "step%d" % g.step_cnt
        if g.step_cnt % 10 == 0:
            print(step)

        if render_mode and hasattr(g, "env_core"):
            if hasattr(g.env_core, "render"):
                g.env_core.render()
        info_dict = {"time": time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time()))}
        joint_act = get_joint_action_eval(g, multi_part_agent_ids, policy_list, actions_spaces, all_observes)
        all_observes, reward, done, info_before, info_after = g.step(joint_act)
        if env_name.split("-")[0] in ["magent"]:
            info_dict["joint_action"] = g.decode(joint_act)
        if info_before:
            info_dict["info_before"] = info_before
        info_dict["reward"] = reward
        if info_after:
            info_dict["info_after"] = info_after
        steps.append(info_dict)

    game_info["steps"] = steps
    game_info["winner"] = g.check_win()
    game_info["winner_information"] = g.won
    game_info["n_return"] = g.n_return
    ed = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time()))
    game_info["end_time"] = ed
    logs = json.dumps(game_info, ensure_ascii=False, cls=NpEncoder)
    logger.info(logs)


def get_valid_agents():
    dir_path = os.path.join(os.getcwd(), 'agents')
    return [f for f in os.listdir(dir_path) if f != "__pycache__"]

pygame 2.1.2 (SDL 2.0.16, Python 3.9.7)
Hello from the pygame community. https://www.pygame.org/contribute.html


In [2]:
env_type = "olympics-wrestling"
game = make(env_type, seed=None)

render_mode = True

# policy_list = ["random"] * len(game.agent_nums)
policy_list = ["random", "random"] # your policy, the folder name under agents

multi_part_agent_ids, actions_space = get_players_and_action_space_list(game)

run_game(game, env_type, multi_part_agent_ids, actions_space, policy_list, render_mode)
pygame.quit()

from agents.random.submission import my_controller as m0
from agents.random.submission import my_controller as m1
step0
[[array([129.1693], dtype=float32), array([29.711187], dtype=float32)], [array([18.770868], dtype=float32), array([5.283855], dtype=float32)]]
[[array([150.79143], dtype=float32), array([20.12188], dtype=float32)], [array([119.04926], dtype=float32), array([-4.659135], dtype=float32)]]
[[array([-61.504528], dtype=float32), array([-17.620546], dtype=float32)], [array([-90.61181], dtype=float32), array([29.53458], dtype=float32)]]
[[array([-10.639216], dtype=float32), array([4.2851415], dtype=float32)], [array([133.55753], dtype=float32), array([3.9955542], dtype=float32)]]
[[array([133.51227], dtype=float32), array([-26.343813], dtype=float32)], [array([35.412], dtype=float32), array([12.715875], dtype=float32)]]
[[array([1.3118306], dtype=float32), array([-26.600008], dtype=float32)], [array([-19.744585], dtype=float32), array([-12.073001], dtype=float32)]]
[[array([1