In [1]:
%autoreload 2

In [45]:
from argparse import Namespace
from collections import defaultdict
import copy
from datetime import datetime
import difflib
import gzip
import itertools
import os
import pickle
import sys
import typing

from IPython.display import display, Markdown, HTML  # type: ignore
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import polars as pl
import tatsu
import tqdm.notebook as tqdmn


sys.path.append(os.path.abspath('..'))
sys.path.append(os.path.abspath('../reward-machine'))

import compile_predicate_statistics
import compile_predicate_statistics_split_args

In [41]:
cache_dir = compile_predicate_statistics.get_project_dir() + '/reward-machine/caches'

regular_df = pd.read_pickle(os.path.join(cache_dir, 'predicate_statistics.pkl'))
split_args_df = pd.read_pickle(os.path.join(cache_dir, 'predicate_statistics_split_args.pkl'))

stats = compile_predicate_statistics.CommonSensePredicateStatistics(cache_dir)
split_args_stats = compile_predicate_statistics_split_args.CommonSensePredicateStatisticsSplitArgs(cache_dir)

Loaded data with shape (43521, 6) from /Users/guydavidson/projects/game-generation-modeling/reward-machine/caches/predicate_statistics.pkl
Loaded data with shape (43521, 8) from /Users/guydavidson/projects/game-generation-modeling/reward-machine/caches/predicate_statistics_split_args.pkl


Debugging `(and (not (in_motion ?b)) (in ?h ?b)))`

In [26]:
trace_id = '7r4cgxJHzLJooFaMG1Rd-gameplay-attempt-1-rerecorded'
ball_id = 'Dodgeball|+00.19|+01.13|-02.80'
bin_id = 'GarbageCan|+00.75|-00.03|-02.74'
arg_ids = (bin_id, ball_id)
mapping = {'?b': ['ball'], '?h': ['hexagonal_bin']}
arg_types = ('garbagecan' ,'dodgeball')
regular_df[(regular_df.trace_id == trace_id) & (regular_df.predicate == 'in') & (regular_df.arg_ids == arg_ids)]

Unnamed: 0,predicate,arg_ids,arg_types,trace_id,domain,intervals
21566,in,"(GarbageCan|+00.75|-00.03|-02.74, Dodgeball|+0...","(garbagecan, dodgeball)",7r4cgxJHzLJooFaMG1Rd-gameplay-attempt-1-rereco...,many,"[[3304, 3640]]"


In [78]:
regular_df[regular_df.arg_types.apply(lambda x: len(x) == 2 and x[0] == 'desk' and 'block' in x[1]) & (regular_df.predicate == 'adjacent')]

