In [1]:
%autoreload 2

In [2]:
from argparse import Namespace
from collections import defaultdict
import copy
from datetime import datetime
import difflib
import duckdb
from functools import reduce
import glob
import gzip
import itertools
import os
import pickle
import sys
import typing

from IPython.display import display, Markdown, HTML  # type: ignore
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import polars as pl
import tatsu, tatsu.ast
import tqdm.notebook as tqdmn


sys.path.append(os.path.abspath('..'))
sys.path.append(os.path.abspath('../src'))
sys.path.append(os.path.abspath('../reward-machine'))

import compile_predicate_statistics
import compile_predicate_statistics_split_args
from compile_predicate_statistics_split_args import *
from config import SPECIFIC_NAMED_OBJECTS_BY_ROOM
import config
from utils import get_object_assignments

import ast_printer
import ast_parser

2023-10-01 00:13:48 - ast_utils - DEBUG    - Using cache folder: /Users/guydavidson/tmp/game_generation_cache


In [None]:
cache_dir = compile_predicate_statistics.get_project_dir() + '/reward-machine/caches'

bitstings_df_path = os.path.join(cache_dir, 'predicate_statistics_bitstring_intervals_028b3733.pkl.gz')
bitstrings_df = pd.read_pickle(bitstings_df_path)
print(bitstrings_df.shape)
with gzip.open(os.path.join(cache_dir, 'trace_lengths_028b3733.pkl'), 'rb') as f:
    trace_lengths_and_domains = pickle.load(f)


# regular_df = pd.read_pickle(os.path.join(cache_dir, 'predicate_statistics.pkl'))
split_args_df = pd.read_pickle(os.path.join(cache_dir, 'predicate_statistics_028b3733.pkl.gz'))
# split_args_df = split_args_df[split_args_df['predicate'] != 'same_type']
# print(split_args_df.shape)
# # split_args_df = pd.read_pickle(os.path.join(cache_dir, 'predicate_statistics_4d5dd602.pkl.gz'))

# # stats = compile_predicate_statistics.CommonSensePredicateStatistics(cache_dir)
# split_args_stats = compile_predicate_statistics_split_args.CommonSensePredicateStatisticsSplitArgs(
#     # cache_dir, compile_predicate_statistics_split_args.CURRENT_TEST_TRACE_NAMES, overwrite=False
#     )

In [None]:
# drop_index = bitstrings_df[bitstrings_df.arg_1_type.isin(config.ON_EXCLUDED_OBJECT_TYPES) & (bitstrings_df.predicate == 'on')].index
# filtered_bitstrings_df = bitstrings_df.drop(drop_index)
# filtered_bitstrings_df = filtered_bitstrings_df.reset_index(drop=True)

# print(filtered_bitstrings_df.shape, bitstrings_df.shape, filtered_bitstrings_df.index.max())

# filtered_bitstrings_df.to_pickle(bitstings_df_path)

In [None]:
bitstrings_df

In [None]:
# game_object_excluded_types = set(config.GAME_OBJECT_EXCLUDED_TYPES)
# bitstrings_df = bitstrings_df.assign(arg_1_is_game_object=~bitstrings_df.arg_1_type.isin(game_object_excluded_types), arg_2_is_game_object=~bitstrings_df.arg_2_type.isin(game_object_excluded_types))
# bitstrings_df.to_pickle(bitstings_df_path)

In [None]:

bitstrings_df = bitstrings_df.assign(arg_1_is_block=bitstrings_df.arg_1_type.str.contains('block'), arg_2_is_block=bitstrings_df.arg_2_type.str.contains('block'))
bitstrings_df.to_pickle(bitstings_df_path)


In [None]:
b = bytes([0, 1])
BYTE_MAPPING = {b: str(i) for i, b in enumerate(b)}


def row_to_string_intervals(row):
    value = np.zeros(row['trace_length'], dtype=np.uint8)
    for interval in row['intervals']:
        value[interval[0]:interval[1]] = 1

    return ''.join(map(lambda b: BYTE_MAPPING[b], value.tobytes()))
    

def create_bitstings_df(df, trace_lengths_and_domains_dict, output_path):
    trace_lengths_and_domains_rows = [(key, *value) for key, value in trace_lengths_and_domains_dict.items()]
    trace_lengths_and_domains_df = pd.DataFrame(trace_lengths_and_domains_rows, columns=['trace_id', 'trace_length', 'domain'])

    split_args_with_trace_length_df = df.join(trace_lengths_and_domains_df.drop(columns=['domain']).set_index('trace_id'), on='trace_id')
    split_args_with_string_intervals_df = split_args_with_trace_length_df.assign(intervals=split_args_with_trace_length_df.apply(row_to_string_intervals, axis=1))
    split_args_with_string_intervals_df.drop(columns=['trace_length']).to_pickle(output_path)
    

In [None]:
trace_lengths_and_domains_rows = [(key, *value) for key, value in trace_lengths_and_domains.items()]
trace_lengths_and_domains_df = pd.DataFrame(trace_lengths_and_domains_rows, columns=['trace_id', 'trace_length', 'domain'])

split_args_with_trace_length_df = split_args_df.join(trace_lengths_and_domains_df.drop(columns=['domain']).set_index('trace_id'), on='trace_id')

In [None]:
b = bytes([0, 1])
BYTE_MAPPING = {b: str(i) for i, b in enumerate(b)}


def row_to_string_intervals(row):
    value = np.zeros(row['trace_length'], dtype=np.uint8)
    for interval in row['intervals']:
        value[interval[0]:interval[1]] = 1

    return ''.join(map(lambda b: BYTE_MAPPING[b], value.tobytes()))
    # return np.array2string(value, separator='', threshold=max_length + 10)[1:-1].replace('\n ', '')




