In [1]:
%autoreload 2

In [2]:
from argparse import Namespace
from collections import defaultdict
import copy
from datetime import datetime
import difflib
import duckdb
from functools import reduce
import glob
import gzip
import itertools
import os
import pickle
import sys
import typing

from IPython.display import display, Markdown, HTML  # type: ignore
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import polars as pl
import tatsu, tatsu.ast
import tqdm.notebook as tqdmn


sys.path.append(os.path.abspath('..'))
sys.path.append(os.path.abspath('../src'))
sys.path.append(os.path.abspath('../reward-machine'))

import compile_predicate_statistics
import compile_predicate_statistics_split_args
from compile_predicate_statistics_split_args import *
from config import SPECIFIC_NAMED_OBJECTS_BY_ROOM
import config
from utils import get_object_assignments

import ast_printer
import ast_parser

2023-08-17 12:55:01 - ast_utils - DEBUG    - Using cache folder: /Users/guydavidson/tmp/game_generation_cache


In [3]:
cache_dir = compile_predicate_statistics.get_project_dir() + '/reward-machine/caches'

no_strings_df = pd.read_pickle(os.path.join(cache_dir, 'predicate_statistics_no_intervals_028b3733.pkl.gz'))
with gzip.open(os.path.join(cache_dir, 'trace_lengths_7511b7be.pkl'), 'rb') as f:
    trace_lengths_and_domains = pickle.load(f)


# regular_df = pd.read_pickle(os.path.join(cache_dir, 'predicate_statistics.pkl'))
# split_args_df = pd.read_pickle(os.path.join(cache_dir, 'predicate_statistics_028b3733.pkl.gz'))
# split_args_df = split_args_df[split_args_df['predicate'] != 'same_type']
# print(split_args_df.shape)
# # split_args_df = pd.read_pickle(os.path.join(cache_dir, 'predicate_statistics_4d5dd602.pkl.gz'))

# # stats = compile_predicate_statistics.CommonSensePredicateStatistics(cache_dir)
# split_args_stats = compile_predicate_statistics_split_args.CommonSensePredicateStatisticsSplitArgs(
#     # cache_dir, compile_predicate_statistics_split_args.CURRENT_TEST_TRACE_NAMES, overwrite=False
#     )

In [None]:
split_args_df_no_same_type = split_args_df[split_args_df['predicate'] != 'same_type'].drop(columns=['string_intervals'])
split_args_df_no_same_type.shape

In [None]:
split_args_df_no_same_type.to_pickle(os.path.join(cache_dir, 'predicate_statistics_028b3733.pkl.gz'))

In [None]:
split_args_stats.trace_lengths_and_domains_df.select(pl.col('trace_length').max()).item()

In [None]:
a = np.zeros(10000, dtype=np.uint8)
a[100:1000] = 1


In [None]:
t = np.zeros(10, dtype=np.uint8)
t[1:4] = 1
b = t.tobytes()
d = {b[0]: '0', b[1]: '1'}

In [None]:
''.join(map(lambda x: d[x], b))

In [None]:
b = bytes([0, 1])
b

In [None]:
MAX_LENGTH = split_args_stats.trace_lengths_and_domains_df.select(pl.col('trace_length').max()).item()
print(MAX_LENGTH)
b = bytes([0, 1])
BYTE_MAPPING = {b: str(i) for i, b in enumerate(b)}



def intervals_to_string(intervals, max_length: int = MAX_LENGTH):
    value = np.zeros(max_length, dtype=np.uint8)
    for interval in intervals:
        value[interval[0]:interval[1]] = 1

    return ''.join(map(lambda b: BYTE_MAPPING[b], value.tobytes()))
    # return np.array2string(value, separator='', threshold=max_length + 10)[1:-1].replace('\n ', '')


intervals = split_args_df.intervals.apply(intervals_to_string)

# small_split_args_df = small_split_args_df.assign()

In [None]:
split_args_df = split_args_df.assign(string_intervals=intervals)
split_args_df.head()

In [None]:
duckdb.sql('DROP TABLE data')
duckdb.sql('CREATE TABLE data AS SELECT * FROM split_args_df')
duckdb.sql('ALTER TABLE data ALTER string_intervals TYPE BITSTRING')
duckdb.sql('DESCRIBE data').show()

