In [1]:
%autoreload 2

In [76]:
from argparse import Namespace
from collections import defaultdict
import copy
from datetime import datetime
import difflib
import gzip
import itertools
import os
import pickle
import sys
import typing

from IPython.display import display, Markdown, HTML  # type: ignore
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import polars as pl
import tatsu
import tqdm.notebook as tqdmn


sys.path.append(os.path.abspath('..'))
sys.path.append(os.path.abspath('../reward-machine'))

import compile_predicate_statistics
import compile_predicate_statistics_split_args
from config import SPECIFIC_NAMED_OBJECTS_BY_ROOM

In [77]:
cache_dir = compile_predicate_statistics.get_project_dir() + '/reward-machine/caches'

# regular_df = pd.read_pickle(os.path.join(cache_dir, 'predicate_statistics.pkl'))
split_args_df = pd.read_pickle(os.path.join(cache_dir, 'predicate_statistics_028b3733.pkl.gz'))
# split_args_df = pd.read_pickle(os.path.join(cache_dir, 'predicate_statistics_4d5dd602.pkl.gz'))

# stats = compile_predicate_statistics.CommonSensePredicateStatistics(cache_dir)
split_args_stats = compile_predicate_statistics_split_args.CommonSensePredicateStatisticsSplitArgs(
    cache_dir, compile_predicate_statistics_split_args.CURRENT_TEST_TRACE_NAMES, overwrite=False)

2023-08-07 12:22:38 - compile_predicate_statistics_split_args - INFO     - Loaded data with shape (416740, 8) from /Users/guydavidson/projects/game-generation-modeling/reward-machine/caches/predicate_statistics_4d5dd602.pkl.gz


In [81]:
all_specific_names = set()
for room, type_to_ids in SPECIFIC_NAMED_OBJECTS_BY_ROOM.items():
    all_specific_names.update(type_to_ids.keys())

specific_name_rows = split_args_df[(split_args_df.arg_1_type.isin(all_specific_names) | split_args_df.arg_2_type.isin(all_specific_names))].shape[0]
total_rows = split_args_df.shape[0]
print(f'{specific_name_rows} / {total_rows} ({specific_name_rows / total_rows * 100:.2f}%) rows have a specific name')

171832 / 2703800 (6.36%) rows have a specific name


In [87]:
same_type_rows = split_args_df[(split_args_df.predicate == 'same_type')].shape[0]
print(f'{same_type_rows} / {total_rows} ({same_type_rows / total_rows * 100:.2f}%) rows are for same_type')

2078294 / 2703800 (76.87%) rows are for same_type


In [66]:
for i, row in split_args_df.iterrows():
    print(i, row.to_dict())
    break

0 {'predicate': 'adjacent', 'arg_1_id': 'Floor|+00.00|+00.00|+00.00', 'arg_1_type': 'floor', 'arg_2_id': 'building_6', 'arg_2_type': 'building', 'trace_id': '1HOTuIZpRqk2u1nZI1v1-gameplay-attempt-1-rerecorded', 'domain': 'many', 'intervals': [[0, 921]]}


In [55]:
split_args_df[(split_args_df.predicate == 'adjacent') & (split_args_df.arg_1_type == 'agent') & (split_args_df.arg_2_type == 'green_golfball')]