In [None]:
split_args_with_string_intervals_df = split_args_with_trace_length_df.assign(intervals=split_args_with_trace_length_df.apply(row_to_string_intervals, axis=1))
split_args_with_string_intervals_df

In [None]:
split_args_with_string_intervals_df.to_pickle(bitstings_df_path)

In [None]:
split_args_with_string_intervals_df.loc[
    (split_args_with_string_intervals_df.arg_1_type == 'ball'), "arg_1_type"]

In [None]:
for arg_id, arg_type in compile_predicate_statistics_split_args.OBJECT_ID_TYPE_REMAP.items():
    original_type = None
    if 'window' in arg_id.lower():
        original_type = 'window'
    if 'shelf' in arg_id.lower():
        original_type = 'shelf'
    
    if original_type is None:
        raise ValueError(f'Could not find original type for {arg_id}')
    
    split_args_with_string_intervals_df.loc[
    (split_args_with_string_intervals_df.arg_1_id == arg_id) & (split_args_with_string_intervals_df.arg_1_type == original_type), "arg_1_type"] = arg_type

    split_args_with_string_intervals_df.loc[
    (split_args_with_string_intervals_df.arg_2_id == arg_id) & (split_args_with_string_intervals_df.arg_2_type == original_type), "arg_2_type"] = arg_type

In [None]:
from itertools import chain
from config import OBJECTS_BY_ROOM_AND_TYPE, SPECIFIC_NAMED_OBJECTS_BY_ROOM, META_TYPES, GAME_OBJECT, BUILDING

all_df_types = set(split_args_with_string_intervals_df.arg_1_type.unique()) | set(split_args_with_string_intervals_df.arg_2_type.unique())
computed_types = set(reduce(lambda x, y: x + y, [list(x.keys()) for x in chain(OBJECTS_BY_ROOM_AND_TYPE.values(), SPECIFIC_NAMED_OBJECTS_BY_ROOM.values())]))
computed_types.difference_update(META_TYPES.keys())
computed_types.remove(GAME_OBJECT)
computed_types.add(BUILDING)

In [None]:
all_df_types - computed_types

In [None]:
computed_types - all_df_types

In [None]:
RENAMED_TYPES = """blue_cube_block
tan_cube_block
yellow_cube_block
blue_pyramid_block
red_pyramid_block
yellow_pyramid_block
blue_dodgeball
red_dodgeball
pink_dodgeball
green_golfball
green_triangular_ramp""".split('\n')

class DefaultValueDict(dict):
    def __init__(self, *args, **kawags):
        super().__init__(*args, **kawags)

    def __missing__(self, key):
        return key
    
arg_type_mapping = DefaultValueDict()
for renamed_type in RENAMED_TYPES:
    sp = renamed_type.split('_')
    new_name = '_'.join(sp[1:] + sp[:1])
    arg_type_mapping[renamed_type] =  new_name



bitstrings_df.assign(arg_1_type=bitstrings_df.arg_1_type.map(arg_type_mapping), 
                     arg_2_type=bitstrings_df.arg_2_type.map(arg_type_mapping),).to_pickle(bitstings_df_path)

In [None]:
bitstrings_df.replace(dict(domain=dict(few=0, medium=1, many=2))).to_pickle('predicate_statistics_modified_bitstring_intervals_028b3733.pkl.gz')

In [None]:
split_args_stats.trace_lengths_and_domains_df.select(pl.col('trace_length').max()).item()

In [None]:
a = np.zeros(10000, dtype=np.uint8)
a[100:1000] = 1


In [None]:
t = np.zeros(10, dtype=np.uint8)
t[1:4] = 1
b = t.tobytes()
d = {b[0]: '0', b[1]: '1'}

In [None]:
''.join(map(lambda x: d[x], b))

In [None]:
b = bytes([0, 1])
b

In [None]:
MAX_LENGTH = split_args_stats.trace_lengths_and_domains_df.select(pl.col('trace_length').max()).item()
print(MAX_LENGTH)
b = bytes([0, 1])
BYTE_MAPPING = {b: str(i) for i, b in enumerate(b)}



def intervals_to_string(intervals, max_length: int = MAX_LENGTH):
    value = np.zeros(max_length, dtype=np.uint8)
    for interval in intervals:
        value[interval[0]:interval[1]] = 1

    return ''.join(map(lambda b: BYTE_MAPPING[b], value.tobytes()))
    # return np.array2string(value, separator='', threshold=max_length + 10)[1:-1].replace('\n ', '')


intervals = split_args_df.intervals.apply(intervals_to_string)

# small_split_args_df = small_split_args_df.assign()

In [None]:
split_args_df = split_args_df.assign(string_intervals=intervals)
split_args_df.head()

In [None]:
import compile_predicate_statistics_full_database

stats = compile_predicate_statistics_full_database.CommonSensePredicateStatisticsFullDatabase(force_trace_names_hash='028b3733')

In [None]:
from src.ast_utils import cached_load_and_parse_games_from_file
grammar = open('../dsl/dsl.ebnf').read()
grammar_parser = tatsu.compile(grammar)
game_asts = list(cached_load_and_parse_games_from_file('../dsl/interactive-beta.pddl', grammar_parser, False, relative_path='.'))

In [None]:
test_game_str = """
(define (game test-game) (:domain many-objects-room-v1)
(:constraints (and
    (preference testPred
        (exists (?v1 - building ?v2 - block)
            (at-end (and
                (in ?v2 ?v1)
            ))
        )
    )
))
(:scoring (+
    (* (count testPred) 1)
)))
""".strip()
# test_game_str = """
# (define (game test-game) (:domain many-objects-room-v1)
# (:constraints (and
#     (preference testPred
#         (exists (?v0 - block ?v1 - wall)
#             (at-end (and 
#                 (on ?v1 ?v0)
#                 (in ?v0 ?v1)
#             ))
#         )
#     )
# ))
# (:scoring (+
#     (* (count testPred) 1)
# )))
# """.strip()