# duckdb.sql("CREATE TYPE predicate AS ENUM (SELECT predicate FROM split_args_df);")
# duckdb.sql("CREATE TYPE domain AS ENUM (SELECT domain FROM split_args_df);")
# duckdb.sql("CREATE TYPE trace_id AS ENUM (SELECT trace_id FROM split_args_df);")
# all_types = tuple([t for t in set(split_args_df.arg_1_type.unique()) | set(split_args_df.arg_2_type.unique()) if isinstance(t, str) ])
# duckdb.sql(f"CREATE TYPE arg_type AS ENUM {all_types};")
# all_ids = tuple([t for t in set(split_args_df.arg_1_id.unique()) | set(split_args_df.arg_2_id.unique()) if isinstance(t, str)])
# duckdb.sql(f"CREATE TYPE arg_id AS ENUM {all_ids};")

# duckdb.sql('ALTER TABLE data ALTER predicate TYPE predicate')
# duckdb.sql('ALTER TABLE data ALTER domain TYPE domain')
# duckdb.sql('ALTER TABLE data ALTER trace_id TYPE trace_id')
# duckdb.sql('ALTER TABLE data ALTER arg_1_type TYPE arg_type')
# duckdb.sql('ALTER TABLE data ALTER arg_2_type TYPE arg_type')
# duckdb.sql('ALTER TABLE data ALTER arg_1_id TYPE arg_id')
# duckdb.sql('ALTER TABLE data ALTER arg_2_id TYPE arg_id')

In [None]:
tld_df = split_args_stats.trace_lengths_and_domains_df
duckdb.sql('DROP TABLE trace_lengths_domains')
duckdb.sql('CREATE TABLE trace_lengths_domains AS SELECT * FROM tld_df')
duckdb.sql('ALTER TABLE trace_lengths_domains ALTER domain TYPE domain')
duckdb.sql('ALTER TABLE trace_lengths_domains ALTER trace_id TYPE trace_id')
duckdb.sql('SELECT * FROM trace_lengths_domains').show()


In [None]:
duckdb.sql("SELECT arg_1_type FROM data WHERE arg_1_type='ball'")

In [None]:
duckdb.sql("SELECT arg_1_id as '?b' FROM data WHERE arg_1_type = 'dodgeball' ")

In [None]:
DEFAULT_GRAMMAR_PATH = "../dsl/dsl.ebnf"
grammar = open(DEFAULT_GRAMMAR_PATH).read()
grammar_parser = typing.cast(tatsu.grammars.Grammar, tatsu.compile(grammar))

game = open(get_project_dir() + '/reward-machine/games/ball_to_bin_from_bed.txt').read()
game_ast = grammar_parser.parse(game)  # type: ignore

block_stacking_game = open(get_project_dir() + '/reward-machine/games/block_stacking.txt').read()
block_stacking_game_ast = grammar_parser.parse(block_stacking_game)  # type: ignore

In [None]:
# test_pred_desk_and = block_stacking_game_ast[3][1]['preferences'][1]['definition']['pref_body']['body']['exists_args']['at_end_pred']
# ast_printer.ast_section_to_string(test_pred_desk_and, ast_parser.PREFERENCES)

test_pred_or = block_stacking_game_ast[3][1]['preferences'][2]['definition']['pref_body']['body']['exists_args']['at_end_pred']
print(ast_printer.ast_section_to_string(test_pred_or, ast_parser.PREFERENCES))

In [None]:
cspsd.filter(test_pred_or, {"?b": ["block"], "?c": ["chair"]})

In [None]:
q = """SELECT temp_table_3.trace_id, temp_table_3.domain, temp_table_3."top_drawer", temp_table_3."?g", 
(temp_table_3.string_intervals | COALESCE(temp_table_1.string_intervals, empty_bitstring()) | COALESCE(temp_table_2.string_intervals, empty_bitstring())) AS si,
FROM temp_table_3 
LEFT JOIN temp_table_1 ON temp_table_3.trace_id=temp_table_1.trace_id AND temp_table_1."top_drawer"=temp_table_3."top_drawer" AND temp_table_1."?g"=temp_table_3."?g" 
LEFT JOIN temp_table_2 ON temp_table_3.trace_id=temp_table_2.trace_id AND temp_table_2."top_drawer"=temp_table_3."top_drawer" 
"""

