In [25]:
import json
import numpy as np

import sys

sys.path.insert(0, '../')
from src.common import *
from src.questions_construction.domains import DOMAIN_NAMES, ALL_DOMAIN_CLASSES_BY_NAME

In [33]:
FLUENT_TYPES_KEYS = ['base', 'derived', 'persistent', 'static', 'total']

def is_fluent_part_of_type(fluent, fluent_prefixes):
    for prefix in fluent_prefixes:
        if fluent.startswith(prefix):
            return True
    return False

# def num_fluents(data):
#     return len(data[0][INIT_ACTION_KEY][FLUENTS_KEY]) + len(data[0][INIT_ACTION_KEY][NEG_FLUENTS_KEY])


def num_fluents(data, domain_name):
    return {
        'base': 2 * len([f for f in data[0][INIT_ACTION_KEY][FLUENTS_KEY] if 
                         is_fluent_part_of_type(f, ALL_DOMAIN_CLASSES_BY_NAME[domain_name].BASE_POS_FLUENTS)]),
        'derived': 2 * len([f for f in data[0][INIT_ACTION_KEY][FLUENTS_KEY] if
                             is_fluent_part_of_type(f, ALL_DOMAIN_CLASSES_BY_NAME[domain_name].DERIVED_POS_FLUENTS)]),
        'persistent': 2 * len([f for f in data[0][INIT_ACTION_KEY][FLUENTS_KEY] if
                                is_fluent_part_of_type(f, ALL_DOMAIN_CLASSES_BY_NAME[domain_name].PERSISTENT_POS_FLUENTS)]),
        'static': 2 * len([f for f in data[0][INIT_ACTION_KEY][FLUENTS_KEY] if
                            is_fluent_part_of_type(f, ALL_DOMAIN_CLASSES_BY_NAME[domain_name].STATIC_POS_FLUENTS)]),
        'total': 2 * len(data[0][INIT_ACTION_KEY][FLUENTS_KEY])
    }

In [34]:
stats_by_domain = {}
for domain_name in sorted(DOMAIN_NAMES):
    action_lengths = []
    objects = []
    by_instance = {}
    for i in range(1, 11):
        instance_name = f'Instance_{i}'
        data = open_jsonl(f'{DATA_PATH}/states_actions/{domain_name}/{instance_name}.jsonl')
        by_instance[instance_name] = num_fluents(data, domain_name)
    stats_by_fluent_type = {}
    for fluent_type in FLUENT_TYPES_KEYS:
        fluents_ls = [f[fluent_type] for f in by_instance.values()]
        stats_by_fluent_type[fluent_type] = {'mean': float(np.mean(fluents_ls)),
                                             'std': float(np.std(fluents_ls)),
                                             'min': int(np.min(fluents_ls)),
                                             'max': int(np.max(fluents_ls)),
                                             'count': fluents_ls}
    stats_by_domain[domain_name] = stats_by_fluent_type
with open('stats_by_domain.json', 'w') as f:
    json.dump(stats_by_domain, f)

In [35]:
# sorted(stats_by_domain.items(), key=lambda x: x[1]['mean'])
stats_by_domain

{'blocksworld': {'base': {'mean': 5.4,
   'std': 0.9165151389911681,
   'min': 4,
   'max': 6,
   'count': [6, 6, 4, 6, 6, 4, 6, 4, 6, 6]},
  'derived': {'mean': 7.4,
   'std': 0.9165151389911681,
   'min': 6,
   'max': 8,
   'count': [8, 8, 6, 8, 8, 6, 8, 6, 8, 8]},
  'persistent': {'mean': 10.8,
   'std': 1.32664991614216,
   'min': 8,
   'max': 12,
   'count': [8, 12, 10, 10, 12, 12, 12, 10, 12, 10]},
  'static': {'mean': 0.0,
   'std': 0.0,
   'min': 0,
   'max': 0,
   'count': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]},
  'total': {'mean': 23.6,
   'std': 2.33238075793812,
   'min': 20,
   'max': 26,
   'count': [22, 26, 20, 24, 26, 22, 26, 20, 26, 24]}},
 'depots': {'base': {'mean': 0.0,
   'std': 0.0,
   'min': 0,
   'max': 0,
   'count': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]},
  'derived': {'mean': 24.4,
   'std': 3.32264954516723,
   'min': 20,
   'max': 28,
   'count': [24, 28, 28, 24, 24, 28, 28, 20, 20, 20]},
  'persistent': {'mean': 47.4,
   'std': 2.2,
   'min': 44,
   'max': 50,
   'count