In [84]:
import sys
sys.path.insert(0, '../../')
from src.questions_construction.main import PLAN_LENGTHS, QUESTION_CATEGORIES
from src.questions_construction.domains import DOMAIN_NAMES
from src.analysis.model_performances import gather_questions
from src.common import *
import random
from collections import defaultdict
from copy import deepcopy

In [85]:
questions_dir = f'{DATA_PATH}/questions_m1'
questions_by_id = gather_questions(questions_dir)

In [6]:
dataset_dir = '/Users/paveldolin/dev/research/current/reasoning_about_actions/pipeline/data/data_for_evaluation_small/without_random_sub'

In [52]:
MAX_QUESTIONS_PER_DOMAIN_PER_PLAN_LENGTH = 50

def is_ok_to_take_question(question_info, exclude_params, counter_by_plan_langth,
    max_questions_per_domain_per_plan_length=MAX_QUESTIONS_PER_DOMAIN_PER_PLAN_LENGTH):
    for k, v in exclude_params.items():
        if question_info[k] == v:
            return False
    
    if question_info['plan_length'] not in counter_by_plan_langth:
        return False
    
    if counter_by_plan_langth[question_info['plan_length']] < max_questions_per_domain_per_plan_length:
        return True
    else:
        return False


exclude_params = {'fluent_type': None, 'is_pos_fluent_question': None}

In [54]:
selected_question_ids = set()
selected_question_ids_by_domain = {dom: set() for dom in DOMAIN_NAMES}
counter_by_domain = {dom: deepcopy({1:0, 10:0, 19:0}) for dom in DOMAIN_NAMES}
while len(selected_question_ids) < len(DOMAIN_NAMES) * 3 * MAX_QUESTIONS_PER_DOMAIN_PER_PLAN_LENGTH:
    for domain in DOMAIN_NAMES:
        for i in range(1,11):
            path = os.path.join(dataset_dir, 'with_ramifications', 'few_shot_5', domain, f'Instance_{i}.jsonl')
            if not os.path.exists(path):
                continue
            data = open_jsonl(path)
            while sum(counter_by_domain[domain].values()) < 3 * MAX_QUESTIONS_PER_DOMAIN_PER_PLAN_LENGTH:
                d = random.choice(data)
                question_id = d['id']
                question_info = questions_by_id[question_id]
                if is_ok_to_take_question(question_info, exclude_params, counter_by_domain[domain]) and question_id not in selected_question_ids:
                    counter_by_domain[domain][question_info['plan_length']] += 1
                    selected_question_ids_by_domain[domain].add(question_id)
                    selected_question_ids.add(question_id)
                    break
    print(counter_by_domain)

