In [214]:
import sys
sys.path.insert(0, '../../')
from src.questions_construction.questions import FLUENT_TYPES_LIST
from src.questions_construction.main import PLAN_LENGTHS, QUESTION_CATEGORIES
from src.questions_construction.domains import DOMAIN_NAMES
from src.analysis.model_performances import * #gather_questions, TRANSPORTATION_DOMAINS, NON_TRANSPORTATION_DOMAINS
from src.common import *
import random
from collections import defaultdict
from copy import deepcopy
import itertools

In [73]:
def get_data(domain, dataset_dir, few_shot=1):
    data_all = []
    for i in range(1,11):
        path = os.path.join(dataset_dir, 'without_random_sub', 'without_ramifications', f'few_shot_{few_shot}', domain, f'Instance_{i}.jsonl')
        data_all.extend(open_jsonl(path))
    return data_all

In [74]:
questions_dir = f'{DATA_PATH}/questions_m1'
questions_by_id = gather_questions(questions_dir)

In [75]:
dataset_dir = f'{DATA_PATH}/data_for_evaluation'
data_by_domain = {domain: get_data(domain, dataset_dir) for domain in DOMAIN_NAMES}

In [76]:
for k,v in data_by_domain.items():
    print(k, len(v))

blocksworld 7279
depots 5999
driverlog 7127
goldminer 8742
grippers 7106
logistics 6100
miconic 8261
mystery 6100
npuzzle 7287
satellite 8517
spanner 8495
visitall 7166
zenotravel 7370


In [6]:
def is_good_qa(question_info, include_params, exclude_params):
    for k, v in exclude_params.items():
        if question_info[k] == v:
            return False
        
    for k, v in include_params.items():
        if question_info[k] != v:
            return False
    return True

def is_token_window_ok(data_dict, chars_per_token=4, max_tokens=4096):
    return (len(data_dict['prompt'] + data_dict['label'])/chars_per_token) < max_tokens

def sample_data(data_all, questions_by_id, selected_plan_length):
    selected_data = []
    exclude_params = {}
    include_params = {'plan_length': selected_plan_length}
    for d in data_all:
        question_id = d['id']
        if question_id not in questions_by_id:
            raise ValueError
        question_info = questions_by_id[question_id]
        if is_good_qa(question_info, include_params, exclude_params) and is_token_window_ok(d):
            selected_data.append(d)
    random.shuffle(selected_data)
    return selected_data

In [77]:
# domain = 'goldminer'
# selected_plan_length = 1
# sample = sample_data(data_by_domain[domain], questions_by_id, selected_plan_length)
# len(sample)

# Sample

In [215]:
def output_keys(questions_dict):
    key1 = questions_dict[OUT_OBJ_QUESTION_CATEGORY]
    if q[OUT_OBJ_ANSWER_TYPE] == TRUE_FALSE_ANSWER_TYPE:
        if q[OUT_OBJ_ANSWER] == 'True':
            key2 = 'true'
        else:
            key2 = 'false'
    else:
        key2 = 'free'
        
    key3 = q[OUT_OBJ_FLUENT_TYPE]
        
    if q[OUT_OBJ_IS_POS_FLUENT_QUESTION] is True:
        key4 = 'POS'
    elif q[OUT_OBJ_IS_POS_FLUENT_QUESTION] is False:
        key4 = 'NEG'
    else:
        key4 = 'None'
        
    if q[OUT_OBJ_DOMAIN_NAME] in TRANSPORTATION_DOMAINS:
        key5 = TRANSPORTATION_DOMAIN_KEY
    else:
        key5 = NON_TRANSPORTATION_DOMAIN_KEY
    return key1, key2, key3, key4, key5

def is_restricted(key1, key2, key3, key4):
        # restrictions on keys
    if key1 == 'fluent_tracking' and (key3 == 'None' or key4 == 'None'):
        return True
        
    if key1 == 'object_tracking' and (key3 == 'None' or key4 == 'None'):
        return True
        
    if key1 == 'hallucination' and key4 == 'None':
        return True
        
    if key1 == 'effects' and key3 == 'None' and key4 == 'None':
        return True
    
    return False

In [433]:
selected_plan_length = 1
# for selected_plan_length in [1,5,19]:

# PREPARE DATA
data_domain_all = {domain: sample_data(data_by_domain[domain], questions_by_id, selected_plan_length) for domain in DOMAIN_NAMES}
min_samples = min([len(v) for v in data_domain_all.values() if len(v) > 100])
data_all = []
for k, v in data_domain_all.items():
    data_all.extend(v[:min_samples])
random.shuffle(data_all)

print(len(data_all))

14417


In [474]:
MAX_PER_CATEGORY = 50 #12 #110 # 100 #100

MULTUPLICITY_BY_CATEGORY = {'action_executability': 3+0.3,
             'effects': 11+0.5,
             'fluent_tracking': 24,
             'hallucination': 4+0.5,
             'numerical_reasoning': 3+0.5,
             'object_tracking': 16,
             'state_tracking': 5+0.1}