test_ast = grammar_parser.parse(test_game_str)
test_pred = test_ast[3][1].preferences[0].definition.pref_body.body.exists_args.at_end_pred.pred.and_args[0]
print(ast_printer.ast_section_to_string(test_pred, ast_parser.PREFERENCES))

In [None]:
stats.filter(test_pred, {'?v1': ['game_object'], '?v2': ['building']})
# df = stats.filter(test_pred, {'?v1': ['pillow'], '?v2': ['hexagonal_bin']}, last_interval_bit_set=True)
# df = stats.filter(test_pred, {'?v1': ['building'], '?v2': ['flat_block'], '?v3': ['block']}, last_interval_bit_set=True)
# df

In [None]:
duckdb.sql("""
WITH t149 AS (WITH t146 AS (SELECT trace_length_and_domains.trace_id as trace_id, trace_length_and_domains.domain as domain, empty_bitstrings.intervals as intervals, object_assignments."?v4" as "?v4", object_assignments."?v6" as "?v6", object_assignments."?v2" as "?v2", object_assignments."?v5" as "?v5" FROM trace_length_and_domains JOIN (WITH t142 AS (SELECT domain, object_id AS "?v4" FROM object_type_to_id WHERE type NOT IN ('upright'::arg_type, 'yellow'::arg_type, 'wall'::arg_type, 'mirror'::arg_type, 'diagonal'::arg_type, 'front'::arg_type, 'rug'::arg_type, 'bed'::arg_type, 'desktop'::arg_type, 'back'::arg_type, 'pink'::arg_type, 'white'::arg_type, 'red'::arg_type, 'floor'::arg_type, 'poster'::arg_type, 'bottom_shelf'::arg_type, 'green'::arg_type, 'sideways'::arg_type, 'left'::arg_type, 'south_wall'::arg_type, 'blinds'::arg_type, 'west_wall'::arg_type, 'tan'::arg_type, 'agent'::arg_type, 'brown'::arg_type, 'room_center'::arg_type, 'top_shelf'::arg_type, 'main_light_switch'::arg_type, 'east_wall'::arg_type, 'orange'::arg_type, 'right'::arg_type, 'building'::arg_type, 'gray'::arg_type, 'shelf'::arg_type, 'upside_down'::arg_type, 'west_sliding_door'::arg_type, 'north_wall'::arg_type, 'purple'::arg_type, 'shelf_desk'::arg_type, 'east_sliding_door'::arg_type, 'door'::arg_type, 'side_table'::arg_type, 'sliding_door'::arg_type, 'blue'::arg_type, 'desk'::arg_type)), t143 AS (SELECT domain, object_id AS "?v6" FROM object_type_to_id WHERE type = 'cube_block_blue'), t144 AS (SELECT domain, object_id AS "?v2" FROM object_type_to_id WHERE type NOT IN ('upright'::arg_type, 'yellow'::arg_type, 'wall'::arg_type, 'mirror'::arg_type, 'diagonal'::arg_type, 'front'::arg_type, 'rug'::arg_type, 'bed'::arg_type, 'desktop'::arg_type, 'back'::arg_type, 'pink'::arg_type, 'white'::arg_type, 'red'::arg_type, 'floor'::arg_type, 'poster'::arg_type, 'bottom_shelf'::arg_type, 'green'::arg_type, 'sideways'::arg_type, 'left'::arg_type, 'south_wall'::arg_type, 'blinds'::arg_type, 'west_wall'::arg_type, 'tan'::arg_type, 'agent'::arg_type, 'brown'::arg_type, 'room_center'::arg_type, 'top_shelf'::arg_type, 'main_light_switch'::arg_type, 'east_wall'::arg_type, 'orange'::arg_type, 'right'::arg_type, 'building'::arg_type, 'gray'::arg_type, 'shelf'::arg_type, 'upside_down'::arg_type, 'west_sliding_door'::arg_type, 'north_wall'::arg_type, 'purple'::arg_type, 'shelf_desk'::arg_type, 'east_sliding_door'::arg_type, 'door'::arg_type, 'side_table'::arg_type, 'sliding_door'::arg_type, 'blue'::arg_type, 'desk'::arg_type)), t145 AS (SELECT domain, object_id AS "?v5" FROM object_type_to_id WHERE type IN ('bridge_block'::arg_type, 'cube_block'::arg_type, 'cylindrical_block'::arg_type, 'flat_block'::arg_type, 'pyramid_block'::arg_type, 'tall_cylindrical_block'::arg_type, 'tall_rectangular_block'::arg_type, 'triangle_block'::arg_type))
SELECT t142.domain, t142."?v4" AS "?v4", t143."?v6" AS "?v6", t144."?v2" AS "?v2", t145."?v5" AS "?v5" FROM t142
JOIN t143 ON (t142.domain = t143.domain) AND (t142."?v4" != t143."?v6") JOIN t144 ON (t142.domain = t144.domain) AND (t142."?v4" != t144."?v2") AND (t143."?v6" != t144."?v2") JOIN t145 ON (t142.domain = t145.domain) AND (t142."?v4" != t145."?v5") AND (t143."?v6" != t145."?v5") AND (t144."?v2" != t145."?v5")
) AS object_assignments ON (trace_length_and_domains.domain = object_assignments.domain) JOIN empty_bitstrings ON (trace_length_and_domains.trace_id = empty_bitstrings.trace_id)), t147 AS (SELECT trace_id, domain, intervals, arg_1_id AS "?v6", arg_2_id AS "?v2" FROM data WHERE predicate='above' AND (arg_1_type='cube_block_blue') AND arg_2_is_game_object IS TRUE), t148 AS (WITH t141 AS (WITH t139 AS (SELECT trace_length_and_domains.trace_id as trace_id, trace_length_and_domains.domain as domain, empty_bitstrings.intervals as intervals, object_assignments."?v4" as "?v4", object_assignments."?v5" as "?v5" FROM trace_length_and_domains JOIN (WITH t137 AS (SELECT domain, object_id AS "?v4" FROM object_type_to_id WHERE type NOT IN ('upright'::arg_type, 'yellow'::arg_type, 'wall'::arg_type, 'mirror'::arg_type, 'diagonal'::arg_type, 'front'::arg_type, 'rug'::arg_type, 'bed'::arg_type, 'desktop'::arg_type, 'back'::arg_type, 'pink'::arg_type, 'white'::arg_type, 'red'::arg_type, 'floor'::arg_type, 'poster'::arg_type, 'bottom_shelf'::arg_type, 'green'::arg_type, 'sideways'::arg_type, 'left'::arg_type, 'south_wall'::arg_type, 'blinds'::arg_type, 'west_wall'::arg_type, 'tan'::arg_type, 'agent'::arg_type, 'brown'::arg_type, 'room_center'::arg_type, 'top_shelf'::arg_type, 'main_light_switch'::arg_type, 'east_wall'::arg_type, 'orange'::arg_type, 'right'::arg_type, 'building'::arg_type, 'gray'::arg_type, 'shelf'::arg_type, 'upside_down'::arg_type, 'west_sliding_door'::arg_type, 'north_wall'::arg_type, 'purple'::arg_type, 'shelf_desk'::arg_type, 'east_sliding_door'::arg_type, 'door'::arg_type, 'side_table'::arg_type, 'sliding_door'::arg_type, 'blue'::arg_type, 'desk'::arg_type)), t138 AS (SELECT domain, object_id AS "?v5" FROM object_type_to_id WHERE type IN ('bridge_block'::arg_type, 'cube_block'::arg_type, 'cylindrical_block'::arg_type, 'flat_block'::arg_type, 'pyramid_block'::arg_type, 'tall_cylindrical_block'::arg_type, 'tall_rectangular_block'::arg_type, 'triangle_block'::arg_type))
SELECT t137.domain, t137."?v4" AS "?v4", t138."?v5" AS "?v5" FROM t137
JOIN t138 ON (t137.domain = t138.domain) AND (t137."?v4" != t138."?v5")
) AS object_assignments ON (trace_length_and_domains.domain = object_assignments.domain) JOIN empty_bitstrings ON (trace_length_and_domains.trace_id = empty_bitstrings.trace_id)), t140 AS (SELECT trace_id, domain, intervals, arg_1_id AS "?v4", arg_2_id AS "?v5" FROM data WHERE predicate='adjacent' AND arg_1_is_game_object IS TRUE AND arg_2_is_block IS TRUE)
        SELECT t139.trace_id as trace_id, t139.domain as domain, t139."?v4" as "?v4", t139."?v5" as "?v5", (~( t139.intervals | COALESCE(t140.intervals, t139.intervals) )) AS intervals FROM t139 LEFT JOIN t140 ON t139."trace_id"=t140."trace_id" AND t139."?v4"=t140."?v4" AND t139."?v5"=t140."?v5"
        ) SELECT * FROM t141 WHERE bit_count(intervals) != 0) SELECT t146.trace_id, t146.domain, t146."?v4", t146."?v6", t146."?v2", t146."?v5", (t146.intervals | COALESCE(t147.intervals, t146.intervals) | COALESCE(t148.intervals, t146.intervals)) AS intervals FROM t146 LEFT JOIN t147 ON (t146.trace_id=t147.trace_id) AND (t146."?v6"=t147."?v6") AND (t146."?v2"=t147."?v2") LEFT JOIN t148 ON (t146.trace_id=t148.trace_id) AND (t146."?v4"=t148."?v4") AND (t146."?v5"=t148."?v5")) SELECT * FROM t149 WHERE bit_count(intervals) != 0
""").df()