# temp_table_3.string_intervals as si3, COALESCE(temp_table_1.string_intervals, empty_bitstring()) as si1, COALESCE(temp_table_2.string_intervals, empty_bitstring()) AS si2,
# 
q2 = """SELECT temp_table_3.trace_id, temp_table_3.domain, temp_table_3."top_drawer", temp_table_3."?g", 
(temp_table_3.string_intervals | COALESCE(temp_table_1.string_intervals, empty_bitstring()) | COALESCE(temp_table_2.string_intervals, empty_bitstring())) AS si,
FROM temp_table_3 
LEFT JOIN temp_table_1 ON temp_table_3.trace_id=temp_table_1.trace_id AND temp_table_1."top_drawer"=temp_table_3."top_drawer" AND temp_table_1."?g"=temp_table_3."?g" 
LEFT JOIN temp_table_2 ON temp_table_3.trace_id=temp_table_2.trace_id AND temp_table_2."top_drawer"=temp_table_3."top_drawer"
"""

In [None]:
pldf = pl.from_dataframe(split_args_df)

In [None]:
# duckdb.sql("CREATE TYPE predicate AS ENUM (SELECT predicate FROM split_args_df);")
# duckdb.sql("CREATE TYPE domain AS ENUM (SELECT domain FROM split_args_df);")
# duckdb.sql("CREATE TYPE trace_id AS ENUM (SELECT trace_id FROM split_args_df);")
# all_types = tuple([t for t in set(split_args_df.arg_1_type.unique()) | set(split_args_df.arg_2_type.unique()) if isinstance(t, str) ])
# duckdb.sql(f"CREATE TYPE arg_type AS ENUM {all_types};")
# all_ids = tuple([t for t in set(split_args_df.arg_1_id.unique()) | set(split_args_df.arg_2_id.unique()) if isinstance(t, str)])
# duckdb.sql(f"CREATE TYPE arg_id AS ENUM {all_ids};")

# duckdb.sql('ALTER TABLE test ALTER predicate TYPE predicate')
# duckdb.sql('ALTER TABLE test ALTER domain TYPE domain')
# duckdb.sql('ALTER TABLE test ALTER trace_id TYPE trace_id')
# duckdb.sql('ALTER TABLE test ALTER arg_1_type TYPE arg_type')
# duckdb.sql('ALTER TABLE test ALTER arg_2_type TYPE arg_type')
# duckdb.sql('ALTER TABLE test ALTER arg_1_id TYPE arg_id')
# duckdb.sql('ALTER TABLE test ALTER arg_2_id TYPE arg_id')

In [None]:
duckdb.sql('DESCRIBE test').show()

In [None]:
duckdb.sql("SELECT bit_count((SELECT string_intervals FROM test WHERE predicate='on' LIMIT 1) & (SELECT string_intervals FROM test WHERE predicate='on' OFFSET 100 LIMIT 1))")

In [None]:
duckdb.sql("SELECT intervals FROM test WHERE predicate='on' LIMIT 1")

In [None]:
arg_type = 'basketball'
duckdb.sql(f"SELECT count(*) FROM data WHERE (arg_1_type='{arg_type}' OR arg_2_type='{arg_type}')").fetchone()[0]

In [None]:
duckdb.sql("SELECT val FROM (SELECT unnest(enum_range(NULL::arg_id)) as val) ")  # WHERE val='Shelf|-02.97|+01.53|-01.72' 

In [None]:
bits_df.select(pl.col('bits_1') & pl.col('bits_2'))

In [None]:
all_specific_names = set()
for room, type_to_ids in SPECIFIC_NAMED_OBJECTS_BY_ROOM.items():
    all_specific_names.update(type_to_ids.keys())

specific_name_rows = split_args_df[(split_args_df.arg_1_type.isin(all_specific_names) | split_args_df.arg_2_type.isin(all_specific_names))].shape[0]
total_rows = split_args_df.shape[0]
print(f'{specific_name_rows} / {total_rows} ({specific_name_rows / total_rows * 100:.2f}%) rows have a specific name')

In [None]:
same_type_rows = split_args_df[(split_args_df.predicate == 'same_type')].shape[0]
print(f'{same_type_rows} / {total_rows} ({same_type_rows / total_rows * 100:.2f}%) rows are for same_type')

In [None]:
for i, row in split_args_df.iterrows():
    print(i, row.to_dict())
    break

In [None]:
split_args_df[(split_args_df.predicate == 'adjacent') & (split_args_df.arg_1_type == 'agent') & (split_args_df.arg_2_type == 'green_golfball')]

