In [78]:
%autoreload 2

In [79]:
from argparse import Namespace
from collections import defaultdict
import copy
from datetime import datetime
import difflib
import gzip
import itertools
import os
import pickle
import sys
import typing

from IPython.display import display, Markdown, HTML  # type: ignore
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import polars as pl
import tatsu
import tqdm.notebook as tqdmn


sys.path.append(os.path.abspath('..'))
sys.path.append(os.path.abspath('../reward-machine'))

import compile_predicate_statistics
import compile_predicate_statistics_split_args

In [80]:
cache_dir = compile_predicate_statistics.get_project_dir() + '/reward-machine/caches'

# regular_df = pd.read_pickle(os.path.join(cache_dir, 'predicate_statistics.pkl'))
split_args_df = pd.read_pickle(os.path.join(cache_dir, 'predicate_statistics_4d5dd602.pkl.gz'))

# stats = compile_predicate_statistics.CommonSensePredicateStatistics(cache_dir)
split_args_stats = compile_predicate_statistics_split_args.CommonSensePredicateStatisticsSplitArgs(
    cache_dir, compile_predicate_statistics_split_args.CURRENT_TEST_TRACE_NAMES, overwrite=False)

Loaded data with shape (309215, 8) from /Users/guydavidson/projects/game-generation-modeling/reward-machine/caches/predicate_statistics_4d5dd602.pkl.gz


In [81]:
split_args_df[split_args_df.predicate == 'object_orientation'].groupby(['arg_1_id', 'arg_2_id']).count()
split_args_df[split_args_df.predicate == 'object_orientation'].groupby(['arg_2_id']).count()

Unnamed: 0_level_0,predicate,arg_1_id,arg_1_type,arg_2_type,trace_id,domain,intervals
arg_2_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
diagonal,553,553,553,553,553,553,553
sideways,474,474,474,474,474,474,474
upright,5882,5882,5882,5882,5882,5882,5882
upside_down,358,358,358,358,358,358,358


In [83]:
split_args_df[split_args_df.predicate == 'broken']

Unnamed: 0,predicate,arg_1_id,arg_1_type,arg_2_id,arg_2_type,trace_id,domain,intervals
80780,broken,Mirror|+00.45|+01.49|+00.62,mirror,,,LTZh4k4THamxI5QJfVrk-preCreateGame-rerecorded,few,"[[950, 1069]]"
83586,broken,Window|+02.28|+00.93|-03.18,window,,,79X7tsrbEIu5ffDGnY8q-gameplay-attempt-1-rereco...,many,"[[1161, 2079]]"
84095,broken,sliding_door,sliding_door,,,79X7tsrbEIu5ffDGnY8q-gameplay-attempt-1-rereco...,many,"[[1161, 2079]]"
89393,broken,CellPhone|+02.96|+00.79|-00.93,cellphone,,,jCc0kkmGUg3xUmUSXg5w-gameplay-attempt-1-rereco...,few,"[[2049, 4588]]"
92344,broken,Window|+02.28|+00.93|-03.18,window,,,ktwB7wT09sh4ivNme3Dw-createGame-rerecorded,medium,"[[1054, 1182]]"
92651,broken,sliding_door,sliding_door,,,ktwB7wT09sh4ivNme3Dw-createGame-rerecorded,medium,"[[1054, 1182]]"
118644,broken,Mirror|+00.45|+01.49|+00.62,mirror,,,IvoZWi01FO2uiNpNHyci-freePlay-rerecorded,many,"[[586, 785]]"
202110,broken,Mirror|+00.45|+01.49|+00.62,mirror,,,Tcfpwc8v8HuKRyZr5Dyc-gameplay-attempt-1-rereco...,medium,"[[51, 5090]]"
202351,broken,Window|+02.28|+00.93|-03.18,window,,,Tcfpwc8v8HuKRyZr5Dyc-gameplay-attempt-1-rereco...,medium,"[[2942, 5090]]"
202754,broken,sliding_door,sliding_door,,,Tcfpwc8v8HuKRyZr5Dyc-gameplay-attempt-1-rereco...,medium,"[[2942, 5090]]"


In [33]:
door_df = split_args_df[(split_args_df.arg_1_type == 'door') | (split_args_df.arg_1_type == 'door')]
print(door_df.predicate.value_counts())
print()
for predicate in door_df.predicate.unique():
    pred_df = door_df[door_df.predicate == predicate]
    print(f'For predicate {predicate}:')
    print(pred_df.arg_2_type.value_counts().iloc[:10])
    print()

adjacent    338
touch        16
on           15
Name: predicate, dtype: int64

