In [1]:
import os.path
import sys
sys.path.insert(0, '../../')
from src.questions_construction.questions import FLUENT_TYPES_LIST
from src.questions_construction.main import PLAN_LENGTHS, QUESTION_CATEGORIES
from src.questions_construction.domains import DOMAIN_NAMES
from src.analysis.model_performances import * #gather_questions, TRANSPORTATION_DOMAINS, NON_TRANSPORTATION_DOMAINS
from src.common import *
import random
from collections import defaultdict
from copy import deepcopy
import itertools

import sentencepiece

In [2]:
with open('huggingface.token.key') as f:
    huggingface_key = f.read()
# print(huggingface_key)

# !huggingface-cli login

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


In [5]:
from transformers import AutoTokenizer

reberta_tokenizer = AutoTokenizer.from_pretrained('roberta-base')
llama_tokenizer = AutoTokenizer.from_pretrained('meta-llama/Llama-2-7b-chat-hf', token=huggingface_key)

def token_length(text, tokenizer):
    tokens = tokenizer(text, add_special_tokens=False)
    return len(tokens['input_ids'])

In [11]:
def get_data(domain, dataset_dir, few_shot=1):
    data_all = []
    for i in range(1,11):
        path = os.path.join(dataset_dir, 'without_random_sub', 'without_ramifications', f'few_shot_{few_shot}', domain, f'Instance_{i}.jsonl')
        data_all.extend(open_jsonl(path))
    return data_all

In [12]:
questions_dir = f'{DATA_PATH}/questions_m1'
questions_by_id = gather_questions(questions_dir)

In [13]:
dataset_dir = f'{DATA_PATH}/data_for_evaluation'
data_by_domain = {domain: get_data(domain, dataset_dir) for domain in DOMAIN_NAMES}

In [14]:
for k,v in data_by_domain.items():
    print(k, len(v))

blocksworld 7279
depots 5999
driverlog 7127
goldminer 8742
grippers 7106
logistics 6100
miconic 8261
mystery 6100
npuzzle 7287
satellite 8517
spanner 8495
visitall 7166
zenotravel 7370


In [26]:
def is_good_qa(question_info, include_params, exclude_params):
    for k, v in exclude_params.items():
        if question_info[k] == v:
            return False
        
    for k, v in include_params.items():
        if question_info[k] != v:
            return False
    return True

def tokens_in_text(text, chars_per_token=4):
    return len(text)/chars_per_token

def is_token_window_ok(data_dict, max_tokens=4096, max_out_tokens=int(512/2)):
    # condition1 = tokens_in_text(data_dict['prompt'] + data_dict['label']) <= max_tokens
    # condition2 = tokens_in_text(data_dict['label']) <= max_out_tokens
    condition1 = token_length(data_dict['prompt'] + data_dict['label'], llama_tokenizer) <= max_tokens
    condition2 = token_length(data_dict['label'], reberta_tokenizer) <= max_out_tokens
    return condition1 and condition2


def filter_by_length(data_all):
    data_filtered = []
    for d in tqdm(data_all):
        if is_token_window_ok(d):
            data_filtered.append(d)
    return data_filtered

def sample_data(data_all, questions_by_id, selected_plan_length):
    selected_data = []
    exclude_params = {}
    include_params = {'plan_length': selected_plan_length}
    for d in data_all:
        question_id = d['id']
        if question_id not in questions_by_id:
            raise ValueError
        question_info = questions_by_id[question_id]
        if is_good_qa(question_info, include_params, exclude_params):
            selected_data.append(d)
    random.shuffle(selected_data)
    return selected_data

In [27]:
# domain = 'goldminer'
# selected_plan_length = 1
# sample = sample_data(data_by_domain[domain], questions_by_id, selected_plan_length)
# len(sample)

In [28]:
def output_keys(questions_dict):
    key1 = questions_dict[OUT_OBJ_QUESTION_CATEGORY]
    if q[OUT_OBJ_ANSWER_TYPE] == TRUE_FALSE_ANSWER_TYPE:
        if q[OUT_OBJ_ANSWER] == 'True':
            key2 = 'true'
        else:
            key2 = 'false'
    else:
        key2 = 'free'
        
    key3 = q[OUT_OBJ_FLUENT_TYPE]
        
    if q[OUT_OBJ_IS_POS_FLUENT_QUESTION] is True:
        key4 = 'POS'
    elif q[OUT_OBJ_IS_POS_FLUENT_QUESTION] is False:
        key4 = 'NEG'
    else:
        key4 = 'None'
        
    if q[OUT_OBJ_DOMAIN_NAME] in TRANSPORTATION_DOMAINS:
        key5 = TRANSPORTATION_DOMAIN_KEY
    else:
        key5 = NON_TRANSPORTATION_DOMAIN_KEY
    return key1, key2, key3, key4, key5

