In [1]:
import sys
sys.path.insert(0, '../../')
from src.questions_construction.questions import FLUENT_TYPES_LIST
from src.questions_construction.main import PLAN_LENGTHS, QUESTION_CATEGORIES
from src.questions_construction.domains import DOMAIN_NAMES
from src.analysis.model_performances import gather_questions
from src.common import *
import random
from collections import defaultdict
from copy import deepcopy
import itertools

In [2]:
def get_data(domain, dataset_dir, few_shot=1):
    data_all = []
    for i in range(1,11):
        path = os.path.join(dataset_dir, 'without_random_sub', 'without_ramifications', f'few_shot_{few_shot}', domain, f'Instance_{i}.jsonl')
        data_all.extend(open_jsonl(path))
    return data_all

In [3]:
questions_dir = f'{DATA_PATH}/questions_m1'
questions_by_id = gather_questions(questions_dir)

In [4]:
dataset_dir = f'{DATA_PATH}/data_for_evaluation'
data_by_domain = {domain: get_data(domain, dataset_dir) for domain in DOMAIN_NAMES}

In [5]:
for k,v in data_by_domain.items():
    print(k, len(v))

blocksworld 7279
depots 5999
driverlog 7127
goldminer 8742
grippers 7106
logistics 6100
miconic 8261
mystery 6100
npuzzle 7287
satellite 8517
spanner 8495
visitall 7166
zenotravel 7370


In [6]:
def is_good_qa(question_info, include_params, exclude_params):
    for k, v in exclude_params.items():
        if question_info[k] == v:
            return False
        
    for k, v in include_params.items():
        if question_info[k] != v:
            return False
    return True

def is_token_window_ok(data_dict, chars_per_token=4, max_tokens=4096):
    return (len(data_dict['prompt'] + data_dict['label'])/chars_per_token) < max_tokens

def sample_data(data_all, questions_by_id, selected_plan_length):
    selected_data = []
    exclude_params = {}
    include_params = {'plan_length': selected_plan_length}
    for d in data_all:
        question_id = d['id']
        if question_id not in questions_by_id:
            raise ValueError
        question_info = questions_by_id[question_id]
        if is_good_qa(question_info, include_params, exclude_params) and is_token_window_ok(d):
            selected_data.append(d)
    random.shuffle(selected_data)
    return selected_data

In [7]:
# domain = 'goldminer'
# selected_plan_length = 1
# sample = sample_data(data_by_domain[domain], questions_by_id, selected_plan_length)
# len(sample)

# Sample

In [8]:
def output_keys(questions_dict):
    key1 = questions_dict[OUT_OBJ_QUESTION_CATEGORY]
    if q[OUT_OBJ_ANSWER_TYPE] == TRUE_FALSE_ANSWER_TYPE:
        if q[OUT_OBJ_ANSWER] == 'True':
            key2 = 'true'
        else:
            key2 = 'false'
    else:
        key2 = 'free'
        
    key3 = q[OUT_OBJ_FLUENT_TYPE]
        
    if q[OUT_OBJ_IS_POS_FLUENT_QUESTION] is True:
        key4 = 'POS'
    elif q[OUT_OBJ_IS_POS_FLUENT_QUESTION] is False:
        key4 = 'NEG'
    else:
        key4 = 'None'
    return (key1, key2, key3, key4)

def is_restricted(key1, key2, key3, key4):
        # restrictions on keys
    if key1 == 'fluent_tracking' and (key3 == 'None' or key4 == 'None'):
        return True
        
    if key1 == 'object_tracking' and (key3 == 'None' or key4 == 'None'):
        return True
        
    if key1 == 'hallucination' and key4 == 'None':
        return True
        
    if key1 == 'effects' and key3 == 'None' and key4 == 'None':
        return True
    
    return False

In [9]:
selected_plan_length = 1
# for selected_plan_length in [1,5,19]:

# PREPARE DATA
data_domain_all = {domain: sample_data(data_by_domain[domain], questions_by_id, selected_plan_length) for domain in DOMAIN_NAMES}
min_samples = min([len(v) for v in data_domain_all.values() if len(v) > 100])
data_all = []
for k, v in data_domain_all.items():
    data_all.extend(v[:min_samples])
random.shuffle(data_all)

print(len(data_all))

14417


In [10]:
{k: len(v) for k, v in data_domain_all.items()}

{'blocksworld': 1460,
 'depots': 1109,
 'driverlog': 1254,
 'goldminer': 1569,
 'grippers': 1356,
 'logistics': 1205,
 'miconic': 1426,
 'mystery': 1207,
 'npuzzle': 1408,
 'satellite': 1603,
 'spanner': 1624,
 'visitall': 1218,
 'zenotravel': 1456}

In [60]:
MAX_PL_DOM_CAT_TYPE = 13
MAX_PER_CATEGORY = 100

MULTUPLICITY_BY_CATEGORY = {'action_executability': 3+0.3,
             'effects': 11+0.5,
             'fluent_tracking': 24,
             'hallucination': 4+0.5,
             'numerical_reasoning': 3+0.5,
             'object_tracking': 16,
             'state_tracking': 5+0.5}