In [None]:
split_args_df[(split_args_df.predicate == 'same_type')]

# Expriment with moving everything to a database

In [None]:
all_types = set([t for t in set(no_strings_df.arg_1_type.unique()) | set(no_strings_df.arg_2_type.unique()) if isinstance(t, str)])
all_predicates = set([t for t in set(no_strings_df.predicate.unique()) if isinstance(t, str)])
all_arg_ids = set([t for t in set(no_strings_df.arg_1_id.unique()) | set(no_strings_df.arg_2_id.unique()) if isinstance(t, str)])
all_arg_ids.update(config.UNITY_PSEUDO_OBJECTS.keys())
all_arg_ids.update(reduce(lambda x, y: x + y, [object_types for room_types in config.OBJECTS_BY_ROOM_AND_TYPE.values() for object_types in room_types.values()]))

In [None]:
config_types = set(reduce(lambda x, y: x + y, [list(x.keys()) for x in itertools.chain(config.OBJECTS_BY_ROOM_AND_TYPE.values(), config.SPECIFIC_NAMED_OBJECTS_BY_ROOM.values())]))
# config_types.difference_update(config.META_TYPES.keys())
config_types.remove(config.GAME_OBJECT)

config_predicates = set([t[0] for t in compile_predicate_statistics_split_args.COMMON_SENSE_PREDICATES_AND_FUNCTIONS])

config_arg_ids = set(reduce(lambda x, y: x + y, [object_types for room_types in config.OBJECTS_BY_ROOM_AND_TYPE.values() for object_types in room_types.values()]))
config_arg_ids.update(config.UNITY_PSEUDO_OBJECTS.keys())



In [None]:
for enum_name in ('domain', 'trace_id', 'predicate', 'arg_type', 'arg_id'):
    duckdb.sql(f'DROP TYPE {enum_name} CASCADE;')

duckdb.sql(f"CREATE TYPE domain AS ENUM {tuple(config.ROOMS)};")
all_trace_ids = [os.path.splitext(os.path.basename(p))[0] for p in glob.glob('../reward-machine/traces/participant-traces/*.json')]
duckdb.sql(f"CREATE TYPE trace_id AS ENUM {tuple(all_trace_ids)};")
duckdb.sql(f"CREATE TYPE predicate AS ENUM {tuple(sorted(config_predicates))};")
duckdb.sql(f"CREATE TYPE arg_type AS ENUM {tuple(sorted(config_types))};")
duckdb.sql(f"CREATE TYPE arg_id AS ENUM {tuple(sorted(config_arg_ids))};")

# Trace lengths and domains

In [None]:
duckdb.sql('DROP TABLE IF EXISTS trace_length_and_domains;')
duckdb.sql('CREATE TABLE trace_length_and_domains(trace_id trace_id, domain domain, length INTEGER);')
trace_length_and_domain_rows = [(trace_id, domain, length) for (trace_id, (length, domain)) in trace_lengths_and_domains.items()]
duckdb.sql(f'INSERT INTO trace_length_and_domains VALUES {str(tuple(trace_length_and_domain_rows))[1:-1]}')

duckdb.sql('DROP TABLE IF EXISTS empty_bitstrings;')
duckdb.sql("CREATE TABLE empty_bitstrings AS (SELECT trace_id, BITSTRING('0', length) as intervals FROM trace_length_and_domains)")

# Meta-types

In [None]:
duckdb.sql('DROP TABLE IF EXISTS meta_types;')
duckdb.sql('CREATE TABLE meta_types(meta_type arg_type, type arg_type);')
for meta_type, meta_type_sub_types in config.META_TYPES.items():
    for sub_type in meta_type_sub_types:
        duckdb.sql(f'INSERT INTO meta_types VALUES (\'{meta_type}\', \'{sub_type}\');')

# Object assignments

In [None]:
duckdb.sql('DROP TABLE IF EXISTS object_type_to_id;')
duckdb.sql('CREATE TABLE object_type_to_id(domain domain, type arg_type, object_id arg_id);')

data_rows = []
for domain in config.ROOMS:
    for object_dict in (config.OBJECTS_BY_ROOM_AND_TYPE[domain], config.SPECIFIC_NAMED_OBJECTS_BY_ROOM[domain]):
        for object_type, object_ids in object_dict.items():
            if object_type in config_types:
                for object_id in object_ids:
                    data_rows.append((domain, object_type, object_id))


