In [1]:
%autoreload 2

In [2]:
from argparse import Namespace
from collections import defaultdict
import dataclasses
import copy
import gzip
import os
import pickle
import sys
import time
import typing
import re

from IPython.display import display, Markdown, Latex
import numpy as np
import tqdm.notebook as tqdm

sys.path.append(os.path.abspath('..'))
sys.path.append(os.path.abspath('../src'))
from src.ast_utils import load_games_from_file, _extract_game_id
from ast_counter_sampler import *

In [8]:
RANDOM_SEED = 33

args = Namespace(
    parse_counter=False, 
    counter_output_path=os.path.join('..', DEFAULT_COUNTER_OUTPUT_PATH),
    grammar_file=os.path.join('..', DEFAULT_GRAMMAR_FILE),
    random_seed=RANDOM_SEED,
    test_files=['../dsl/interactive-beta.pddl'],
    dont_tqdm=False,
)

grammar = open(args.grammar_file).read()
grammar_parser = tatsu.compile(grammar)
counter = parse_or_load_counter(args, grammar_parser)
sampler = ASTSampler(grammar_parser, counter, seed=args.random_seed)

PATTERN_RULES = {k: re.compile(v['pattern']) for k, v in sampler.rules.items() if 'pattern' in v}
PATTERN_RULES['predicate_name'] = PATTERN_RULES['name']
PATTERN_RULES['type_name'] = PATTERN_RULES['name']



No counted data for function_distance_side_4.arg_1
No counted data for function_distance_side_4.arg_2
No counted data for function_distance_side_4.arg_3
No counted data for function_distance_side_4.arg_4
No counted data for setup_not.not_args
String token rule for any: {'token_posterior': {'(any)': 1}, 'production': [('token', '(any)')]}
No counted data for hold_for.hold_pred
No counted data for hold_to_end.hold_pred
No counted data for forall_seq.forall_seq_vars
No counted data for forall_seq.forall_seq_then
No counted data for terminal_not.not_args
No counted data for scoring_external_minimize.scoring_expr
No counted data for scoring_and.and_args
No counted data for scoring_or.or_args
No counted data for scoring_not.not_args
String token rule for total_time: {'token_posterior': {'(total-time)': 1}, 'production': [('token', '(total-time)')]}
String token rule for total_score: {'token_posterior': {'(total-score)': 1}, 'production': [('token', '(total-score)')]}
No counted data for scor

In [13]:
@dataclasses.dataclass
class Expansion:
    rule: str
    expansions: typing.List[str] = dataclasses.field(default_factory=list)
    probability: float = 1.0

    def __str__(self):
        return f'{self.rule.upper()} &rarr; {" ".join(self.expansions)}    (P = {self.probability:.4f})'


def _add_to_expansions(expansions: typing.List[Expansion], addition: str, probability=None):
    for expansion in expansions:
        expansion.expansions.append(addition)

        if probability is not None:
            expansion.probability *= probability
    
    return expansions


def _multiple_add_to_components(expansions: typing.List[Expansion], additions_to_probabilities: typing.Dict[str, float]):
    if len(additions_to_probabilities) == 1:
        key = next(iter(additions_to_probabilities))
        return _add_to_expansions(expansions, key.upper(), additions_to_probabilities[key])

    original_length = len(expansions)
    temp = [copy.deepcopy(expansions) for _ in range(len(additions_to_probabilities))]
    expansions = [comp for component in temp for comp in component]

    for i, (addition, probability) in enumerate(additions_to_probabilities.items()):
        for j in range(original_length):
            current = expansions[i * original_length + j]
            current.expansions.append(addition.upper())
            if probability is not None:
                current.probability *= probability

    return expansions


def print_rule_cfg(rule: str, rules: typing.Dict[str, typing.Dict[str, typing.Any]],
    pattern_rules: typing.Dict[str, re.Pattern] = PATTERN_RULES):

    if rule not in rules:
        return [], [], {}

    rule_dict = rules[rule]

    if PRODUCTION not in rule_dict:
        raise ValueError(f'No production for {rule}')

    expansions = [Expansion(rule)]
    encountered_rules = []
    encountered_tokens = defaultdict(dict)

    for prod_type, prod_value in rule_dict[PRODUCTION]:
        if prod_type == RULE:
            expansions = _add_to_expansions(expansions, prod_value.upper())
            encountered_rules.append(prod_value)

        elif prod_type == TOKEN:
            if prod_value == SAMPLE:
                expansions = _multiple_add_to_components(expansions, rule_dict[TOKEN_POSTERIOR])

            elif prod_value != EOF:
                expansions = _add_to_expansions(expansions, prod_value)

        elif prod_type == NAMED:
            if prod_value not in rule_dict:
                raise ValueError(f'Unknown named {prod_value}')

            named_dict = rule_dict[prod_value]
            options = named_dict[OPTIONS]
            
            if isinstance(options, str):
                out = options.upper()
                if MIN_LENGTH in named_dict:
                    min_length = named_dict[MIN_LENGTH]
                    out = f'[ {out} ]{"*" if min_length == 0 else "+"}'

                expansions = _add_to_expansions(expansions, out)
                encountered_rules.append(options)
                options = [options]

            else:
                if MIN_LENGTH in named_dict:
                    raise ValueError(f'Min length not supported for multiple options {prod_value}')
                expansions = _multiple_add_to_components(expansions, 
                    {opt: (named_dict[RULE_POSTERIOR][opt] if (RULE_POSTERIOR in named_dict and opt in named_dict[RULE_POSTERIOR]) else None) 
                    for opt in options}  # type: ignore
                )
                encountered_rules.extend(options)

            if TOKEN_POSTERIOR in named_dict:
                for t, p in named_dict[TOKEN_POSTERIOR].items():
                    if t in rules: continue

                    if len(options) == 1:
                        encountered_tokens[options[0]][t] = p
                        continue

                    if isinstance(t, (int, float)):
                        encountered_tokens['number'][t] = p
                        continue

                    for opt in options:
                        if opt in pattern_rules and pattern_rules[opt].match(t):
                            encountered_tokens[opt][t] = p
                            break

        elif prod_type == PATTERN:
            expansions = _add_to_expansions(expansions, f'/{prod_value}/')

        else:
            raise ValueError(f'Unknown production type {prod_type}')

    sum_probs = sum(exp.probability for exp in expansions)
    for exp in expansions:
        exp.probability /= sum_probs

    return [str(exp) for exp in expansions], encountered_rules, encountered_tokens