selected_ids = []
data_by_category = defaultdict(list)
for d in data_all:
    question_id = d['id']
    q = questions_by_id[question_id]
    
    key1, key2, key3, key4, key5 = output_keys(q)
    key_all = (key1, key2, key3, key4, key5)
    if is_restricted(key1, key2, key3, key4):
        continue
        
    if len(data_by_category[key_all]) < MAX_PER_CATEGORY: #max(int(MAX_PER_CATEGORY/MULTUPLICITY_BY_CATEGORY[key1]), 1):
        data_by_category[key_all].append(question_id)
selected_ids.extend(list(itertools.chain.from_iterable(data_by_category.values())))

In [475]:
# {k: len(v) for k, v in data_domain_all.items()}

In [476]:
stats = defaultdict(int)
for id in selected_ids:
    q = questions_by_id[id]
    stats[output_keys(q)]+=1

by_domain = defaultdict(int)
for id in selected_ids:
    q = questions_by_id[id]
    by_domain[q[OUT_OBJ_DOMAIN_NAME]]+=1
print(by_domain, '\n')

by_fluents = defaultdict(int)
for id in selected_ids:
    q = questions_by_id[id]
    by_fluents[q[OUT_OBJ_FLUENT_TYPE]]+=1
print(by_fluents,'\n')


count_by_cat = defaultdict(int)
for k, v in sorted(stats.items(), key=lambda x: x[0][0]):
    count_by_cat[k[0]] += v
    print(k, v)
count_by_cat

defaultdict(<class 'int'>, {'satellite': 446, 'goldminer': 422, 'spanner': 448, 'blocksworld': 433, 'visitall': 366, 'npuzzle': 438, 'miconic': 371, 'grippers': 383, 'zenotravel': 434, 'logistics': 368, 'depots': 351, 'driverlog': 352, 'mystery': 378}) 

defaultdict(<class 'int'>, {'persistent_fluents': 1143, None: 1462, 'static_fluents': 1012, 'base_fluents': 726, 'derived_fluents': 847}) 

('action_executability', 'true', None, 'None', 'NON_TRANSPORTATION') 50
('action_executability', 'free', None, 'None', 'NON_TRANSPORTATION') 50
('action_executability', 'true', None, 'None', 'TRANSPORTATION') 50
('action_executability', 'free', None, 'None', 'TRANSPORTATION') 50
('action_executability', 'false', None, 'None', 'TRANSPORTATION') 50
('action_executability', 'false', None, 'None', 'NON_TRANSPORTATION') 50
('effects', 'true', 'base_fluents', 'None', 'TRANSPORTATION') 50
('effects', 'true', 'static_fluents', 'None', 'TRANSPORTATION') 50
('effects', 'false', 'derived_fluents', 'None', 'NO

defaultdict(int,
            {'action_executability': 300,
             'effects': 995,
             'fluent_tracking': 1997,
             'hallucination': 400,
             'numerical_reasoning': 300,
             'object_tracking': 934,
             'state_tracking': 264})

In [477]:
# count_by_cat = defaultdict(int)
# for k, v in sorted(stats.items(), key=lambda x: x[0][0]):
#     count_by_cat[k[0]] += 1
# count_by_cat

In [478]:
sum(count_by_cat.values())

5190

In [468]:
13*48*2

1248

In [448]:
print(13*48*2-sum(count_by_cat.values()))

-9778


In [449]:
save_jsonl(selected_ids, f'small_dataset_ids.{MAX_PER_CATEGORY}.new2.pl-{selected_plan_length}.new.jsonl')

In [134]:
# total = 0
# count_by_cat = defaultdict(int)
# for k, v in sorted(stats.items(), key=lambda x: x[0][0]):
#     count_by_cat[k[0]] += 1
# count_by_cat

In [ ]:
# # things to consider:
# - balanced T/F sampling
# - balanced Q categories
# - take q + a with lenth + response <= 4000 tokens
# balanced fluent types
# balanced -+ fluents

# state tracking, no true false questions
# reduce object tracking qs
# boost action executability

# Check if Data is Present

In [74]:
# ids_by_prefix = {}
# for sub in [WITHOUT_RANDOM_SUB, WITH_RANDOM_SUB]:
#     for ram in RAMIFICATION_TYPES:
#         for few_shot in ['few_shot_1', 'few_shot_5']:
#             ids = set()
#             for domain in DOMAIN_NAMES:
#                 for i in range(1,11):
#                     path = os.path.join(DATA_PATH, 'data_for_evaluation', sub, ram, few_shot, domain, f'Instance_{i}.jsonl')
#                     if not os.path.exists(path):
#                         print(path)
#                     else:
#                         for d in open_jsonl(path):
#                             ids.add(d['id'])
#             ids_by_prefix[(sub,ram,few_shot)] = ids    