def is_restricted(key1, key2, key3, key4):
    if key1 == 'fluent_tracking' and (key3 == 'None' or key4 == 'None'):
        return True
        
    if key1 == 'object_tracking' and key2!='free' and (key3 == 'None' or key4 == 'None'):
        return True
    
    if key1 == 'hallucination' and key4 == 'None':
        return True
    
    return False

# Filter By length

In [29]:
data_domain_all = {domain: filter_by_length(data_by_domain[domain]) for domain in DOMAIN_NAMES}

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
100%|██████████| 7279/7279 [00:15<00:00, 475.23it/s]
100%|██████████| 5999/5999 [00:51<00:00, 116.64it/s]
100%|██████████| 7127/7127 [00:44<00:00, 158.80it/s]
100%|██████████| 8742/8742 [00:55<00:00, 157.09it/s]
100%|██████████| 7106/7106 [00:13<00:00, 515.85it/s]
100%|██████████| 6100/6100 [00:18<00:00, 335.05it/s]
100%|██████████| 8261/8261 [00:33<00:00, 244.58it/s]
100%|██████████| 6100/6100 [00:25<00:00, 239.08it/s]
100%|██████████| 7287/7287 [00:28<00:00, 258.68it/s]
100%|██████████| 8517/8517 [00:43<00:00, 196.72it/s]
100%|██████████| 8495/8495 [00:29<00:00, 290.75it/s]
100%|██████████| 7166/7166 [00:53<00:00, 134.88it/s]
100%|██████████| 7370/7370 [00:56<00:00, 130.04it/s]


In [32]:
with open('data_domains.roberta.llama.length_fix.json', 'w') as f:
    json.dump(data_domain_all, f)
for k,v in data_domain_all.items():
    print(k, len(v))

blocksworld 7029
depots 5194
driverlog 5966
goldminer 6982
grippers 6841
logistics 5742
miconic 7362
mystery 5679
npuzzle 6604
satellite 8005
spanner 8089
visitall 5110
zenotravel 7073


# Sample

In [43]:
selected_plan_length = 1

# PREPARE DATA
data_domain_sample = {domain: sample_data(data_domain_all[domain], questions_by_id, selected_plan_length) for domain in DOMAIN_NAMES}
for k,v in data_domain_sample.items():
    print(k, len(v))
min_samples = min([len(v) for v in data_domain_sample.values() if len(v) > 100])
data_all = []
for k, v in data_domain_sample.items():
    data_all.extend(v[:min_samples])
random.shuffle(data_all)

print(len(data_all))


MAX_PER_CATEGORY = 20 #12 #110 # 100 #100

selected_ids = []
data_by_category = defaultdict(list)
for d in data_all:
    question_id = d['id']
    q = questions_by_id[question_id]
    
    key1, key2, key3, key4, key5 = output_keys(q)
    key_all = (key1, key2, key3, key4, key5)
    if is_restricted(key1, key2, key3, key4):
        continue
    max_per_category = MAX_PER_CATEGORY
        
    if len(data_by_category[key_all]) < max_per_category:
        data_by_category[key_all].append(question_id)
selected_ids.extend(list(itertools.chain.from_iterable(data_by_category.values())))

blocksworld 1410
depots 1069
driverlog 1203
goldminer 1487
grippers 1311
logistics 1158
miconic 1393
mystery 1146
npuzzle 1349
satellite 1598
spanner 1608
visitall 1093
zenotravel 1416
13897


In [44]:
stats = defaultdict(int)
for id in selected_ids:
    q = questions_by_id[id]
    stats[output_keys(q)]+=1

by_domain = defaultdict(int)
for id in selected_ids:
    q = questions_by_id[id]
    by_domain[q[OUT_OBJ_DOMAIN_NAME]]+=1
print(by_domain, '\n')

