In [214]:
import sys
sys.path.insert(0, '../../')
from src.questions_construction.questions import FLUENT_TYPES_LIST
from src.questions_construction.main import PLAN_LENGTHS, QUESTION_CATEGORIES
from src.questions_construction.domains import DOMAIN_NAMES
from src.analysis.model_performances import * #gather_questions, TRANSPORTATION_DOMAINS, NON_TRANSPORTATION_DOMAINS
from src.common import *
import random
from collections import defaultdict
from copy import deepcopy
import itertools

In [73]:
def get_data(domain, dataset_dir, few_shot=1):
    data_all = []
    for i in range(1,11):
        path = os.path.join(dataset_dir, 'without_random_sub', 'without_ramifications', f'few_shot_{few_shot}', domain, f'Instance_{i}.jsonl')
        data_all.extend(open_jsonl(path))
    return data_all

In [74]:
questions_dir = f'{DATA_PATH}/questions_m1'
questions_by_id = gather_questions(questions_dir)

In [75]:
dataset_dir = f'{DATA_PATH}/data_for_evaluation'
data_by_domain = {domain: get_data(domain, dataset_dir) for domain in DOMAIN_NAMES}

In [76]:
for k,v in data_by_domain.items():
    print(k, len(v))

blocksworld 7279
depots 5999
driverlog 7127
goldminer 8742
grippers 7106
logistics 6100
miconic 8261
mystery 6100
npuzzle 7287
satellite 8517
spanner 8495
visitall 7166
zenotravel 7370


In [6]:
def is_good_qa(question_info, include_params, exclude_params):
    for k, v in exclude_params.items():
        if question_info[k] == v:
            return False
        
    for k, v in include_params.items():
        if question_info[k] != v:
            return False
    return True

def is_token_window_ok(data_dict, chars_per_token=4, max_tokens=4096):
    return (len(data_dict['prompt'] + data_dict['label'])/chars_per_token) < max_tokens

def sample_data(data_all, questions_by_id, selected_plan_length):
    selected_data = []
    exclude_params = {}
    include_params = {'plan_length': selected_plan_length}
    for d in data_all:
        question_id = d['id']
        if question_id not in questions_by_id:
            raise ValueError
        question_info = questions_by_id[question_id]
        if is_good_qa(question_info, include_params, exclude_params) and is_token_window_ok(d):
            selected_data.append(d)
    random.shuffle(selected_data)
    return selected_data

In [77]:
# domain = 'goldminer'
# selected_plan_length = 1
# sample = sample_data(data_by_domain[domain], questions_by_id, selected_plan_length)
# len(sample)

In [632]:
def output_keys(questions_dict):
    key1 = questions_dict[OUT_OBJ_QUESTION_CATEGORY]
    if q[OUT_OBJ_ANSWER_TYPE] == TRUE_FALSE_ANSWER_TYPE:
        if q[OUT_OBJ_ANSWER] == 'True':
            key2 = 'true'
        else:
            key2 = 'false'
    else:
        key2 = 'free'
        
    key3 = q[OUT_OBJ_FLUENT_TYPE]
        
    if q[OUT_OBJ_IS_POS_FLUENT_QUESTION] is True:
        key4 = 'POS'
    elif q[OUT_OBJ_IS_POS_FLUENT_QUESTION] is False:
        key4 = 'NEG'
    else:
        key4 = 'None'
        
    if q[OUT_OBJ_DOMAIN_NAME] in TRANSPORTATION_DOMAINS:
        key5 = TRANSPORTATION_DOMAIN_KEY
    else:
        key5 = NON_TRANSPORTATION_DOMAIN_KEY
    return key1, key2, key3, key4, key5

def is_restricted(key1, key2, key3, key4):
    if key1 == 'fluent_tracking' and (key3 == 'None' or key4 == 'None'):
        return True
        
    if key1 == 'object_tracking' and key2!='free' and (key3 == 'None' or key4 == 'None'):
        return True
    
    if key1 == 'hallucination' and key4 == 'None':
        return True
    
    return False

# Sample

In [659]:
selected_plan_length = 19
# for selected_plan_length in [1,5,19]:

# PREPARE DATA
data_domain_all = {domain: sample_data(data_by_domain[domain], questions_by_id, selected_plan_length) for domain in DOMAIN_NAMES}
min_samples = min([len(v) for v in data_domain_all.values() if len(v) > 100])
data_all = []
for k, v in data_domain_all.items():
    data_all.extend(v[:min_samples])
random.shuffle(data_all)

# print(len(data_all))

In [660]:
MAX_PER_CATEGORY = 20 #12 #110 # 100 #100

selected_ids = []
data_by_category = defaultdict(list)
for d in data_all:
    question_id = d['id']
    q = questions_by_id[question_id]
    
    key1, key2, key3, key4, key5 = output_keys(q)
    key_all = (key1, key2, key3, key4, key5)
    if is_restricted(key1, key2, key3, key4):
        continue
        
    # if key1=='effects':
    #     max_per_category = int(MAX_PER_CATEGORY/3)
    # elif key1=='fluent_tracking':
    #     max_per_category = int(MAX_PER_CATEGORY/4)
    # elif key1=='object_tracking':
    #     max_per_category = int(MAX_PER_CATEGORY/4)
    # else:
    max_per_category = MAX_PER_CATEGORY
        
    if len(data_by_category[key_all]) < max_per_category:
        data_by_category[key_all].append(question_id)