duckdb.sql(f'INSERT INTO object_type_to_id VALUES {str(tuple(data_rows))[1:-1]}')


In [None]:
duckdb.sql('SELECT intervals | NULL from empty_bitstrings LIMIT 1')

In [None]:
# import cachetools
# import operator

# from config import GAME_OBJECT, GAME_OBJECT_EXCLUDED_TYPES, META_TYPES

# from ast_parser import PREFERENCES
# from ast_printer import ast_section_to_string


# DEBUG = True

# class PredicateNotImplementedException(Exception):
#     pass


# class MissingVariableException(Exception):
#     pass


# class CommonSensePredicateStatisticsFullDatabse():

#     def __init__(self): 
#         self.predicates = config_predicates
#         self.cache = cachetools.LRUCache(maxsize=10000)
#         self.temp_table_index = -1
#         self.temp_table_prefix = 't'

#     def _table_name(self, index: int):
#         return f"{self.temp_table_prefix}{index}"

#     def _next_temp_table_index(self):
#         self.temp_table_index += 1
#         return self.temp_table_index

#     def _next_temp_table_name(self):
#         return self._table_name(self._next_temp_table_index())

#     def filter(self, predicate: tatsu.ast.AST, mapping: typing.Dict[str, typing.Union[str, typing.List[str]]], **kwargs):
#         try:
#             self.temp_table_index = -1
#             result_query, _ = self._inner_filter(predicate, mapping, **kwargs)
#             return result_query
        
#         except PredicateNotImplementedException as e:
#             # Pass the exception through and let the caller handle it
#             raise e

#     def _predicate_and_mapping_cache_key(self, predicate: tatsu.ast.AST, mapping: typing.Dict[str, typing.Union[str, typing.List[str]]], *args, **kwargs) -> str:
#         '''
#         Returns a string that uniquely identifies the predicate and mapping
#         '''
#         return ast_section_to_string(predicate, PREFERENCES) + "_" + str(mapping)

#     @cachetools.cachedmethod(operator.attrgetter('cache'), key=_predicate_and_mapping_cache_key)
#     def _handle_predicate(self, predicate: tatsu.ast.AST, mapping: typing.Dict[str, typing.Union[str, typing.List[str]]], return_trace_ids: bool = False, **kwargs) -> typing.Tuple[str, typing.Set[str]]:
#         predicate_name = extract_predicate_function_name(predicate)  # type: ignore

#         if predicate_name not in self.predicates:
#             raise PredicateNotImplementedException(predicate_name)

#         variables = extract_variables(predicate)  # type: ignore
#         used_variables = set(variables)

#         # Restrict the mapping to just the referenced variables and expand meta-types
#         relevant_arg_mapping = {}
#         for var in variables:
#             if var in mapping:
#                 relevant_arg_mapping[var] = sum([META_TYPES.get(arg_type, [arg_type]) for arg_type in mapping[var]], [])

#             # This handles variables which are referenced directly, like the desk and bed
#             elif not var.startswith("?"):
#                 relevant_arg_mapping[var] = [var]

#             else:
#                 raise MissingVariableException(f"Variable {var} is not in the mapping")

#         select_items = ["trace_id", "domain", "string_intervals"]
#         where_items = [f"predicate='{predicate_name}'"]

#         for i, (arg_var, arg_types) in enumerate(relevant_arg_mapping.items()):
#             # if it can be the generic object type, we filter for it specifically
#             if GAME_OBJECT in arg_types:
#                 where_items.append(f"arg_{i + 1}_type NOT IN {tuple(GAME_OBJECT_EXCLUDED_TYPES)}")

#             else:
#                 if len(arg_types) == 1:
#                     where_items.append(f"arg_{i + 1}_type='{arg_types[0]}'")
#                 else:
#                     where_items.append(f"arg_{i + 1}_type IN {tuple(arg_types)}")

#             select_items.append(f"arg_{i + 1}_id as '{arg_var}'")

#         query = f"SELECT {select_items} FROM data WHERE {' AND '.join(where_items)};"
#         return query, used_variables

#     @cachetools.cachedmethod(operator.attrgetter('cache'), key=_predicate_and_mapping_cache_key)
#     def _handle_and(self, predicate: tatsu.ast.AST, mapping: typing.Dict[str, typing.Union[str, typing.List[str]]], **kwargs) -> typing.Tuple[str, typing.Set[str]]:
#         and_args = predicate["and_args"]
#         if not isinstance(and_args, list):
#             and_args = [and_args]