by_fluents = defaultdict(int)
for id in selected_ids:
    q = questions_by_id[id]
    by_fluents[q[OUT_OBJ_FLUENT_TYPE]]+=1
print(by_fluents,'\n')


count_by_cat = defaultdict(lambda: [0,0])
for k, v in sorted(stats.items(), key=lambda x: x[0][0]):
    if k[1] == 'free':
        count_by_cat[(k[0],k[2])][1] += v
    else:
        count_by_cat[(k[0],k[2])][0] += v
    print(k, v)
count_by_cat

defaultdict(<class 'int'>, {'satellite': 224, 'goldminer': 197, 'npuzzle': 186, 'blocksworld': 204, 'visitall': 150, 'spanner': 224, 'mystery': 159, 'zenotravel': 202, 'logistics': 145, 'miconic': 191, 'depots': 144, 'grippers': 184, 'driverlog': 157}) 

defaultdict(<class 'int'>, {None: 574, 'static_fluents': 480, 'persistent_fluents': 480, 'derived_fluents': 405, 'base_fluents': 428}) 

('action_executability', 'free', None, 'None', 'TRANSPORTATION') 20
('action_executability', 'free', None, 'None', 'NON_TRANSPORTATION') 20
('action_executability', 'false', None, 'None', 'TRANSPORTATION') 20
('action_executability', 'true', None, 'None', 'NON_TRANSPORTATION') 20
('action_executability', 'true', None, 'None', 'TRANSPORTATION') 20
('action_executability', 'false', None, 'None', 'NON_TRANSPORTATION') 20
('effects', 'free', None, 'POS', 'NON_TRANSPORTATION') 20
('effects', 'true', 'derived_fluents', 'None', 'TRANSPORTATION') 20
('effects', 'true', 'derived_fluents', 'None', 'NON_TRANSPOR

defaultdict(<function __main__.<lambda>()>,
            {('action_executability', None): [80, 40],
             ('effects', None): [0, 40],
             ('effects', 'derived_fluents'): [80, 0],
             ('effects', 'persistent_fluents'): [80, 0],
             ('effects', 'static_fluents'): [80, 0],
             ('effects', 'base_fluents'): [80, 0],
             ('fluent_tracking', 'static_fluents'): [160, 80],
             ('fluent_tracking', 'persistent_fluents'): [160, 80],
             ('fluent_tracking', 'derived_fluents'): [157, 80],
             ('fluent_tracking', 'base_fluents'): [150, 80],
             ('hallucination', None): [80, 80],
             ('numerical_reasoning', None): [80, 40],
             ('object_tracking', 'persistent_fluents'): [160, 0],
             ('object_tracking', None): [0, 40],
             ('object_tracking', 'static_fluents'): [160, 0],
             ('object_tracking', 'base_fluents'): [118, 0],
             ('object_tracking', 'derived_fluents')

In [45]:
total, tf, free = sum(v1+v2 for v1,v2 in count_by_cat.values()), sum(v1 for v1, v2 in count_by_cat.values()), sum(v2 for v1, v2 in count_by_cat.values())
print(total, tf, free)
print(tf/total, free/total)

2367 1768 599
0.7469370511195607 0.2530629488804394


In [46]:
save_jsonl(selected_ids, f'small_dataset_ids.{MAX_PER_CATEGORY}.pl-{selected_plan_length}.jsonl')

In [604]:
# total = 0
# count_by_cat = defaultdict(int)
# for k, v in sorted(stats.items(), key=lambda x: x[0][0]):
#     count_by_cat[k[0]] += 1
# count_by_cat

In [ ]:
# # things to consider:
# - balanced T/F sampling
# - balanced Q categories
# - take q + a with lenth + response <= 4000 tokens
# balanced fluent types
# balanced -+ fluents

# state tracking, no true false questions
# reduce object tracking qs
# boost action executability

# Train Data

In [48]:
test_ids = open_jsonl(os.path.join(DATA_PATH, 'small_dataset_ids.20.jsonl'))

In [52]:
train_ids = []
for q_id, _question in tqdm(questions_by_id.items()):
    if q_id not in test_ids:
        train_ids.append(q_id)

100%|██████████| 191102/191102 [00:19<00:00, 9912.70it/s] 


In [55]:
save_jsonl(train_ids, os.path.join(DATA_PATH, 'small_dataset_ids.20.train.jsonl'))