selected_ids.extend(list(itertools.chain.from_iterable(data_by_category.values())))

In [661]:
stats = defaultdict(int)
for id in selected_ids:
    q = questions_by_id[id]
    stats[output_keys(q)]+=1

by_domain = defaultdict(int)
for id in selected_ids:
    q = questions_by_id[id]
    by_domain[q[OUT_OBJ_DOMAIN_NAME]]+=1
print(by_domain, '\n')

by_fluents = defaultdict(int)
for id in selected_ids:
    q = questions_by_id[id]
    by_fluents[q[OUT_OBJ_FLUENT_TYPE]]+=1
print(by_fluents,'\n')


count_by_cat = defaultdict(lambda: [0,0])
for k, v in sorted(stats.items(), key=lambda x: x[0][0]):
    if k[1] == 'free':
        count_by_cat[(k[0],k[2])][1] += v
    else:
        count_by_cat[(k[0],k[2])][0] += v
    print(k, v)
count_by_cat

defaultdict(<class 'int'>, {'satellite': 205, 'spanner': 230, 'goldminer': 198, 'blocksworld': 232, 'npuzzle': 207, 'visitall': 172, 'driverlog': 148, 'miconic': 212, 'depots': 156, 'grippers': 242, 'zenotravel': 215, 'mystery': 157, 'logistics': 172}) 

defaultdict(<class 'int'>, {'base_fluents': 470, 'derived_fluents': 412, None: 709, 'static_fluents': 475, 'persistent_fluents': 480}) 

('action_executability', 'false', None, 'None', 'NON_TRANSPORTATION') 20
('action_executability', 'free', None, 'None', 'TRANSPORTATION') 20
('action_executability', 'free', None, 'None', 'NON_TRANSPORTATION') 20
('action_executability', 'true', None, 'None', 'TRANSPORTATION') 20
('action_executability', 'true', None, 'None', 'NON_TRANSPORTATION') 20
('action_executability', 'false', None, 'None', 'TRANSPORTATION') 20
('effects', 'true', 'static_fluents', 'None', 'TRANSPORTATION') 20
('effects', 'false', 'derived_fluents', 'None', 'TRANSPORTATION') 20
('effects', 'false', 'persistent_fluents', 'None',

defaultdict(<function __main__.<lambda>()>,
            {('action_executability', None): [80, 40],
             ('effects', 'static_fluents'): [80, 0],
             ('effects', 'derived_fluents'): [80, 0],
             ('effects', 'persistent_fluents'): [80, 0],
             ('effects', 'base_fluents'): [80, 0],
             ('effects', None): [0, 104],
             ('fluent_tracking', 'base_fluents'): [160, 80],
             ('fluent_tracking', 'derived_fluents'): [160, 80],
             ('fluent_tracking', 'persistent_fluents'): [160, 80],
             ('fluent_tracking', 'static_fluents'): [155, 80],
             ('hallucination', None): [80, 80],
             ('numerical_reasoning', None): [80, 40],
             ('object_tracking', 'persistent_fluents'): [160, 0],
             ('object_tracking', 'static_fluents'): [160, 0],
             ('object_tracking', None): [0, 40],
             ('object_tracking', 'derived_fluents'): [92, 0],
             ('object_tracking', 'base_fluents')

In [662]:
total = sum(v1+v2 for v1,v2 in count_by_cat.values())
total

2546

In [663]:
13*48*2

1248

In [664]:
print(13*48*2-total)

-1298


In [665]:
save_jsonl(selected_ids, f'small_dataset_ids.{MAX_PER_CATEGORY}.pl-{selected_plan_length}.jsonl')

In [604]:
# total = 0
# count_by_cat = defaultdict(int)
# for k, v in sorted(stats.items(), key=lambda x: x[0][0]):
#     count_by_cat[k[0]] += 1
# count_by_cat

In [ ]:
# # things to consider:
# - balanced T/F sampling
# - balanced Q categories
# - take q + a with lenth + response <= 4000 tokens
# balanced fluent types
# balanced -+ fluents

# state tracking, no true false questions
# reduce object tracking qs
# boost action executability

# Check if Data is Present

In [74]:
# ids_by_prefix = {}
# for sub in [WITHOUT_RANDOM_SUB, WITH_RANDOM_SUB]:
#     for ram in RAMIFICATION_TYPES:
#         for few_shot in ['few_shot_1', 'few_shot_5']:
#             ids = set()
#             for domain in DOMAIN_NAMES:
#                 for i in range(1,11):
#                     path = os.path.join(DATA_PATH, 'data_for_evaluation', sub, ram, few_shot, domain, f'Instance_{i}.jsonl')
#                     if not os.path.exists(path):
#                         print(path)
#                     else:
#                         for d in open_jsonl(path):
#                             ids.add(d['id'])
#             ids_by_prefix[(sub,ram,few_shot)] = ids    