In [1]:
import os.path
import sys
sys.path.insert(0, '../../')
from questions_construction.questions import FLUENT_TYPES_LIST
from questions_construction.main import PLAN_LENGTHS, QUESTION_CATEGORIES
from questions_construction.domains import DOMAIN_NAMES
from analysis.model_performances import * #gather_questions, TRANSPORTATION_DOMAINS, NON_TRANSPORTATION_DOMAINS
from common import *
import random
from collections import defaultdict
from copy import deepcopy
import itertools

import sentencepiece

In [2]:
with open('huggingface.token.key') as f:
    huggingface_key = f.read()
# print(huggingface_key)

# !huggingface-cli login

In [3]:
from transformers import AutoTokenizer

reberta_tokenizer = AutoTokenizer.from_pretrained('roberta-base')
llama_tokenizer = AutoTokenizer.from_pretrained('meta-llama/Llama-2-7b-chat-hf', token=huggingface_key)

def token_length(text, tokenizer):
    tokens = tokenizer(text, add_special_tokens=False)
    return len(tokens['input_ids'])

In [4]:
def get_data(domain, dataset_dir, subs='without_random_sub', few_shot=1):
    data_all = []
    for i in range(1,11):
        path = os.path.join(dataset_dir, subs, 'without_ramifications', f'few_shot_{few_shot}', domain, f'Instance_{i}.jsonl')
        data_all.extend(open_jsonl(path))
    return data_all

In [7]:
questions_dir = f'{DATA_PATH}/questions_m1'
ids_file_name = 'dataset_ids.test.pruned'
selected_ids = open_jsonl(f'{DATA_PATH}/{ids_file_name}.jsonl')
questions_by_id = gather_questions(questions_dir, selected_ids)

questions gathered


In [17]:
max_question_plus_answer = []
for _id, d1 in questions_by_id.items():
    d2 = d1[WITHOUT_RANDOM_SUB]
    max_question_plus_answer.append(len(d2[OUT_OBJ_QUESTION]) + len(d2[OUT_OBJ_ANSWER]))

In [19]:
print(np.mean(max_question_plus_answer), np.std(max_question_plus_answer), np.max(max_question_plus_answer))
MAX_CHARS = 7858

1091.2592213114754 977.9991725839053 7858


In [20]:
# questions_dir = f'{DATA_PATH}/questions_m1'
composite_questions_by_id = gather_questions(f'{DATA_PATH}/questions.composite')

missing /Users/paveldolin/dev/research/current/reasoning_about_actions/pipeline/data/questions.composite/without_random_sub/mystery/Instance_7.jsonl
missing /Users/paveldolin/dev/research/current/reasoning_about_actions/pipeline/data/questions.composite/without_random_sub/mystery/Instance_8.jsonl
missing /Users/paveldolin/dev/research/current/reasoning_about_actions/pipeline/data/questions.composite/without_random_sub/mystery/Instance_9.jsonl
missing /Users/paveldolin/dev/research/current/reasoning_about_actions/pipeline/data/questions.composite/without_random_sub/mystery/Instance_10.jsonl
missing /Users/paveldolin/dev/research/current/reasoning_about_actions/pipeline/data/questions.composite/without_random_sub/npuzzle/Instance_1.jsonl
missing /Users/paveldolin/dev/research/current/reasoning_about_actions/pipeline/data/questions.composite/without_random_sub/npuzzle/Instance_2.jsonl
missing /Users/paveldolin/dev/research/current/reasoning_about_actions/pipeline/data/questions.composite/

In [8]:
# dataset_dir = f'{DATA_PATH}/data_for_evaluation'
dataset_dir = f'{DATA_PATH}/data_for_evaluation.composite'
subs = WITHOUT_RANDOM_SUB
few_shot = 5
data_by_domain = {domain: get_data(domain, dataset_dir, subs=subs, few_shot=few_shot) for domain in DOMAIN_NAMES}

In [9]:
for k,v in data_by_domain.items():
    print(k, len(v))

blocksworld 7279
depots 5999
driverlog 7083
goldminer 8746
grippers 7106
logistics 6100
miconic 8249
mystery 6100
npuzzle 7288
satellite 8516
spanner 8494
visitall 7165
zenotravel 7372


In [10]:
def is_good_qa(question_info, include_params, exclude_params):
    for k, v in exclude_params.items():
        if question_info[k] == v:
            return False
        
    for k, v in include_params.items():
        if question_info[k] != v:
            return False
    return True

def tokens_in_text(text, chars_per_token=4):
    return len(text)/chars_per_token