In [None]:
q = """
SELECT trace_id, arg_1_id, arg_2_id, overlap, d1_count, d2_count
FROM (WITH 
d1 AS (SELECT trace_id, arg_1_id, arg_1_type, arg_2_id, arg_2_type, intervals FROM data WHERE predicate='on' AND arg_1_type!='building' AND arg_2_type!='building'), 
d2 AS (SELECT trace_id, arg_1_id, arg_2_id, intervals FROM data WHERE predicate='on' AND arg_1_type!='building' AND arg_2_type!='building')
SELECT d1.trace_id, d1.arg_1_id, d1.arg_1_type, d1.arg_2_id, d1.arg_2_type, bit_count(d1.intervals & d2.intervals) as overlap, bit_count(d1.intervals) as d1_count, bit_count(d2.intervals) as d2_count
FROM d1
INNER JOIN d2 ON d1.trace_id = d2.trace_id AND d1.arg_1_id = d2.arg_2_id AND d1.arg_2_id = d2.arg_1_id)
WHERE overlap > 0
"""
reciprocal_on_df = duckdb.sql(q).df()  # .to_csv('temp_outputs/a_on_b_and_b_on_a.csv')
reciprocal_on_df
# reciprocal_on_df = reciprocal_on_df.assign(remove_d1=reciprocal_on_df.d1_count <= reciprocal_on_df.d2_count)
# reciprocal_on_df


In [None]:
duckdb.sql('CREATE INDEX idx_data_arg_1_type ON data (arg_1_type)')