#         sub_queries = []
#         used_variables_by_child = []
#         all_used_variables = set()

#         for and_arg in and_args:  # type: ignore
#             try:
#                 sub_query, sub_used_variables = self._inner_filter(and_arg, mapping)  # type: ignore
#                 sub_queries.append(sub_query)
#                 used_variables_by_child.append(sub_used_variables)
#                 all_used_variables |= sub_used_variables

#             except PredicateNotImplementedException as e:
#                 continue

#         if len(sub_queries) == 0:
#             raise PredicateNotImplementedException("All sub-predicates of the and were not implemented")

#         if len(sub_queries) == 1:
#             return sub_queries[0], used_variables_by_child[0]
        
#         subquery_table_names = [self._next_temp_table_name() for _ in range(len(sub_queries))]

#         select_items = [f"{subquery_table_names[0]}.trace_id", f"{subquery_table_names[0]}.domain"]
#         selected_variables = set()
#         intervals = []
#         join_clauses = []

#         for i, (sub_query, table_name, sub_used_variables) in enumerate(zip(sub_queries, subquery_table_names, used_variables_by_child)):
#             intervals.append(f"{sub_query}.intervals")

#             for variable in sub_used_variables:
#                 if variable not in selected_variables:
#                     select_items.append(f'{sub_query}."{variable}"')
#                     selected_variables.add(variable)

#             if i > 0:
#                 join_parts = [f"INNER JOIN ({sub_query}) AS {table_name} ON ({subquery_table_names[0]}.trace_id={table_name}.trace_id)"]

#                 for j, (prev_table_name, prev_used_variables) in enumerate(zip(subquery_table_names[:i], used_variables_by_child[:i])):
#                     shared_variables = sub_used_variables & prev_used_variables
#                     join_parts.extend([f'({table_name}."{v}"={prev_table_name}."{v}")' for v in shared_variables])

#                 join_clauses.append(" AND ".join(join_parts))


#         select_items.append(f'({" & ".join(intervals)}) AS intervals')

#         table_name = self._next_temp_table_name()
#         query = f"WITH TABLE {table_name} AS (SELECT {', '.join(select_items)} FROM {subquery_table_names[0]} {' '.join(join_clauses)}) SELECT * FROM {table_name} WHERE bit_count(intervals) != 0;"
#         if DEBUG: print(query)
#         return query, all_used_variables


#     @cachetools.cachedmethod(operator.attrgetter('cache'), key=_predicate_and_mapping_cache_key)
#     def _handle_or(self, predicate: tatsu.ast.AST, mapping: typing.Dict[str, typing.Union[str, typing.List[str]]], **kwargs) -> typing.Tuple[str, typing.Set[str]]:
#         or_args = predicate["or_args"]
#         if not isinstance(or_args, list):
#             or_args = [or_args]

#         sub_queries = []
#         used_variables_by_child = []
#         all_used_variables = set()

#         for or_arg in or_args:  # type: ignore
#             try:
#                 subquery, sub_used_variables = self._inner_filter(or_arg, mapping)  # type: ignore
#                 sub_queries.append(subquery)
#                 used_variables_by_child.append(sub_used_variables)
#                 all_used_variables |= sub_used_variables

#             except PredicateNotImplementedException as e:
#                 continue

#         if len(sub_queries) == 0:
#             raise PredicateNotImplementedException("All sub-predicates of the por were not implemented")

#         if len(sub_queries) == 1:
#             return sub_queries[0], used_variables_by_child[0]

#         sub_queries.insert(0, self._build_potential_missing_values_query(mapping, list(all_used_variables)))
#         used_variables_by_child.insert(0, all_used_variables)

#         subquery_table_names = [self._next_temp_table_name() for _ in range(len(sub_queries))]

#         select_items = [f"{subquery_table_names[0]}.trace_id", f"{subquery_table_names[0]}.domain"]
#         selected_variables = set()
#         intervals = []
#         join_clauses = []

#         for i, (subquery, sub_table_name, sub_used_variables) in enumerate(zip(sub_queries, subquery_table_names, used_variables_by_child)):
#             intervals.append(f"{sub_table_name}.intervals")