def print_rules_cfg(rules: typing.Dict[str, typing.Dict[str, typing.Any]], start_rule: str = START):
    rule_frontier = [start_rule]
    encountered_rules = set(rule_frontier)
    
    cfg_rules = ['## CFG-like grammar printout']
    token_rule_expansions = defaultdict(dict)

    while rule_frontier:
        rule = rule_frontier.pop()
        cfgs, new_rules, new_tokens = print_rule_cfg(rule, rules)
        cfg_rules.extend(cfgs)

        for new_rule in reversed(new_rules):
            if new_rule not in encountered_rules:
                rule_frontier.append(new_rule)
                encountered_rules.add(new_rule)

        for rule, tokens in new_tokens.items():
            token_rule_expansions[rule].update(tokens)

    cfg_rules.append('## Token rule expansions')
    for rule, tokens in token_rule_expansions.items():
        for token, prob in tokens.items():
            cfg_rules.append(f'{rule.upper()} &rarr; {token}    (P = {prob:.4f})')

    return '\n\n'.join(cfg_rules)

markdown_output = print_rules_cfg(sampler.rules)
    

In [14]:
print(markdown_output)

## CFG-like grammar printout

START &rarr; (define GAME_DEF DOMAIN_DEF SETUP_DEF CONSTRAINTS_DEF TERMINAL_DEF SCORING_DEF )    (P = 1.0000)

GAME_DEF &rarr; (game ID )    (P = 1.0000)

ID &rarr; /[a-z0-9][a-z0-9\-]+/    (P = 1.0000)

DOMAIN_DEF &rarr; (:domain ID )    (P = 1.0000)

SETUP_DEF &rarr; (:setup SETUP )    (P = 1.0000)

SETUP &rarr; SETUP_AND    (P = 0.2464)

SETUP &rarr; SETUP_OR    (P = 0.0214)

SETUP &rarr; SETUP_NOT    (P = 0.0179)

SETUP &rarr; SETUP_EXISTS    (P = 0.2107)

SETUP &rarr; SETUP_FORALL    (P = 0.1357)

SETUP &rarr; SETUP_STATEMENT    (P = 0.3679)

SETUP_AND &rarr; (and [ SETUP ]+ )    (P = 1.0000)

SETUP_OR &rarr; (or [ SETUP ]+ )    (P = 1.0000)

SETUP_NOT &rarr; (not SETUP )    (P = 1.0000)

SETUP_EXISTS &rarr; (exists VARIABLE_LIST SETUP )    (P = 1.0000)

VARIABLE_LIST &rarr; ( [ VARIABLE_TYPE_DEF ]+ )    (P = 1.0000)

VARIABLE_TYPE_DEF &rarr; [ VARIABLE ]+ - TYPE_DEFINITION    (P = 1.0000)

VARIABLE &rarr; /\?[a-z][a-z0-9]*/    (P = 1.0000)

TYPE_DEFI

In [None]:
sampler.rules['game_def']

In [11]:
game_texts = list(load_games_from_file('../dsl/interactive-beta.pddl'))

In [24]:
sampler.rules['predicate_or_function_term']['term']['token_posterior']

{'agent': 0.053822937625754526,
 'upright': 0.005030181086519115,
 'top_drawer': 0.002012072434607646,
 'door': 0.007545271629778672,
 'bed': 0.025150905432595575,
 'desk': 0.03219315895372234,
 'floor': 0.011569416498993963,
 'front': 0.005030181086519115,
 'back': 0.0025150905432595573,
 'block': 0.0015090543259557343,
 'room_center': 0.01710261569416499,
 'rug': 0.01710261569416499,
 'upside_down': 0.001006036217303823,
 'left': 0.001006036217303823,
 'right': 0.001006036217303823,
 'main_light_switch': 0.0005030181086519115,
 'desktop': 0.0005030181086519115,
 'west_wall': 0.001006036217303823,
 'south_wall': 0.0005030181086519115,
 'pillow': 0.0015090543259557343,
 'side_table': 0.002012072434607646,
 'bridge_block': 0.001006036217303823,
 'pink_dodgeball': 0.006036217303822937,
 'ball': 0.0005030181086519115,
 'drawer': 0.0005030181086519115,
 'blinds': 0.0005030181086519115,
 'green_golfball': 0.002012072434607646,
 'bottom_shelf': 0.0005030181086519115,
 'top_shelf': 0.00251509