In [None]:
predicate_on_rows = bitstrings_df.predicate == 'on'
indices_to_remove = []

for _, row in reciprocal_on_df.iterrows():
    row_filter = predicate_on_rows & (bitstrings_df.trace_id == row.trace_id)
    if row.remove_d1:
        row_filter &= (bitstrings_df.arg_1_id == row.arg_1_id) & (bitstrings_df.arg_2_id == row.arg_2_id)

    else:
        row_filter &= (bitstrings_df.arg_1_id == row.arg_2_id) & (bitstrings_df.arg_2_id == row.arg_1_id)

    indices_to_remove.extend(bitstrings_df[row_filter].index)


print(indices_to_remove)



In [None]:
len(set(indices_to_remove))

In [None]:
# filtered_bitstrings_df = bitstrings_df.drop(index=indices_to_remove)
# filtered_bitstrings_df = filtered_bitstrings_df.reset_index(drop=True)

# print(filtered_bitstrings_df.shape, bitstrings_df.shape, filtered_bitstrings_df.index.max())

# filtered_bitstrings_df.to_pickle(bitstings_df_path)

In [None]:
q = """
SELECT trace_id, arg_1_id, arg_2_id, overlap  
FROM (WITH 
d1 AS (SELECT trace_id, arg_1_id, arg_1_type, intervals FROM data WHERE predicate='agent_holds'), 
d2 AS (SELECT trace_id, arg_1_id, arg_2_id, intervals FROM data WHERE predicate='adjacent' AND (arg_1_type='agent' OR arg_2_type='agent'))
SELECT d2.trace_id, d2.arg_1_id, d2.arg_2_id, bit_count(d1.intervals & d2.intervals) as overlap
FROM d2
INNER JOIN d1 ON d1.trace_id = d2.trace_id AND (d1.arg_1_id = d2.arg_1_id OR d1.arg_1_id = d2.arg_2_id)
)
WHERE overlap > 0
"""
duckdb.sql(q).df()  # .to_csv('temp_outputs/a_on_b_and_b_on_a.csv')


In [None]:
q = """
SELECT count(*) FROM data
WHERE predicate='agent_holds'
"""
duckdb.sql(q)  # .df()  # .to_csv('temp_outputs/a_on_b_and_b_on_a.csv')

In [None]:
in_arg_types = duckdb.sql("SELECT arg_types.* FROM (SELECT DISTINCT(arg_1_type, arg_2_type) as arg_types FROM data WHERE predicate='in');").fetchall()
in_arg_types = [tuple(x) for x in in_arg_types]
first_arg_types, second_arg_types = zip(*in_arg_types)

In [None]:
[t for t in in_arg_types if t[0] == 'mug']

In [None]:
on_arg_types = duckdb.sql("SELECT arg_types.* FROM (SELECT DISTINCT(arg_1_type, arg_2_type) as arg_types FROM data WHERE predicate='on');").fetchall()
on_arg_types = [tuple(x) for x in on_arg_types]
on_first_arg_types, on_second_arg_types = zip(*on_arg_types)

In [None]:
[t for t in on_arg_types if t[1] == 'desk']

In [None]:
import room_and_object_types

on_types_by_category = defaultdict(set)
for t in on_first_arg_types:
    on_types_by_category[room_and_object_types.TYPES_TO_CATEGORIES[t]].add(t)


for cat in on_types_by_category:
    print(cat)
    print(on_types_by_category[cat])
    print()

In [None]:
q = """
SELECT * FROM (WITH 
d1 AS (SELECT trace_id, arg_1_id, arg_2_id, intervals FROM data WHERE predicate='on' AND arg_1_type!='building' AND arg_2_type!='building'), 
d2 AS (SELECT trace_id, arg_1_id, arg_2_id, intervals FROM data WHERE predicate='on' AND arg_1_type!='building' AND arg_2_type!='building')
SELECT d1.trace_id, d1.arg_1_id, d1.arg_2_id, bit_count(d1.intervals & d2.intervals) as overlap
FROM d1
INNER JOIN d2 ON d1.trace_id = d2.trace_id AND d1.arg_1_id = d2.arg_2_id AND d1.arg_2_id = d2.arg_1_id)
WHERE overlap > 100
"""
duckdb.sql(q).df().to_csv('temp_outputs/a_on_b_and_b_on_a.csv')

# duckdb.sql("SELECT * FROM data")

In [None]:
q = """
SELECT * FROM (WITH 
d1 AS (SELECT trace_id, arg_1_id, arg_2_id, intervals FROM data WHERE predicate='on' AND arg_1_type!='building' AND arg_2_type!='building'), 
d2 AS (SELECT trace_id, arg_1_id, arg_2_id, intervals FROM data WHERE predicate='in' AND arg_1_type!='building' AND arg_2_type!='building')
SELECT d1.trace_id, d1.arg_1_id, d1.arg_2_id, bit_count(d1.intervals & d2.intervals) as overlap
FROM d1
INNER JOIN d2 ON d1.trace_id = d2.trace_id AND d1.arg_1_id = d2.arg_2_id AND d1.arg_2_id = d2.arg_1_id)
WHERE overlap > 100
"""
duckdb.sql(q).df().to_csv('temp_outputs/a_on_b_and_b_in_a.csv')

# duckdb.sql("SELECT * FROM data")