#             for variable in sub_used_variables:
#                 if variable not in selected_variables:
#                     select_items.append(f'{sub_table_name}."{variable}"')
#                     selected_variables.add(variable)

#             if i > 0:
#                 join_parts = [f"LEFT JOIN {sub_table_name} ON ({subquery_table_names[0]}.trace_id={sub_table_name}.trace_id)"]

#                 shared_variables = sub_used_variables & all_used_variables
#                 join_parts.extend([f'({subquery}."{v}"={subquery_table_names[0]}."{v}")' for v in shared_variables])

#                 join_clauses.append(" AND ".join(join_parts))

#         intervals_coalesce = [f"COALESCE({intervals_select}, {intervals[0]})" if i > 0 else intervals_select for i, intervals_select in enumerate(intervals)]
#         select_items.append(f'({" | ".join(intervals_coalesce)}) AS intervals')

#         table_name = self._next_temp_table_name()
#         query = f"WITH TABLE {table_name} AS (SELECT {', '.join(select_items)} FROM {subquery_table_names[0]} {' '.join(join_clauses)}) SELECT * FROM {table_name} WHERE bit_count(intervals) != 0;"
#         if DEBUG: print(query)
#         return query, all_used_variables

#     def _build_object_assignment_cte(self, object_types: typing.Union[str, typing.List[str]]):
#         if isinstance(object_types, str) or len(object_types) == 1:
#             where_clause = f"type = '{object_types[0]}'"
#         else:
#             where_clause = f"type IN IN {tuple(object_types)}"
#         return f"SELECT domain, object_id FROM object_type_to_id WHERE {where_clause}"

#     def object_assignments_query(self, mapping: typing.Dict[str, typing.Union[str, typing.List[str]]]):
#         if len(mapping) == 0:
#             return []
        
#         if len(mapping) == 1:
#             query = self._build_object_assignment_cte(mapping[list(mapping.keys())[0]])

#         else:
#             object_id_selects = []
#             ctes = []
#             join_statements = []
#             for i, (var, var_types) in enumerate(mapping.items()):
#                 ctes.append(f"t{i} AS ({self._build_object_assignment_cte(var_types)})")
#                 object_id_selects.append(f"t{i}.object_id AS '{var}'")
#                 if i > 0:
#                     join_clauses = []
#                     join_clauses.append(f"(t0.domain = t{i}.domain)")
#                     for j in range(i):
#                         join_clauses.append(f"(t{j}.object_id != t{i}.object_id)")

#                     join_statements.append(f"JOIN t{i} ON {' AND '.join(join_clauses)}")

#             query = f"""WITH {', '.join(ctes)}
# SELECT t0.domain, {', '.join(object_id_selects)} FROM t0
# {' '.join(join_statements)}
# """

#         return query

#     def _build_potential_missing_values_query(self, mapping: typing.Dict[str, typing.Union[str, typing.List[str]]], relevant_vars: typing.List[str]):
#         # For each trace ID, and each assignment of the vars that exist in the sub_predicate_df so far:
#         relevant_var_mapping = {var: mapping[var] if var.startswith("?") else [var] for var in relevant_vars}
        
#         object_assignments_query = self.object_assignments_query(relevant_var_mapping)

#         select_variables = ', '.join(f'object_assignments."{var}" as "{var}"' for var in relevant_vars)
#         query = f"""SELECT trace_length_and_domains.trace_id as trace_id, trace_length_and_domains.domain as domain, {select_variables}, empty_bitstrings.intervals as intervals
#         FROM trace_length_and_domains
#         JOIN ({object_assignments_query}) AS object_assignments ON (trace_length_and_domains.domain = object_assignments.domain))
#         JOIN empty_bitstrings ON (trace_length_and_domains.trace_id = empty_bitstrings.trace_id)
#         """
#         if DEBUG: print(query)

#         return query

#     @cachetools.cachedmethod(operator.attrgetter('cache'), key=_predicate_and_mapping_cache_key)
#     def _handle_not(self, predicate: tatsu.ast.AST, mapping: typing.Dict[str, typing.Union[str, typing.List[str]]], **kwargs) -> typing.Tuple[str, typing.Set[str]]:
#         try:
#             inner_query, used_variables = self._inner_filter(predicate["not_args"], mapping)  # type: ignore
#         except PredicateNotImplementedException as e:
#             raise PredicateNotImplementedException(f"Sub-predicate of the not ({e.args}) was not implemented")