For predicate adjacent:
main_light_switch    77
mirror               77
dodgeball            43
agent                36
pink_dodgeball       19
red_dodgeball        14
golfball             10
blue_dodgeball       10
beachball             5
cube_block            5
Name: arg_2_type, dtype: int64

For predicate touch:
chair              4
book               3
dodgeball          3
red_dodgeball      2
cd                 1
triangular_ramp    1
hexagonal_bin      1
pink_dodgeball     1
Name: arg_2_type, dtype: int64

For predicate on:
chair              4
dodgeball          3
book               2
red_dodgeball      2
cd                 1
triangular_ramp    1
hexagonal_bin      1
pink_dodgeball     1
Name: arg_2_type, dtype: int64



In [30]:
for tid in door_df[door_df.predicate == 'in_motion'].trace_id:
    print(tid)

jCc0kkmGUg3xUmUSXg5w-gameplay-attempt-1-rerecorded
jCc0kkmGUg3xUmUSXg5w-preCreateGame-rerecorded
NJUY0YT1Pq6dZXsmw0wE-createGame-rerecorded


In [32]:
split_args_df[split_args_df.predicate == 'in_motion'].arg_1_type.value_counts()

dodgeball                 101
cube_block                 97
key_chain                  63
agent                      58
hexagonal_bin              48
pillow                     45
pink_dodgeball             41
chair                      40
tan_cube_block             34
golfball                   33
beachball                  32
blue_dodgeball             32
blue_cube_block            32
yellow_cube_block          31
red_dodgeball              28
watch                      26
laptop                     24
book                       22
teddy_bear                 22
cellphone                  21
credit_card                18
desktop                    17
doggie_bed                 17
curved_wooden_ramp         16
bridge_block               16
pen                        15
triangular_ramp            14
drawer                     13
cd                         13
mug                        13
pencil                     12
green_golfball             10
lamp                       10
alarm_cloc

In [None]:
stats._invert_intervals(regular_df[(regular_df.trace_id == trace_id) & (regular_df.predicate == 'in_motion') & (regular_df.arg_ids == (ball_id,))].intervals.values[0],
                        stats.trace_lengths[trace_id])

In [None]:
split_args_df[(split_args_df.trace_id == trace_id) & (split_args_df.predicate == 'in') & (split_args_df.arg_1_id == bin_id) & (split_args_df.arg_2_id == ball_id)]

In [None]:
ball_not_in_motion_int = split_args_stats._invert_intervals(
    split_args_df[(split_args_df.trace_id == trace_id) & (split_args_df.predicate == 'in_motion') & (split_args_df.arg_1_id == ball_id)].intervals.values[0],
    split_args_stats.trace_lengths_and_domains[trace_id][0],
)

ball_not_in_motion_int

In [None]:
ball_in_bin_int = split_args_df[(split_args_df.trace_id == trace_id) & (split_args_df.predicate == 'in') & (split_args_df.arg_1_id == bin_id) & (split_args_df.arg_2_id == ball_id)].intervals.values[0]

split_args_stats._intersect_intervals(ball_in_bin_int, ball_not_in_motion_int)

In [None]:
single_trace_split_args_df = split_args_df[split_args_df['trace_id'] == trace_id]

single_trace_split_args_stats = compile_predicate_statistics_split_args.CommonSensePredicateStatisticsSplitArgs(cache_dir)
single_trace_split_args_stats.data = pl.from_pandas(single_trace_split_args_df)
single_trace_split_args_stats.data.shape

In [None]:
DEFAULT_GRAMMAR_PATH = "../dsl/dsl.ebnf"
grammar = open(DEFAULT_GRAMMAR_PATH).read()
grammar_parser = typing.cast(tatsu.grammars.Grammar, tatsu.compile(grammar))

game = open(compile_predicate_statistics.get_project_dir() + '/reward-machine/games/ball_to_bin_from_bed.txt').read()
game_ast = grammar_parser.parse(game) 

test_pred_1 = game_ast[4][1]['preferences'][0]['definition']['forall_pref']['preferences']['pref_body']['body']['exists_args']['then_funcs'][1]['seq_func']['hold_pred']

# should be: (and (not (in_motion ?b)) (in ?h ?b)))
test_pred_2 = game_ast[4][1]['preferences'][0]['definition']['forall_pref']['preferences']['pref_body']['body']['exists_args']['then_funcs'][2]['seq_func']['once_pred']

In [None]:
test_mapping = {"?b": ["ball"], "?h": ["hexagonal_bin"]}
test_out = single_trace_split_args_stats.filter(test_pred_2, test_mapping)
test_out