Unnamed: 0,predicate,arg_1_id,arg_1_type,arg_2_id,arg_2_type,trace_id,domain,intervals
1542,adjacent,agent,agent,Golfball|+01.05|+01.04|-02.70,green_golfball,1HOTuIZpRqk2u1nZI1v1-gameplay-attempt-1-rereco...,many,"[[443, 459], [873, 900]]"
7877,adjacent,agent,agent,Golfball|+01.05|+01.04|-02.70,green_golfball,IvoZWi01FO2uiNpNHyci-createGame-rerecorded,many,"[[298, 314]]"
21171,adjacent,agent,agent,Golfball|+01.05|+01.04|-02.70,green_golfball,WtZpe3LQFZiztmh7pBBC-gameplay-attempt-1-rereco...,many,"[[47, 51], [53, 64], [1122, 1181]]"
28515,adjacent,agent,agent,Golfball|+01.05|+01.04|-02.70,green_golfball,FyGQn1qJCLTLU1hfQfZ2-preCreateGame-rerecorded,many,"[[229, 236], [534, 579], [711, 712], [713, 715..."
47304,adjacent,agent,agent,Golfball|+01.05|+01.04|-02.70,green_golfball,39PytL3fAMFkYXNoB5l6-gameplay-attempt-1-rereco...,many,"[[1066, 1067], [1174, 1208], [1225, 1230], [18..."
63520,adjacent,agent,agent,Golfball|+01.05|+01.04|-02.70,green_golfball,f2WUeVzu41E9Lmqmr2FJ-preCreateGame-rerecorded,many,"[[196, 203], [207, 208], [209, 210], [781, 788..."
80473,adjacent,agent,agent,Golfball|+01.05|+01.04|-02.70,green_golfball,IhOkh1l3TBY9JJVubzHx-gameplay-attempt-1-rereco...,many,"[[220, 225], [566, 573], [2511, 2516], [2533, ..."
135658,adjacent,agent,agent,Golfball|+01.05|+01.04|-02.70,green_golfball,FyGQn1qJCLTLU1hfQfZ2-freePlay-rerecorded,many,"[[1036, 1100], [1168, 1177]]"
144910,adjacent,agent,agent,Golfball|+01.05|+01.04|-02.70,green_golfball,IvoZWi01FO2uiNpNHyci-freePlay-rerecorded,many,"[[32, 43]]"
152502,adjacent,agent,agent,Golfball|+01.05|+01.04|-02.70,green_golfball,IvoZWi01FO2uiNpNHyci-preCreateGame-rerecorded,many,"[[592, 609]]"


In [37]:
split_args_df[(split_args_df.predicate == 'same_type')]

Unnamed: 0,predicate,arg_1_id,arg_1_type,arg_2_id,arg_2_type,trace_id,domain,intervals
1598,same_type,Beachball|+02.29|+00.19|-02.88,beachball,beachball,object_type,1HOTuIZpRqk2u1nZI1v1-gameplay-attempt-1-rereco...,many,"[[0, 921]]"
1599,same_type,Beachball|+02.29|+00.19|-02.88,beachball,ball,object_type,1HOTuIZpRqk2u1nZI1v1-gameplay-attempt-1-rereco...,many,"[[0, 921]]"
1600,same_type,BridgeBlock|-02.92|+00.09|-02.52,bridge_block,bridge_block,object_type,1HOTuIZpRqk2u1nZI1v1-gameplay-attempt-1-rereco...,many,"[[0, 921]]"
1601,same_type,BridgeBlock|-02.92|+00.09|-02.52,bridge_block,block,object_type,1HOTuIZpRqk2u1nZI1v1-gameplay-attempt-1-rereco...,many,"[[0, 921]]"
1602,same_type,BridgeBlock|-02.92|+00.26|-02.52,bridge_block,bridge_block,object_type,1HOTuIZpRqk2u1nZI1v1-gameplay-attempt-1-rereco...,many,"[[0, 921]]"
...,...,...,...,...,...,...,...,...
375707,same_type,building_19,building,building_14,building,7r4cgxJHzLJooFaMG1Rd-preCreateGame-rerecorded,many,"[[0, 923]]"
375708,same_type,building_19,building,building_15,building,7r4cgxJHzLJooFaMG1Rd-preCreateGame-rerecorded,many,"[[0, 923]]"
375709,same_type,building_19,building,building_16,building,7r4cgxJHzLJooFaMG1Rd-preCreateGame-rerecorded,many,"[[0, 923]]"
375710,same_type,building_19,building,building_17,building,7r4cgxJHzLJooFaMG1Rd-preCreateGame-rerecorded,many,"[[0, 923]]"


In [25]:
split_args_df[(split_args_df.predicate == 'in') & (split_args_df.arg_1_type == 'building')].arg_2_type.value_counts()