In [None]:
q = """
SELECT trace_id, domain, arg_1_id, arg_2_id, bit_position('1'::BIT, joint_intervals) as "first_index" FROM (WITH 
d1 AS (SELECT trace_id, domain, arg_1_id, arg_2_id, intervals FROM data WHERE predicate='on' AND (arg_1_type='building' OR arg_2_type='building')), 
d2 AS (SELECT trace_id, arg_1_id, arg_2_id, intervals FROM data WHERE predicate='on' AND (arg_1_type='building' OR arg_2_type='building'))
SELECT d1.trace_id, d1.domain, d1.arg_1_id, d1.arg_2_id, d1.intervals & d2.intervals as joint_intervals
FROM d1
INNER JOIN d2 ON d1.trace_id = d2.trace_id AND d1.arg_1_id = d2.arg_2_id AND d1.arg_2_id = d2.arg_1_id)
WHERE bit_count(joint_intervals) > 100
"""
buildings_df = duckdb.sql(q).df().to_csv('temp_outputs/a_on_b_and_b_on_a_buildings.csv')
buildings_df
# buildings_df[(buildings_df.trace_id == 'Q6a8AbiIdcLA9tJzAu14-createGame-rerecorded') & (buildings_df.arg_2_id == 'SmallSlide|-00.81|+00.14|-03.10')]

# duckdb.sql("SELECT * FROM data")

In [None]:
q = """
SELECT * FROM data
WHERE trace_id='Q6a8AbiIdcLA9tJzAu14-createGame-rerecorded' AND arg_1_id='building_1' and arg_2_id='SmallSlide|-00.81|+00.14|-03.10' AND predicate='on'
"""

d = duckdb.sql(q).df()
print(d.loc[0, 'intervals'] == d.loc[1, 'intervals'])
d

In [None]:
q = """
SELECT trace_id, domain, arg_1_id, arg_2_id, bit_position('1'::BIT, intervals) as 'first_index' FROM data
WHERE predicate='on' and arg_2_type in ('bed', 'desk') and arg_1_type NOT IN ('floor', 'rug')
"""
duckdb.sql(q).df().to_csv('temp_outputs/bed_or_desk_on_object_that_is_not_floor_or_rug.csv')

# duckdb.sql("SELECT * FROM data")

In [None]:
p = game_asts[1][4][1].preferences[0].definition.forall_pref.preferences.pref_body.body.then_funcs[0].seq_func.once_pred
print(p.keys())
ast_printer.ast_section_to_string(p, ast_parser.PREFERENCES)

In [None]:
q = stats.filter(p, {"?b": ["ball"], "?t": ["hexagonal_bin"]})
print(q)

In [None]:
duckdb.sql('PRAGMA force_index_join;')
duckdb.sql("PRAGMA explain_output='OPTIMIZED_ONLY';")

In [None]:
# q = """
#     SELECT t0.trace_id, t0.domain, t0."?b", t1."door", t1."agent", (t0.intervals & t1.intervals) AS intervals
#     FROM (SELECT trace_id, domain, intervals, arg_1_id AS "?b" FROM data WHERE predicate='agent_holds' AND (arg_1_type IN ('beachball'::arg_type, 'basketball'::arg_type, 'dodgeball'::arg_type, 'golfball'::arg_type))) as t0
#     INNER JOIN (SELECT trace_id, domain, intervals, arg_1_id AS "door", arg_2_id AS "agent" FROM data WHERE predicate='adjacent' AND (arg_1_type='door') AND (arg_2_type='agent')) as t1
#     ON (t0.trace_id=t1.trace_id)
# """

q = """
    SELECT t0.trace_id, t0.domain, t0.arg_1_id AS "?b", t1.arg_1_id AS "door", t1.arg_2_id AS "agent", (t0.intervals & t1.intervals) AS intervals
    FROM data AS t0
    INNER JOIN data AS t1
    ON (t0.trace_id=t1.trace_id)
    WHERE t0.predicate='agent_holds' AND (t0.arg_1_type IN ('beachball'::arg_type, 'basketball'::arg_type, 'dodgeball'::arg_type, 'golfball'::arg_type)) AND
    t1.predicate='adjacent' AND (t1.arg_1_type='door') AND (t1.arg_2_type='agent')

"""

print(duckdb.sql(f"EXPLAIN ANALYZE ({q})").fetchone()[1])

In [None]:
duckdb.sql("INSERT INTO domains VALUES ('few'), ('medium'), ('many')")

In [None]:
df = duckdb.sql(q).fetchdf()

In [None]:
duckdb.sql('CREATE INDEX idx_data_predicate ON data (predicate)')

In [None]:
t = np.zeros(10, dtype=np.uint8)
t[1:4] = 1
b = t.tobytes()
d = {b[0]: '0', b[1]: '1'}

In [None]:
df[df.trace_id == '1El1CmicSoKZKTLe8NpP-preCreateGame-rerecorded']

In [None]:
bits_from_df = np.unpackbits(np.frombuffer(interval_from_df, dtype=np.uint8))
len(bits_from_df), bits_from_df

In [None]:
results = duckdb.sql(q).fetchall()

for tup in results:
    sub_df = df[(tup[0] == df['trace_id']) & (tup[2] == df['?b']) & (tup[3] == df['?t'])]
    if len(sub_df) != 1:
        print(f'Error: {tup[:-1]}')
        print(len(sub_df))
        break

    expected_length = duckdb.sql(f"SELECT length from trace_length_and_domains WHERE trace_id='{tup[0]}'").fetchone()[0]

    bits_from_df = np.unpackbits(np.frombuffer(sub_df.intervals.item(), dtype=np.uint8))
    bits_from_db = np.fromiter(map(int, tup[-1]), dtype=np.uint8)

    if len(bits_from_db) != expected_length:
        print(f'Error: {tup[:-1]}')
        print(len(bits_from_db))
        print(expected_length)
        print(len(bits_from_df))
        break

    if not np.all(bits_from_df[-expected_length:] == bits_from_db):
        print(f'Error: {tup[:-1]}')
        print(np.where(bits_from_df[-expected_length:] != bits_from_db))