Unnamed: 0,predicate,arg_ids,arg_types,trace_id,domain,intervals
3869,adjacent,"(Desk|+03.14|00.00|-01.41, CubeBlock|-00.02|+0...","(desk, cubeblock)",6ZjBeRCvHxG05ORmhInj-gameplay-attempt-1-rereco...,few,"[[1597, 2479]]"
3878,adjacent,"(Desk|+03.14|00.00|-01.41, CubeBlock|-00.02|+0...","(desk, cubeblock)",6ZjBeRCvHxG05ORmhInj-gameplay-attempt-1-rereco...,few,"[[1655, 1658]]"
3884,adjacent,"(Desk|+03.14|00.00|-01.41, CubeBlock|-00.24|+0...","(desk, cubeblock)",6ZjBeRCvHxG05ORmhInj-gameplay-attempt-1-rereco...,few,"[[1711, 2479]]"
3904,adjacent,"(Desk|+03.14|00.00|-01.41, CubeBlock|+00.20|+0...","(desk, cubeblock)",6ZjBeRCvHxG05ORmhInj-gameplay-attempt-1-rereco...,few,"[[1766, 2479]]"
3924,adjacent,"(Desk|+03.14|00.00|-01.41, CubeBlock|+00.20|+0...","(desk, cubeblock)",6ZjBeRCvHxG05ORmhInj-gameplay-attempt-1-rereco...,few,"[[1844, 2479]]"
9252,adjacent,"(Desk|+03.14|00.00|-01.41, TallRectBlock|-02.9...","(desk, tallrectblock)",IhOkh1l3TBY9JJVubzHx-gameplay-attempt-1-rereco...,many,"[[1348, 1377]]"
10436,adjacent,"(Desk|+03.14|00.00|-01.41, CubeBlock|-00.24|+0...","(desk, cubeblock)",vfh1MTEQorWXKy8jOP1x-gameplay-attempt-2-rereco...,few,"[[1117, 1121]]"
18300,adjacent,"(Desk|+03.14|00.00|-01.41, CubeBlock|-02.97|+0...","(desk, cubeblock)",IvoZWi01FO2uiNpNHyci-preCreateGame-rerecorded,many,"[[1351, 1923]]"
18306,adjacent,"(Desk|+03.14|00.00|-01.41, LongCylinderBlock|-...","(desk, longcylinderblock)",IvoZWi01FO2uiNpNHyci-preCreateGame-rerecorded,many,"[[1418, 1442]]"
39066,adjacent,"(Desk|+03.14|00.00|-01.41, CubeBlock|-00.02|+0...","(desk, cubeblock)",NJUY0YT1Pq6dZXsmw0wE-createGame-rerecorded,few,"[[1821, 1841], [1842, 1853], [1879, 4464]]"


In [79]:
split_args_stats.data.columns

['predicate',
 'arg_1_id',
 'arg_1_type',
 'arg_2_id',
 'arg_2_type',
 'trace_id',
 'domain',
 'intervals']

In [171]:

names

['?a', '?b']

In [178]:
domains = [val for row in split_args_stats.data.select('domain').unique().rows() for val in row]
variable_types = (('ball', 'dodgeball'),)  #  ('chair', 'bed'))
field_names = [f'?{chr(97 + i)}' for i in range(len(variable_types))]
possible_arg_assignments = [split_args_stats._object_assignments(domain, variable_types) for domain in domains]

assignments_df = pl.DataFrame(dict(domain=domains, assignments=possible_arg_assignments, intervals=[[]] * len(domains)),
                              schema=dict(domain=None, assignments=None, intervals=pl.List(pl.List(pl.Int64))))
print(assignments_df.shape)

tdf = split_args_stats.data.select('domain', 'trace_id').join(assignments_df, on='domain')
tdf.explode('assignments').select('domain', 'trace_id', pl.col("assignments").list.to_struct(fields=field_names), 'intervals').unnest('assignments')

(3, 3)


domain,trace_id,?a,intervals
str,str,str,list[list[i64]]
"""many""","""1HOTuIZpRqk2u1…","""Beachball|+02.…",[]
"""many""","""1HOTuIZpRqk2u1…","""Dodgeball|+00.…",[]
"""many""","""1HOTuIZpRqk2u1…","""Dodgeball|+00.…",[]
"""many""","""1HOTuIZpRqk2u1…","""Dodgeball|+00.…",[]
"""many""","""1HOTuIZpRqk2u1…","""Golfball|+00.9…",[]
"""many""","""1HOTuIZpRqk2u1…","""Golfball|+01.0…",[]
"""many""","""1HOTuIZpRqk2u1…","""Golfball|+01.1…",[]
"""many""","""1HOTuIZpRqk2u1…","""Dodgeball|+00.…",[]
"""many""","""1HOTuIZpRqk2u1…","""Dodgeball|+00.…",[]
"""many""","""1HOTuIZpRqk2u1…","""Dodgeball|+00.…",[]


In [159]:
assignment = possible_arg_assignments[0]
var_keys = ['?a', '?b']

test_df = pl.DataFrame(dict(domain=['few'], trace_id=['foo'], assignments=[assignment]))