cube_block                160
dodgeball                 156
pink_dodgeball             66
tan_cube_block             58
hexagonal_bin              53
blue_dodgeball             52
blue_cube_block            51
yellow_cube_block          51
chair                      42
red_dodgeball              38
pillow                     31
golfball                   28
beachball                  26
laptop                     24
curved_wooden_ramp         21
book                       16
key_chain                  13
triangular_ramp            13
cd                         11
credit_card                10
watch                      10
cellphone                   9
pen                         9
green_golfball              8
green_triangular_ramp       6
doggie_bed                  6
basketball                  6
alarm_clock                 6
teddy_bear                  5
tall_rectangular_block      4
pencil                      3
tall_cylindrical_block      3
mug                         3
bridge_blo

In [None]:
door_df = split_args_df[(split_args_df.arg_1_type == 'door') | (split_args_df.arg_1_type == 'door')]
print(door_df.predicate.value_counts())
print()
for predicate in door_df.predicate.unique():
    pred_df = door_df[door_df.predicate == predicate]
    print(f'For predicate {predicate}:')
    print(pred_df.arg_2_type.value_counts().iloc[:10])
    print()

In [None]:
for tid in door_df[door_df.predicate == 'in_motion'].trace_id:
    print(tid)

In [None]:
split_args_df[split_args_df.predicate == 'in_motion'].arg_1_type.value_counts()

In [None]:
stats._invert_intervals(regular_df[(regular_df.trace_id == trace_id) & (regular_df.predicate == 'in_motion') & (regular_df.arg_ids == (ball_id,))].intervals.values[0],
                        stats.trace_lengths[trace_id])

In [None]:
split_args_df[(split_args_df.trace_id == trace_id) & (split_args_df.predicate == 'in') & (split_args_df.arg_1_id == bin_id) & (split_args_df.arg_2_id == ball_id)]

In [None]:
ball_not_in_motion_int = split_args_stats._invert_intervals(
    split_args_df[(split_args_df.trace_id == trace_id) & (split_args_df.predicate == 'in_motion') & (split_args_df.arg_1_id == ball_id)].intervals.values[0],
    split_args_stats.trace_lengths_and_domains[trace_id][0],
)

ball_not_in_motion_int

In [None]:
ball_in_bin_int = split_args_df[(split_args_df.trace_id == trace_id) & (split_args_df.predicate == 'in') & (split_args_df.arg_1_id == bin_id) & (split_args_df.arg_2_id == ball_id)].intervals.values[0]

split_args_stats._intersect_intervals(ball_in_bin_int, ball_not_in_motion_int)

In [None]:
single_trace_split_args_df = split_args_df[split_args_df['trace_id'] == trace_id]

single_trace_split_args_stats = compile_predicate_statistics_split_args.CommonSensePredicateStatisticsSplitArgs(cache_dir)
single_trace_split_args_stats.data = pl.from_pandas(single_trace_split_args_df)
single_trace_split_args_stats.data.shape

In [None]:
DEFAULT_GRAMMAR_PATH = "../dsl/dsl.ebnf"
grammar = open(DEFAULT_GRAMMAR_PATH).read()
grammar_parser = typing.cast(tatsu.grammars.Grammar, tatsu.compile(grammar))

game = open(compile_predicate_statistics.get_project_dir() + '/reward-machine/games/ball_to_bin_from_bed.txt').read()
game_ast = grammar_parser.parse(game) 

test_pred_1 = game_ast[4][1]['preferences'][0]['definition']['forall_pref']['preferences']['pref_body']['body']['exists_args']['then_funcs'][1]['seq_func']['hold_pred']

# should be: (and (not (in_motion ?b)) (in ?h ?b)))
test_pred_2 = game_ast[4][1]['preferences'][0]['definition']['forall_pref']['preferences']['pref_body']['body']['exists_args']['then_funcs'][2]['seq_func']['once_pred']

In [None]:
test_mapping = {"?b": ["ball"], "?h": ["hexagonal_bin"]}
test_out = single_trace_split_args_stats.filter(test_pred_2, test_mapping)
test_out