{'blocksworld': {1: 0, 10: 6, 19: 4}, 'depots': {1: 4, 10: 5, 19: 1}, 'driverlog': {1: 3, 10: 5, 19: 2}, 'goldminer': {1: 2, 10: 3, 19: 5}, 'grippers': {1: 2, 10: 5, 19: 3}, 'logistics': {1: 3, 10: 4, 19: 3}, 'miconic': {1: 5, 10: 2, 19: 3}, 'mystery': {1: 2, 10: 2, 19: 6}, 'npuzzle': {1: 2, 10: 2, 19: 6}, 'satellite': {1: 4, 10: 3, 19: 3}, 'spanner': {1: 2, 10: 4, 19: 4}, 'visitall': {1: 1, 10: 4, 19: 5}, 'zenotravel': {1: 3, 10: 4, 19: 3}}
{'blocksworld': {1: 2, 10: 12, 19: 6}, 'depots': {1: 6, 10: 9, 19: 5}, 'driverlog': {1: 6, 10: 10, 19: 4}, 'goldminer': {1: 4, 10: 9, 19: 7}, 'grippers': {1: 3, 10: 9, 19: 8}, 'logistics': {1: 7, 10: 7, 19: 6}, 'miconic': {1: 5, 10: 8, 19: 7}, 'mystery': {1: 7, 10: 5, 19: 8}, 'npuzzle': {1: 5, 10: 3, 19: 12}, 'satellite': {1: 5, 10: 5, 19: 10}, 'spanner': {1: 4, 10: 6, 19: 10}, 'visitall': {1: 2, 10: 9, 19: 9}, 'zenotravel': {1: 5, 10: 8, 19: 7}}
{'blocksworld': {1: 5, 10: 17, 19: 8}, 'depots': {1: 8, 10: 15, 19: 7}, 'driverlog': {1: 10, 10: 15, 19

In [80]:
import json

selected_question_ids_by_domain = {k: list(v) for k,v in selected_question_ids_by_domain.items()}
with open('small_dataset_ids_by_domain.json','w') as f:
    json.dump(selected_question_ids_by_domain, f)
    
selected_question_ids = set()
for domain, ids in selected_question_ids_by_domain.items():
    selected_question_ids.update(ids)
save_jsonl(list(selected_question_ids), 'small_dataset_ids.jsonl')

In [ ]:
# # things to consider:
# - balanced T/F sampling
# - take q + a with lenth <= 4000 tokens
# - balanced Q categories

# Check if Data is Present

In [74]:
ids_by_prefix = {}
for sub in [WITHOUT_RANDOM_SUB, WITH_RANDOM_SUB]:
    for ram in RAMIFICATION_TYPES:
        for few_shot in ['few_shot_1', 'few_shot_5']:
            ids = set()
            for domain in DOMAIN_NAMES:
                for i in range(1,11):
                    path = os.path.join(DATA_PATH, 'data_for_evaluation', sub, ram, few_shot, domain, f'Instance_{i}.jsonl')
                    if not os.path.exists(path):
                        print(path)
                    else:
                        for d in open_jsonl(path):
                            ids.add(d['id'])
            ids_by_prefix[(sub,ram,few_shot)] = ids    

In [75]:
missing_ids_by_prefix = {}

for id in selected_question_ids:
    for k,v in ids_by_prefix.items():
        if id not in v:
            missing_ids_by_prefix[k] = id

In [76]:
missing_ids_by_prefix

{('with_random_sub',
  'with_ramifications',
  'few_shot_1'): '81d5fb7e-0dc9-49e2-bd1d-d14d64297009',
 ('with_random_sub',
  'with_ramifications',
  'few_shot_5'): '81d5fb7e-0dc9-49e2-bd1d-d14d64297009',
 ('with_random_sub',
  'without_ramifications',
  'few_shot_1'): '81d5fb7e-0dc9-49e2-bd1d-d14d64297009',
 ('with_random_sub',
  'without_ramifications',
  'few_shot_5'): '81d5fb7e-0dc9-49e2-bd1d-d14d64297009',
 ('without_random_sub',
  'with_ramifications',
  'few_shot_1'): '38fb9521-3c95-476c-98fb-38bdf15fd82e',
 ('without_random_sub',
  'without_ramifications',
  'few_shot_1'): '38fb9521-3c95-476c-98fb-38bdf15fd82e',
 ('without_random_sub',
  'without_ramifications',
  'few_shot_5'): 'c83baaa6-8c60-477e-a832-0ca28910a9a6'}

In [83]:
for domain, ids in selected_question_ids_by_domain.items():
    for id in ids:
        if id == 'c83baaa6-8c60-477e-a832-0ca28910a9a6':
            print(domain)

zenotravel


# Stats on the Dataset

In [99]:
# T/F, length 1
stats = {f'{c}_true': 0 for c in QUESTION_CATEGORIES} | {f'{c}_false': 0 for c in QUESTION_CATEGORIES}
for id in selected_question_ids:
    q = questions_by_id[id]
    if q['plan_length'] == 1 and q[OUT_OBJ_ANSWER_TYPE] == TRUE_FALSE_ANSWER_TYPE:
        if q[OUT_OBJ_ANSWER] == 'True':
            stats[f'{q["question_category"]}_true'] += 1
        else:
            stats[f'{q["question_category"]}_false'] += 1
        

# stats_all = {k: stats[f'{q}_true']/(stats[f'{q}_false'] or 0.0001)for q in QUESTION_CATEGORIES}

In [101]:
stats

{'object_tracking_true': 107,
 'fluent_tracking_true': 215,
 'state_tracking_true': 0,
 'action_executability_true': 0,
 'effects_true': 0,
 'numerical_reasoning_true': 0,
 'hallucination_true': 0,
 'composite_questions_true': 0,
 'object_tracking_false': 82,
 'fluent_tracking_false': 139,
 'state_tracking_false': 0,
 'action_executability_false': 0,
 'effects_false': 0,
 'numerical_reasoning_false': 0,
 'hallucination_false': 0,
 'composite_questions_false': 0}