def is_token_window_ok(data_dict, max_tokens=4096, max_out_tokens=int(512/2)):
    # condition1 = tokens_in_text(data_dict['prompt'] + data_dict['label']) <= max_tokens
    # condition2 = tokens_in_text(data_dict['label']) <= max_out_tokens
    condition1 = token_length(data_dict['prompt'] + data_dict['label'], llama_tokenizer) <= max_tokens
    condition2 = token_length(data_dict['label'], reberta_tokenizer) <= max_out_tokens
    return condition1 and condition2


def filter_by_length(data_all):
    data_filtered = []
    for d in tqdm(data_all):
        try:
            if is_token_window_ok(d):
                data_filtered.append(d)
        except:
            pass
    return data_filtered

def sample_data(data_all, questions_by_id, selected_plan_length):
    selected_data = []
    exclude_params = {}
    include_params = {'plan_length': selected_plan_length}
    for d in data_all:
        question_id = d['id']
        if question_id not in questions_by_id:
            raise ValueError
        question_info = questions_by_id[question_id]
        if is_good_qa(question_info, include_params, exclude_params):
            selected_data.append(d)
    random.shuffle(selected_data)
    return selected_data

In [11]:
def output_keys(questions_dict):
    key1 = questions_dict[OUT_OBJ_QUESTION_CATEGORY]
    if q[OUT_OBJ_ANSWER_TYPE] == TRUE_FALSE_ANSWER_TYPE:
        if q[OUT_OBJ_ANSWER] == 'True':
            key2 = 'true'
        else:
            key2 = 'false'
    else:
        key2 = 'free'
        
    key3 = q[OUT_OBJ_FLUENT_TYPE]
        
    if q[OUT_OBJ_IS_POS_FLUENT_QUESTION] is POS_FLUENT_KEY:
        key4 = 'POS'
    elif q[OUT_OBJ_IS_POS_FLUENT_QUESTION] is NEG_FLUENT_KEY:
        key4 = 'NEG'
    else:
        key4 = 'None'
        
    if q[OUT_OBJ_DOMAIN_NAME] in TRANSPORTATION_DOMAINS:
        key5 = TRANSPORTATION_DOMAIN_KEY
    else:
        key5 = NON_TRANSPORTATION_DOMAIN_KEY
    return key1, key2, key3, key4, key5

def is_restricted(key1, key2, key3, key4):
    if key1 == 'fluent_tracking' and (key3 == 'None' or key4 == 'None'):
        return True
        
    if key1 == 'object_tracking' and key2!='free' and (key3 == 'None' or key4 == 'None'):
        return True
    
    if key1 == 'hallucination' and key4 == 'None':
        return True
    
    return False


# Filter By length