#         relevant_vars = list(used_variables)
#         potential_missing_values_query = self._build_potential_missing_values_query(mapping, relevant_vars)
#         potential_missing_values_table_name = self._next_temp_table_name()
#         inner_table_name = self._next_temp_table_name()

#         # Now, for each possible combination of args on each trace / domain, 'intervals' will contain the truth intervals if
#         # they exist and null otherwise, and 'intervals_right' will contain the empty interval'
#         join_columns = ["trace_id"] + relevant_vars

#         select_items = [f"{potential_missing_values_table_name}.trace_id as trace_id", f"{potential_missing_values_table_name}.domain as domain"]
#         select_items.extend(f'{potential_missing_values_table_name}."{var}" as "{var}"' for var in relevant_vars)
#         select_items.append(f"(~( {potential_missing_values_table_name}.intervals | COALESCE({inner_table_name}.intervals, {potential_missing_values_table_name}.intervals) )) AS intervals")

#         join_items = [f'{potential_missing_values_table_name}."{column}"={inner_table_name}."{column}"'  for column in join_columns]

#         not_query = f"""WITH {potential_missing_values_table_name} AS ({potential_missing_values_query}), {inner_table_name} AS ({inner_query})
#         SELECT {', '.join(select_items)} FROM {potential_missing_values_table_name} LEFT JOIN {inner_table_name} ON {' AND '.join(join_items)};
#         """

#         table_name = self._next_temp_table_name()
#         query = f"WITH {table_name} AS ({not_query}) SELECT * FROM {table_name} WHERE bit_count(intervals) != 0;"
#         if DEBUG: print(query)
#         return query, used_variables


#     def _inner_filter(self, predicate: tatsu.ast.AST, mapping: typing.Dict[str, typing.Union[str, typing.List[str]]], **kwargs) -> typing.Tuple[str, typing.Set[str]]:
#         '''
#         Filters the data by the given predicate and mapping, returning a list of intervals in which the predicate is true
#         for each processed trace

#         Returns a dictionary mapping from the trace ID to a dictionary that maps from the set of arguments to a list of
#         intervals in which the predicate is true for that set of arguments
#         '''

#         predicate_rule = predicate.parseinfo.rule  # type: ignore

#         if predicate_rule == "predicate":
#             return self._handle_predicate(predicate, mapping, **kwargs)

#         elif predicate_rule == "super_predicate":
#             return self._inner_filter(predicate["pred"], mapping, **kwargs)  # type: ignore

#         elif predicate_rule == "super_predicate_and":
#             return self._handle_and(predicate, mapping, **kwargs)

#         elif predicate_rule == "super_predicate_or":
#             return self._handle_or(predicate, mapping, **kwargs)

#         elif predicate_rule == "super_predicate_not":
#             return self._handle_not(predicate, mapping, **kwargs)

#         elif predicate_rule in ["super_predicate_exists", "super_predicate_forall", "function_comparison"]:
#             raise PredicateNotImplementedException(predicate_rule)

#         else:
#             raise ValueError(f"Error: Unknown rule '{predicate_rule}'")

In [4]:
import compile_predicate_statistics_full_database

stats = compile_predicate_statistics_full_database.CommonSensePredicateStatisticsFullDatabase(force_trace_names_hash='028b3733')

2023-08-17 12:55:28 - compile_predicate_statistics_full_database - INFO     - Creating DuckDB table...
2023-08-17 12:55:30 - compile_predicate_statistics_full_database - INFO     - Loaded data, found 1707407 rows


In [7]:
duckdb.sql("set temp_directory='/tmp/duckdb'")

In [None]:
duckdb.sql('SELECT * from data')

In [5]:
duckdb.sql("INSERT INTO data (predicate, trace_id, domain, intervals) SELECT 'game_start' as predicate, trace_id, domain, bitstring('1', length) as intervals FROM trace_length_and_domains")
duckdb.sql("INSERT INTO data (predicate, trace_id, domain, intervals) SELECT 'game_end' as predicate, trace_id, domain, set_bit(bitstring('0', length), 0, 1) as intervals FROM trace_length_and_domains")

In [6]:
duckdb.sql("INSERT INTO data (predicate, trace_id, domain, intervals) SELECT 'game_end' as predicate, trace_id, domain, set_bit(bitstring('0', length), 0, 1) as intervals FROM trace_length_and_domains")