selected_ids = []
data_by_category = defaultdict(list)
for d in data_all:
    question_id = d['id']
    q = questions_by_id[question_id]
    
    key1, key2, key3, key4 = output_keys(q)
    key_all = (key1, key2, key3, key4)
    if is_restricted(key1, key2, key3, key4):
        continue
                
    # max_samples = MAX_PL_DOM_CAT_TYPE
    # if key1 in ('state_tracking'):
    #     if key2 in ('true', 'false'):
    #         max_samples = MAX_PL_DOM_CAT_TYPE*3
    #     else:
    #         max_samples = MAX_PL_DOM_CAT_TYPE
    # if key1 in ('action_executability'):
    #     max_samples = int(MAX_PL_DOM_CAT_TYPE*3)
    #     
    # if key1 in ('fluent_tracking'):
    #     max_samples = int(MAX_PL_DOM_CAT_TYPE/2)
    # 
    # if key1 in ('object_tracking'):
    #     max_samples = int(MAX_PL_DOM_CAT_TYPE/1.5)
    # 
    # if key1 in ('numerical_reasoning'):
    #     max_samples = int(MAX_PL_DOM_CAT_TYPE*1.5)
    
    if len(data_by_category[key_all]) < int(MAX_PER_CATEGORY/MULTUPLICITY_BY_CATEGORY[key1]):
        data_by_category[key_all].append(question_id)
selected_ids.extend(list(itertools.chain.from_iterable(data_by_category.values())))

In [61]:
stats = defaultdict(int)
for id in selected_ids:
    q = questions_by_id[id]
    stats[output_keys(q)]+=1

count_by_cat = defaultdict(int)
for k, v in sorted(stats.items(), key=lambda x: x[0][0]):
    count_by_cat[k[0]] += v
    print(k, v) 
    
count_by_cat

('action_executability', 'true', None, 'None') 30
('action_executability', 'false', None, 'None') 30
('action_executability', 'free', None, 'None') 30
('effects', 'true', 'derived_fluents', 'None') 8
('effects', 'false', 'persistent_fluents', 'None') 8
('effects', 'true', 'persistent_fluents', 'None') 8
('effects', 'free', None, 'POS') 8
('effects', 'true', 'static_fluents', 'None') 8
('effects', 'false', 'derived_fluents', 'None') 8
('effects', 'false', 'base_fluents', 'None') 8
('effects', 'free', None, 'NEG') 8
('effects', 'true', 'base_fluents', 'None') 8
('effects', 'false', 'static_fluents', 'None') 8
('effects', 'free', None, 'None') 8
('fluent_tracking', 'false', 'persistent_fluents', 'POS') 4
('fluent_tracking', 'free', 'persistent_fluents', 'NEG') 4
('fluent_tracking', 'false', 'derived_fluents', 'POS') 4
('fluent_tracking', 'free', 'persistent_fluents', 'POS') 4
('fluent_tracking', 'false', 'base_fluents', 'NEG') 4
('fluent_tracking', 'true', 'base_fluents', 'POS') 4
('fluen

defaultdict(int,
            {'action_executability': 90,
             'effects': 88,
             'fluent_tracking': 96,
             'hallucination': 88,
             'numerical_reasoning': 84,
             'object_tracking': 92,
             'state_tracking': 90})

In [62]:
# count_by_cat = defaultdict(int)
# for k, v in sorted(stats.items(), key=lambda x: x[0][0]):
#     count_by_cat[k[0]] += 1
# count_by_cat

In [63]:
print(13*48-sum(count_by_cat.values()))

-4


In [64]:
13*48

624

In [65]:
# total = 0
# count_by_cat = defaultdict(int)
# for k, v in sorted(stats.items(), key=lambda x: x[0][0]):
#     count_by_cat[k[0]] += 1
# count_by_cat

In [68]:
save_jsonl(selected_ids, f'small_dataset_ids_pl.{selected_plan_length}.jsonl')

In [ ]:
# # things to consider:
# - balanced T/F sampling
# - balanced Q categories
# - take q + a with lenth + response <= 4000 tokens
# balanced fluent types
# balanced -+ fluents

# state tracking, no true false questions
# reduce object tracking qs
# boost action executability

# Check if Data is Present

In [74]:
# ids_by_prefix = {}
# for sub in [WITHOUT_RANDOM_SUB, WITH_RANDOM_SUB]:
#     for ram in RAMIFICATION_TYPES:
#         for few_shot in ['few_shot_1', 'few_shot_5']:
#             ids = set()
#             for domain in DOMAIN_NAMES:
#                 for i in range(1,11):
#                     path = os.path.join(DATA_PATH, 'data_for_evaluation', sub, ram, few_shot, domain, f'Instance_{i}.jsonl')
#                     if not os.path.exists(path):
#                         print(path)
#                     else:
#                         for d in open_jsonl(path):
#                             ids.add(d['id'])
#             ids_by_prefix[(sub,ram,few_shot)] = ids    