In [None]:
data_domain_all = {domain: filter_by_length(data_by_domain[domain]) for domain in DOMAIN_NAMES}

  4%|▎         | 258/7279 [00:01<00:42, 163.59it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (577 > 512). Running this sequence through the model will result in indexing errors
100%|██████████| 7279/7279 [00:38<00:00, 189.26it/s]
 63%|██████▎   | 3760/5999 [01:23<00:33, 66.57it/s] 

In [None]:
with open(f'{few_shot}.data_domains.roberta.llama.length_fix.json', 'w') as f:
    json.dump(data_domain_all, f)
for k,v in data_domain_all.items():
    print(k, len(v))

# Sample

In [31]:
selected_plan_length = 19

# PREPARE DATA
data_domain_sample = {domain: sample_data(data_domain_all[domain], questions_by_id, selected_plan_length) for domain in DOMAIN_NAMES}
for k,v in data_domain_sample.items():
    print(k, len(v))
min_samples = min([len(v) for v in data_domain_sample.values() if len(v) > 100])
data_all = []
for k, v in data_domain_sample.items():
    data_all.extend(v[:min_samples])
random.shuffle(data_all)

print(len(data_all))


MAX_PER_CATEGORY = 20 #12 #110 # 100 #100

selected_ids = []
data_by_category = defaultdict(list)
for d in data_all:
    question_id = d['id']
    q = questions_by_id[question_id]
    
    key1, key2, key3, key4, key5 = output_keys(q)
    key_all = (key1, key2, key3, key4, key5)
    if is_restricted(key1, key2, key3, key4):
        continue
    max_per_category = MAX_PER_CATEGORY
        
    if len(data_by_category[key_all]) < max_per_category:
        data_by_category[key_all].append(question_id)
selected_ids.extend(list(itertools.chain.from_iterable(data_by_category.values())))

blocksworld 1216
depots 947
driverlog 963
goldminer 0
grippers 1371
logistics 820
miconic 401
mystery 451
npuzzle 1175
satellite 1431
spanner 1518
visitall 0
zenotravel 1383
4411


In [32]:
stats = defaultdict(int)
for id in selected_ids:
    q = questions_by_id[id]
    stats[output_keys(q)]+=1

by_domain = defaultdict(int)
for id in selected_ids:
    q = questions_by_id[id]
    by_domain[q[OUT_OBJ_DOMAIN_NAME]]+=1
print(by_domain, '\n')

by_fluents = defaultdict(int)
for id in selected_ids:
    q = questions_by_id[id]
    by_fluents[q[OUT_OBJ_FLUENT_TYPE]]+=1
print(by_fluents,'\n')


count_by_cat = defaultdict(lambda: [0,0])
for k, v in sorted(stats.items(), key=lambda x: x[0][0]):
    if k[1] == 'free':
        count_by_cat[(k[0],k[2])][1] += v
    else:
        count_by_cat[(k[0],k[2])][0] += v
    print(k, v)
count_by_cat

defaultdict(<class 'int'>, {'zenotravel': 148, 'logistics': 105, 'miconic': 143, 'driverlog': 125, 'mystery': 71, 'spanner': 210, 'npuzzle': 182, 'blocksworld': 179, 'satellite': 185, 'grippers': 161, 'depots': 124}) 

defaultdict(<class 'int'>, {'static_fluents': 235, None: 448, 'derived_fluents': 315, 'persistent_fluents': 355, 'base_fluents': 280}) 

('action_executability', 'true', None, 'None', 'TRANSPORTATION') 20
('action_executability', 'free', None, 'None', 'TRANSPORTATION') 20
('action_executability', 'false', None, 'None', 'TRANSPORTATION') 20
('action_executability', 'true', None, 'None', 'NON_TRANSPORTATION') 20
('action_executability', 'free', None, 'None', 'NON_TRANSPORTATION') 20
('action_executability', 'false', None, 'None', 'NON_TRANSPORTATION') 18
('effects', 'false', 'static_fluents', 'None', 'TRANSPORTATION') 20
('effects', 'false', 'persistent_fluents', 'None', 'TRANSPORTATION') 20
('effects', 'true', 'base_fluents', 'None', 'NON_TRANSPORTATION') 20
('effects', '

defaultdict(<function __main__.<lambda>()>,
            {('action_executability', None): [78, 40],
             ('effects', 'static_fluents'): [78, 0],
             ('effects', 'persistent_fluents'): [80, 0],
             ('effects', 'base_fluents'): [75, 0],
             ('effects', 'derived_fluents'): [80, 0],
             ('effects', None): [0, 4],
             ('fluent_tracking', 'derived_fluents'): [140, 51],
             ('fluent_tracking', 'persistent_fluents'): [106, 63],
             ('fluent_tracking', 'base_fluents'): [104, 43],
             ('fluent_tracking', 'static_fluents'): [69, 30],
             ('hallucination', None): [80, 80],
             ('numerical_reasoning', None): [80, 40],
             ('object_tracking', 'derived_fluents'): [44, 0],
             ('object_tracking', 'persistent_fluents'): [106, 0],
             ('object_tracking', 'base_fluents'): [58, 0],
             ('object_tracking', None): [0, 39],
             ('object_tracking', 'static_fluents'): [5

In [33]:
total, tf, free = sum(v1+v2 for v1,v2 in count_by_cat.values()), sum(v1 for v1, v2 in count_by_cat.values()), sum(v2 for v1, v2 in count_by_cat.values())
print(total, tf, free)
print(tf/total, free/total)

1633 1239 394
0.7587262706674831 0.24127372933251684


In [34]:
save_jsonl(selected_ids, f'{subs}.small_dataset_ids.{MAX_PER_CATEGORY}.pl-{selected_plan_length}.jsonl')

In [604]:
# total = 0
# count_by_cat = defaultdict(int)
# for k, v in sorted(stats.items(), key=lambda x: x[0][0]):
#     count_by_cat[k[0]] += 1
# count_by_cat

In [ ]:
# # things to consider:
# - balanced T/F sampling
# - balanced Q categories
# - take q + a with lenth + response <= 4000 tokens
# balanced fluent types
# balanced -+ fluents

# state tracking, no true false questions
# reduce object tracking qs
# boost action executability

# Train Data

In [48]:
test_ids = open_jsonl(os.path.join(DATA_PATH, 'small_dataset_ids.20.jsonl'))

In [52]:
train_ids = []
for q_id, _question in tqdm(questions_by_id.items()):
    if q_id not in test_ids:
        train_ids.append(q_id)

100%|██████████| 191102/191102 [00:19<00:00, 9912.70it/s] 


In [55]:
save_jsonl(train_ids, os.path.join(DATA_PATH, 'small_dataset_ids.20.train.jsonl'))