In [None]:
from og_marl.vault_utils.download_vault import *
from og_marl.vault_utils.analyse_vault import *
import os
import json
from IPython.display import display, HTML

In [None]:
available_vaults = print_download_options()

In [None]:
def make_datacard_info(
    vault_name: str,
    rel_dir: str = "vaults",
    done_flags: tuple = ("terminals",),
) -> Dict[str, Array]:
    
    out_info = {}
    
    vault_uids = get_available_uids(f"./{rel_dir}/{vault_name}")

    for uid in vault_uids:
        out_info[uid] = {}

        vlt = Vault(vault_name=vault_name, rel_dir=rel_dir, vault_uid=uid)
        exp = vlt.read().experience

        saco, _, _ = get_saco(exp)
        mean, stddev, max_ret, min_ret, episode_returns = get_episode_return_descriptors(exp, done_flags)
        n_trans = exp["actions"].shape[1]

        struct, _, n_traj = get_structure_descriptors(exp, 0, done_flags)

        out_info[uid]["Mean episode return"] = float(mean)
        out_info[uid]["Standard deviation episode return"] = float(stddev)
        out_info[uid]["Min return"] = float(min_ret)
        out_info[uid]["Max return"] = float(max_ret)
        out_info[uid]["Transitions"] = n_trans
        out_info[uid]["Trajectories"] = n_traj
        out_info[uid]["Joint SACo"] = saco
        out_info[uid]["Structure"] = struct

    return out_info, vault_uids


In [None]:

with open("../docs/dataset_cards/datacard_info.json") as current_info:
    # get json string
    datacard_str = json.load(current_info)

    # convert to dictionary
    datacard_dict = json.loads(datacard_str)

In [None]:

# pretty names to print for these
env_add_info = {
    "mamujoco": {"pretty_name":"MAMuJoCo", "action_space": "Continuous"},
    "smac_v1": {"pretty_name":"SMAC (v1)", "action_space": "Discrete"},
    "smac_v2": {"pretty_name":"SMAC (v2)", "action_space": "Discrete"},
    "mpe" : {"pretty_name":"MPE", "action_space": "Continuous"},
    "rware": {"pretty_name":"RWARE", "action_space": "Discrete"},
}

ogmarl_collecting_policy = {
    "smac_v1": "QMIX",
    "smac_v2": "QMIX",
    "mamujoco": "MATD3",

}

add_diversity_explanation = {
    "Continuous":"Gaussian noise with standard deviation of 0.2 was added to the action selection.",
    "Discrete":"An epsilon greedy policy with eps=0.05 was used.",
}

# datacard_dict_update = 

for source in available_vaults.keys():
    datacard_dict[source] = {}
    for env in available_vaults[source].keys():
        datacard_dict[source][env] = {}
        if not env=="gymnasium_mamujoco":
            for task in available_vaults[source][env].keys():

                datacard_dict[source][env][task] = {}
                
                # download vault
                rel_vault_location = download_and_unzip_vault(source,env,task)

                # convert source-env-task naming convention to rel_dir-vault_name-vault_uids categories
                vault_rel_dir = f"vaults/{source}/{env}"
                vault_name = f"{task}.vlt" # a vault name contains only the file name which has the .vlt extension

                out_info, vault_uids = make_datacard_info(vault_name=vault_name,rel_dir=vault_rel_dir)

                datacard_dict[source][env][task].update(out_info)


                for uid in vault_uids:
                    datacard_dict[source][env][task][uid]["Download link"] = available_vaults[source][env][task]["url"]
                    datacard_dict[source][env][task][uid]["Scenario name"] = task
                    datacard_dict[source][env][task][uid]["Dataset name"] = uid
                    datacard_dict[source][env][task][uid]["Environment name"] = env_add_info[env]["pretty_name"]
                    datacard_dict[source][env][task][uid]["Action space"] = env_add_info[env]["action_space"]
                    if not (source=="og_marl"):
                        datacard_dict[source][env][task][uid]["Motivation"] = "Existing dataset from the literature converted to Vault format for accessibility."
                        datacard_dict[source][env][task][uid]["Generation procedure"] = f"Converted from {source} format to a Vault."
                        datacard_dict[source][env][task][uid]["Histogram download url"] = f"https://raw.huggingface.co/datasets/InstaDeepAI/og-marl/resolve/main/prior_work/{source}/{env}/{task}_histogram.pdf"

                    else:
                        datacard_dict[source][env][task][uid]["Histogram download url"] = f"https://raw.huggingface.co/datasets/InstaDeepAI/og-marl/resolve/main/core/{env}/{task}_histogram.pdf"
                        if not (uid=="Replay"):
                            datacard_dict[source][env][task][uid]["Generation procedure"] = f"A {ogmarl_collecting_policy[env]} system was trained to target level of performance. The learnt policy was then rolled out to collect approximately 250k transitions. {add_diversity_explanation[env_add_info[env]['action_space']]} This procedure was repeated 4 times and the data was combined."
                        else:
                            datacard_dict[source][env][task][uid]["Generation procedure"] = f"A {ogmarl_collecting_policy[env]} system was trained to target level of performance. The learnt policy was then rolled out to collect approximately 1m transitions. {add_diversity_explanation[env_add_info[env]['action_space']]}"





In [None]:

with open("../docs/dataset_cards/datacard_info.json","w") as current_info:
    # convert to json string
    datacard_str = json.dumps(datacard_dict,indent=4)

    # save dictionary
    json.dump(datacard_str,current_info)