In [None]:
tup[:-1]

In [None]:
len(bits_from_df[-550:])

In [None]:
bits_from_db = np.fromiter(map(int, t[-1]), dtype=np.uint8)

In [None]:
bits_from_df = np.unpackbits(np.frombuffer(df[df.trace_id == t[0]].intervals.item(), dtype=np.uint8))

In [None]:
print(len(bits_from_db), len(bits_from_df))

In [None]:
1264 % 16

In [None]:
print(len(bits_from_db), len(bits_from_df))
np.where(bits_from_db[:] != bits_from_df[14:])

In [None]:
bits_from_db[500:540], bits_from_df[500:540]

In [None]:
(bits_from_db[:min_length] != bits_from_df[:min_length])[:10]

## Gameplan from here

* For (setup, each preference):
* Inspect to make sure all predicates implemeneted in the cache, it doesn't use a once-measure (and probably also not a hold-while, at least at first), etc. 
    * otherwise default to the basic implementation
* If we're go to run on a particular thing:
    * If it's a setup, it's trivial
    * If it's a preference, check if it's an at-end or then
        * If it's an at-end, it's trivial (can probably fold this into the query by doing `get_bit(intervals, length - 1)` if we join on the trace lengths table
        * If it's a then, enumerate over the predicates of each modal, and query for them
            * For each one, fetch the df for the query, use the trace lengths to transform the intervals to the expected format
            * Add the index of the modal to the df
            * Enumerate through all trace ids and assignments, and for each assignment where we have all modals represented: 
                * Create the joint state interval
                * Iterate through the joint using the state machine logic
                * Count satisfactions
                * ...
                * Profit!

In [None]:
import tatsu.ast
import tatsu.grammars
from ast_parser import ASTParser, SECTION_CONTEXT_KEY, VARIABLES_CONTEXT_KEY
from ast_utils import simplified_context_deepcopy, deepcopy_ast, ASTCopyType, replace_child


DEFAULT_UNSUPPORTED_RULES = [
    'function_comparison',
    'function_eval',
    'predicate_adjacent_side_3',
    'predicate_adjacent_side_4',
    'predicate_between',
    'predicate_faces',
    'predicate_is_setup_object',
    'predicate_opposite',
    'predicate_rug_color_under',
    'predicate_same_color',
    'predicate_same_object',
    'predicate_same_type',
    'super_predicate_exists',
    'super_predicate_forall',
    'once_measure',
    'while_hold',
]


def _pref_forall_pos_to_key(pos: int):
    return f'pref_forall_{pos}'


class MixedTraceFilterGameParser(ASTParser):
    unsupported_rules: typing.Set[str]

    def __init__(self, unsupported_rules: typing.Sequence[str] = DEFAULT_UNSUPPORTED_RULES):
        super().__init__()
        self.expected_keys = set()
        self.unsupported_rules = set(unsupported_rules)

    def __call__(self, ast, **kwargs):
        initial_call = 'inner_call' not in kwargs or not kwargs['inner_call']
        if initial_call:
            kwargs['inner_call'] = True
            kwargs['local_context'] = {'mapping': {VARIABLES_CONTEXT_KEY: {}}}
            kwargs['global_context'] = {}
            self.expected_keys = set()
            self.unsupported_keys = set()
            # self.traces_by_preference_or_section = {}
            # self.preferences_or_sections_with_implemented_predicates = set()
            # self.predicate_strings_by_preference_or_section = defaultdict(set)
            # self.not_implemented_predicate_counts = defaultdict(int)

        retval = super().__call__(ast, **kwargs)

        if initial_call:
            return self.unsupported_keys, self.expected_keys
        else:
            return retval

    def _current_ast_to_contexts_hook(self, ast: tatsu.ast.AST, kwargs: typing.Dict[str, typing.Any]):
        rule = typing.cast(str, ast.parseinfo.rule)  # type: ignore

        if rule == 'pref_forall':
            kwargs['local_context']['current_pref_forall_index'] = ast.parseinfo.pos

        if rule == 'preference':
            kwargs['local_context']['current_preference_name'] = ast.pref_name

    def _handle_ast(self, ast: tatsu.ast.AST, **kwargs):
        self._current_ast_to_contexts(ast, **kwargs)
        kwargs['local_context']['mapping'] = ast_parser.update_context_variables(ast, kwargs['local_context']['mapping'])

        current_key = None
        if SECTION_CONTEXT_KEY in kwargs and kwargs[SECTION_CONTEXT_KEY] == ast_parser.SETUP:
            current_key = kwargs[SECTION_CONTEXT_KEY]
        elif 'current_pref_forall_index' in kwargs['local_context']:
            current_key =_pref_forall_pos_to_key(kwargs['local_context']['current_pref_forall_index'])
        elif 'current_preference_name' in kwargs['local_context']:
            current_key = kwargs['local_context']['current_preference_name']
        
        if current_key is not None:
            self.expected_keys.add(current_key)

            if ast.parseinfo.rule in self.unsupported_rules:
                self.unsupported_keys.add(current_key)

        for key in ast:
            if key != 'parseinfo':
                child_kwargs = simplified_context_deepcopy(kwargs)
                retval = self(ast[key], **child_kwargs)
                self._update_contexts_from_retval(kwargs, retval)
            
            


In [None]:
game_parser = MixedTraceFilterGameParser()
for ast in game_asts:
    unsupported, expected = game_parser(ast)
    supported = expected - unsupported
    print(f'Game {ast[1].game_name} has supported keys: {list(supported)} and unsupported keys: {list(unsupported)}')

In [None]:



DUMMY_PREFERENCE_GAME = """(define (game dummy-preference-game) (:domain many-objects-room-v1)
(:constraints (and
    (preference dummyPreference
            (at-end (game-over))
    )
))
(:scoring (count dummyPreference)
))
"""


class ASTTraceFilterSplitter(ast_parser.ASTParser):
    keep_keys: typing.Set[str]
    remove_keys: typing.Set[str]
    should_insert_dummy_preference: bool
    
    def __init__(self, grammar_parser: tatsu.grammars.Grammar):
        self.grammar_parser = grammar_parser

    def __call__(self, ast, **kwargs):
        initial_call = 'inner_call' not in kwargs or not kwargs['inner_call']
        if initial_call:
            kwargs['inner_call'] = True
            
            if 'remove_keys' not in kwargs:
                raise ValueError('remove_keys must be specified')
            self.remove_keys = kwargs['remove_keys']

            if len(self.remove_keys) == 0:
                raise ValueError('remove_keys must be non-empty')

            if 'keep_keys' not in kwargs:
                raise ValueError('keep_keys must be specified')
            self.keep_keys = kwargs['keep_keys']

            if len(self.keep_keys) == 0:
                raise ValueError('keep_keys must be non-empty')

            ast = deepcopy_ast(ast)

            # Handle the setup right here and now, if we're removing it
            if ast_parser.SETUP in self.remove_keys:
                ast = (*ast[:3], *ast[4:])
                # If the only thin we're removing is the setup, we're done
                if len(self.remove_keys) == 1:
                    return ast

            # check if we're only keeping the setup and inserting a dummy preference, because if so, we're done
            if len(self.keep_keys) == 1 and ast_parser.SETUP in self.keep_keys:
                dummy_preference_game = self.grammar_parser.parse(DUMMY_PREFERENCE_GAME)
                return (*ast[:4], dummy_preference_game[3], *ast[4:])

        super().__call__(ast, **kwargs)

        if initial_call:
            return ast
        
    def _handle_ast(self, ast: tatsu.ast.AST, **kwargs):
        rule = ast.parseinfo.rule

        if rule == 'preferences':
            if isinstance(ast.preferences, tatsu.ast.AST):
                raise ValueError(f'If removing a single preference, the initial call should handle it, so this should never occur')
            
            new_children = typing.cast(typing.List[tatsu.ast.AST], deepcopy_ast(ast.preferences, ASTCopyType.NODE))
            indices_to_remove = []
            for i, child in enumerate(new_children):
                if child.parseinfo.rule == 'preference' and child.pref_name in self.remove_keys:
                    print(f'Removing preference {child.pref_name}')
                    indices_to_remove.append(i)
                elif child.parseinfo.rule == 'pref_forall' and _pref_forall_pos_to_key(child.parseinfo.pos) in self.remove_keys:
                    print(f'Removing pref_forall {_pref_forall_pos_to_key(child.parseinfo.pos)}')
                    indices_to_remove.append(i)

            for i in reversed(indices_to_remove):
                new_children.remove(new_children[i])

            replace_child(ast, 'preferences', new_children)

        else:
            for key in ast:
                if key != 'parseinfo':
                    self(ast[key], **kwargs)

            
        


In [None]:
game_parser = MixedTraceFilterGameParser()
game_splitter = ASTTraceFilterSplitter(grammar_parser)  # type: ignore
ast = game_asts[0]
unsupported, expected = game_parser(ast)
supported = expected - unsupported

if len(supported) > 0 and len(unsupported) > 0:
    print(f'Game {ast[1].game_name} has supported keys: {list(supported)} and unsupported keys: {list(unsupported)}')
    supported_only = game_splitter(ast, keep_keys=supported, remove_keys=unsupported)
    unsupported_only = game_splitter(ast, keep_keys=unsupported, remove_keys=supported)

    print('=' * 80)
    print(ast_printer.ast_to_string(supported_only, '\n'))
    print('=' * 80)
    print(ast_printer.ast_to_string(unsupported_only, '\n'))

In [None]:
from itertools import chain
categorical_type = pd.api.types.CategoricalDtype(sorted(chain.from_iterable(duckdb.sql("SELECT enum_range(NULL::trace_id)").fetchone())), ordered=True)

In [None]:
trace_id_to_length_df = duckdb.sql('SELECT * FROM trace_length_and_domains').fetchdf()
trace_id_to_length_df.drop(columns=['domain'], inplace=True)
trace_id_to_length_df.rename(columns=dict(length='trace_length'), inplace=True)
# trace_id_to_length_df.astype(dict(trace_id=categorical_type), copy=False)
trace_id_to_length_df.head()

In [None]:
def _df_intervals_to_array(row):
    return np.unpackbits(np.frombuffer(row['intervals'], dtype=np.uint8))[-row['trace_length']:]


In [None]:
df.astype(dict(trace_id=categorical_type), copy=False)

In [None]:
# df.join(trace_id_to_length_df, on=['trace_id'], how='outer', rsuffix='_r')
merged_df = df.merge(trace_id_to_length_df, on=['trace_id'], how='left')
merged_df

In [None]:
assigned_df = df.assign(intervals=merged_df.apply(_df_intervals_to_array, axis=1))
assigned_df

In [None]:
# assigned_df.groupby('trace_id', as_index=True, observed=False).intervals.transform(lambda l: list(np.logical_or.reduce(*l)))

series_1 = assigned_df.groupby('trace_id', as_index=True, observed=True).intervals.agg(lambda x: reduce(np.logical_or, x.values).astype(bool))

In [None]:
series_2 = series_1.iloc[:50]

In [None]:
merged = pd.merge(series_1, series_2, left_index=True, right_index=True, how='inner')
s = merged.agg(lambda row: np.logical_and(row['intervals_x'], row['intervals_y']), axis=1)