test_df.explode('assignments').select('domain', 'trace_id', pl.col("assignments").list.to_struct(fields=['?a', '?b'])).unnest('assignments')

domain,trace_id,?a,?b
str,str,str,str
"""few""","""foo""","""Dodgeball|-02.…","""Chair|+02.73|0…"
"""few""","""foo""","""Dodgeball|-02.…","""Chair|+02.83|0…"
"""few""","""foo""","""Dodgeball|-02.…","""Bed|-02.46|00.…"
"""few""","""foo""","""Dodgeball|-02.…","""Chair|+02.73|0…"
"""few""","""foo""","""Dodgeball|-02.…","""Chair|+02.83|0…"
"""few""","""foo""","""Dodgeball|-02.…","""Bed|-02.46|00.…"
"""few""","""foo""","""Dodgeball|-02.…","""Chair|+02.73|0…"
"""few""","""foo""","""Dodgeball|-02.…","""Chair|+02.83|0…"
"""few""","""foo""","""Dodgeball|-02.…","""Bed|-02.46|00.…"
"""few""","""foo""","""Dodgeball|-02.…","""Chair|+02.73|0…"


In [181]:
list(split_args_df.trace_id.unique())

['1HOTuIZpRqk2u1nZI1v1-gameplay-attempt-1-rerecorded',
 'IvoZWi01FO2uiNpNHyci-createGame-rerecorded',
 '4WUtnD8W6PGVy0WBtVm4-gameplay-attempt-1-rerecorded',
 'LTZh4k4THamxI5QJfVrk-gameplay-attempt-1-rerecorded',
 'WtZpe3LQFZiztmh7pBBC-gameplay-attempt-1-rerecorded',
 'FyGQn1qJCLTLU1hfQfZ2-preCreateGame-rerecorded',
 '6ZjBeRCvHxG05ORmhInj-gameplay-attempt-1-rerecorded',
 'Tcfpwc8v8HuKRyZr5Dyc-gameplay-attempt-2-rerecorded',
 '4WUtnD8W6PGVy0WBtVm4-createGame-rerecorded',
 '39PytL3fAMFkYXNoB5l6-gameplay-attempt-1-rerecorded',
 '5lTRHBueXsaOu9yhvOQo-gameplay-attempt-1-rerecorded',
 'SQErBa5s5TPVxmm8R6ks-freePlay-rerecorded',
 '9C0wMm4lzrJ5JeP0irIu-preCreateGame-rerecorded',
 'f2WUeVzu41E9Lmqmr2FJ-preCreateGame-rerecorded',
 '6XD5S6MnfzAPQlsP7k30-gameplay-attempt-2-rerecorded',
 'xMUrxzK3fXjgitdzPKsm-freePlay-rerecorded',
 'IhOkh1l3TBY9JJVubzHx-gameplay-attempt-1-rerecorded',
 'WtZpe3LQFZiztmh7pBBC-createGame-rerecorded',
 'vfh1MTEQorWXKy8jOP1x-gameplay-attempt-2-rerecorded',
 'LTZh4k4THamx

In [34]:
stats._invert_intervals(regular_df[(regular_df.trace_id == trace_id) & (regular_df.predicate == 'in_motion') & (regular_df.arg_ids == (ball_id,))].intervals.values[0],
                        stats.trace_lengths[trace_id])

[[0, 519],
 [559, 598],
 [624, 726],
 [730, 1190],
 [1191, 1623],
 [1679, 1884],
 [1924, 2081],
 [2090, 2312],
 [2367, 2474],
 [2568, 2946],
 [2985, 3054],
 [3118, 3145],
 [3176, 3200],
 [3256, 3289],
 [3311, 3312],
 [3327, 3537],
 [3538, 3640]]

In [36]:
split_args_df[(split_args_df.trace_id == trace_id) & (split_args_df.predicate == 'in') & (split_args_df.arg_1_id == bin_id) & (split_args_df.arg_2_id == ball_id)]

Unnamed: 0,predicate,arg_1_id,arg_1_type,arg_2_id,arg_2_type,trace_id,domain,intervals
21566,in,GarbageCan|+00.75|-00.03|-02.74,garbagecan,Dodgeball|+00.19|+01.13|-02.80,dodgeball,7r4cgxJHzLJooFaMG1Rd-gameplay-attempt-1-rereco...,many,"[[3304, 3640]]"


In [43]:
ball_not_in_motion_int = split_args_stats._invert_intervals(
    split_args_df[(split_args_df.trace_id == trace_id) & (split_args_df.predicate == 'in_motion') & (split_args_df.arg_1_id == ball_id)].intervals.values[0],
    split_args_stats.trace_lengths_and_domains[trace_id][0],
)

ball_not_in_motion_int

[[0, 519],
 [559, 598],
 [624, 726],
 [730, 1190],
 [1191, 1623],
 [1679, 1884],
 [1924, 2081],
 [2090, 2312],
 [2367, 2474],
 [2568, 2946],
 [2985, 3054],
 [3118, 3145],
 [3176, 3200],
 [3256, 3289],
 [3311, 3312],
 [3327, 3537],
 [3538, 3640]]

In [44]:
ball_in_bin_int = split_args_df[(split_args_df.trace_id == trace_id) & (split_args_df.predicate == 'in') & (split_args_df.arg_1_id == bin_id) & (split_args_df.arg_2_id == ball_id)].intervals.values[0]

split_args_stats._intersect_intervals(ball_in_bin_int, ball_not_in_motion_int)

[[3311, 3312], [3327, 3537], [3538, 3640]]

In [47]:
single_trace_split_args_df = split_args_df[split_args_df['trace_id'] == trace_id]

single_trace_split_args_stats = compile_predicate_statistics_split_args.CommonSensePredicateStatisticsSplitArgs(cache_dir)
single_trace_split_args_stats.data = pl.from_pandas(single_trace_split_args_df)
single_trace_split_args_stats.data.shape

Loaded data with shape (43521, 8) from /Users/guydavidson/projects/game-generation-modeling/reward-machine/caches/predicate_statistics_split_args.pkl


(901, 8)

In [49]:
DEFAULT_GRAMMAR_PATH = "../dsl/dsl.ebnf"
grammar = open(DEFAULT_GRAMMAR_PATH).read()
grammar_parser = typing.cast(tatsu.grammars.Grammar, tatsu.compile(grammar))

game = open(compile_predicate_statistics.get_project_dir() + '/reward-machine/games/ball_to_bin_from_bed.txt').read()
game_ast = grammar_parser.parse(game) 

test_pred_1 = game_ast[4][1]['preferences'][0]['definition']['forall_pref']['preferences']['pref_body']['body']['exists_args']['then_funcs'][1]['seq_func']['hold_pred']

# should be: (and (not (in_motion ?b)) (in ?h ?b)))
test_pred_2 = game_ast[4][1]['preferences'][0]['definition']['forall_pref']['preferences']['pref_body']['body']['exists_args']['then_funcs'][2]['seq_func']['once_pred']

In [55]:
test_mapping = {"?b": ["ball"], "?h": ["hexagonal_bin"]}
test_out = single_trace_split_args_stats.filter(test_pred_2, test_mapping)
test_out

{('7r4cgxJHzLJooFaMG1Rd-gameplay-attempt-1-rerecorded',
  ('?b->Dodgeball|+00.19|+01.13|-02.80',
   '?h->GarbageCan|+00.75|-00.03|-02.74')): [[3327, 3537], [3538, 3640]],
 ('7r4cgxJHzLJooFaMG1Rd-gameplay-attempt-1-rerecorded',
  ('?b->Dodgeball|+00.70|+01.11|-02.80',
   '?h->GarbageCan|+00.75|-00.03|-02.74')): [[2530, 3205],
  [3207, 3295],
  [3311, 3314],
  [3327, 3537],
  [3547, 3640]]}