In [1]:
%autoreload 2

In [14]:
from argparse import Namespace
from collections import defaultdict, OrderedDict
import copy
import dataclasses
import gzip
import itertools
import os
import pickle
import sys
import time
import typing
import re

from IPython.display import display, Markdown, Latex
import numpy as np
import tqdm.notebook as tqdm

sys.path.append(os.path.abspath('..'))
sys.path.append(os.path.abspath('../src'))
from src.ast_utils import load_games_from_file, _extract_game_id
from ast_counter_sampler import *

In [3]:
RANDOM_SEED = 33
PRIOR_COUNT = 12

args = Namespace(
    parse_counter=False, 
    counter_output_path=os.path.join('..', DEFAULT_COUNTER_OUTPUT_PATH),
    grammar_file=os.path.join('..', DEFAULT_GRAMMAR_FILE),
    random_seed=RANDOM_SEED,
    test_files=['../dsl/interactive-beta.pddl'],
    dont_tqdm=False,
)

grammar = open(args.grammar_file).read()
grammar_parser = tatsu.compile(grammar)
counter = parse_or_load_counter(args, grammar_parser)

length_prior = {n: PRIOR_COUNT for n in LENGTH_PRIOR}
sampler = ASTSampler(grammar_parser, counter, seed=args.random_seed,
                     prior_rule_count=PRIOR_COUNT, prior_token_count=PRIOR_COUNT,
                     length_prior=length_prior)

PATTERN_RULES = {k: re.compile(v['pattern']) for k, v in sampler.rules.items() if 'pattern' in v}
PATTERN_RULES['predicate_name'] = PATTERN_RULES['name']
PATTERN_RULES['type_name'] = PATTERN_RULES['name']
PATTERN_RULES['location'] = PATTERN_RULES['name']



In [17]:
EXPANSION_STRING = '&rarr;'


@dataclasses.dataclass
class Expansion:
    rule: str
    expansions: typing.List[str] = dataclasses.field(default_factory=list)
    probability: float = 1.0
    expansion_string: str = EXPANSION_STRING

    def __str__(self):
        return f'{self.rule.upper()} {self.expansion_string} {" ".join(self.expansions)}    (P = {self.probability:.4f})'


def _add_to_expansions(expansions: typing.List[Expansion], addition: str, probability=None):
    for expansion in expansions:
        expansion.expansions.append(addition)

        if probability is not None:
            expansion.probability *= probability
    
    return expansions


def _multiple_add_to_components(expansions: typing.List[Expansion], additions_to_probabilities: typing.Dict[str, float]):
    if len(additions_to_probabilities) == 1:
        key = next(iter(additions_to_probabilities))
        return _add_to_expansions(expansions, key.upper(), additions_to_probabilities[key])

    original_length = len(expansions)
    temp = [copy.deepcopy(expansions) for _ in range(len(additions_to_probabilities))]
    expansions = [comp for component in temp for comp in component]

    for i, (addition, probability) in enumerate(additions_to_probabilities.items()):
        for j in range(original_length):
            current = expansions[i * original_length + j]
            current.expansions.append(addition.upper())
            if probability is not None:
                current.probability *= probability

    return expansions


def print_rule_cfg(rule: str, rules: typing.Dict[str, typing.Dict[str, typing.Any]],
    pattern_rules: typing.Dict[str, re.Pattern] = PATTERN_RULES, 
    expansion_string: str = EXPANSION_STRING
    ):

    if rule not in rules or rule == 'empty_closure':
        return [], [], {}

    rule_dict = rules[rule]

    if PRODUCTION not in rule_dict:
        raise ValueError(f'No production for {rule}')

    expansions = [Expansion(rule, expansion_string=expansion_string)]
    encountered_rules = []
    encountered_tokens = defaultdict(dict)

    for prod_type, prod_value in rule_dict[PRODUCTION]:
        if prod_type == RULE:
            expansions = _add_to_expansions(expansions, prod_value.upper())
            encountered_rules.append(prod_value)

        elif prod_type == TOKEN:
            if prod_value == SAMPLE:
                expansions = _multiple_add_to_components(expansions, rule_dict[TOKEN_POSTERIOR])

            elif prod_value != EOF:
                expansions = _add_to_expansions(expansions, prod_value)

        elif prod_type == NAMED:
            if prod_value not in rule_dict:
                raise ValueError(f'Unknown named {prod_value}')

            named_dict = rule_dict[prod_value]
            options = named_dict[OPTIONS]

            # Handle the rules that have a single named option that expands to a terminal
            if (rule, prod_value) in SPECIAL_RULE_FIELD_VALUE_TYPES:
                print(f'Found special rule {rule} {prod_value}')
                for t, p in named_dict[TOKEN_POSTERIOR].items():
                    encountered_tokens[rule][t] = p

                expansions = _add_to_expansions(expansions, '<TOKEN> (see below)')
            
            else:
                if isinstance(options, str):
                    out = options.upper()
                    if MIN_LENGTH in named_dict:
                        min_length = named_dict[MIN_LENGTH]
                        out = f'[ {out} ]{"*" if min_length == 0 else "+"}'

                    expansions = _add_to_expansions(expansions, out)
                    encountered_rules.append(options)
                    options = [options]

                else:
                    # if MIN_LENGTH in named_dict:
                        # continue
                        # raise ValueError(f'Min length not supported for multiple options {prod_value}')
                    expansions = _multiple_add_to_components(expansions, 
                        {opt: (named_dict[RULE_POSTERIOR][opt] if (RULE_POSTERIOR in named_dict and opt in named_dict[RULE_POSTERIOR]) else None) 
                        for opt in options}  # type: ignore
                    )
                    encountered_rules.extend(options)

                if TOKEN_POSTERIOR in named_dict:
                    for t, p in named_dict[TOKEN_POSTERIOR].items():
                        if t in rules: continue

                        if len(options) == 1:
                            encountered_tokens[options[0]][t] = p
                            continue

                        if isinstance(t, (int, float)):
                            encountered_tokens['number'][t] = p
                            continue

                        for opt in options:
                            if opt in pattern_rules and pattern_rules[opt].match(t):
                                encountered_tokens[opt][t] = p
                                break

        elif prod_type == PATTERN:
            expansions = _add_to_expansions(expansions, f'/{prod_value}/')

        else:
            raise ValueError(f'Unknown production type {prod_type}')

    if PRODUCTION_PROBABILITY in rule_dict:
        production_probability = rule_dict[PRODUCTION_PROBABILITY]
        for expansion in expansions:
            expansion.probability *= production_probability

        expansions.append(Expansion(rule, ['$\\epsilon$'], probability=1 - production_probability, expansion_string=expansion_string))

    sum_probs = sum(exp.probability for exp in expansions)
    for exp in expansions:
        exp.probability /= sum_probs

    return [str(exp) for exp in expansions], encountered_rules, encountered_tokens


def extract_cfg_rule_dict(rules: typing.Dict[str, typing.Dict[str, typing.Any]], start_rule: str = START, expansion_string: str = EXPANSION_STRING):
    rule_frontier = [start_rule]
    encountered_rules = set(rule_frontier)
    
    cfg_expansions_by_rule = OrderedDict()
    token_rule_expansions = defaultdict(dict)

    while rule_frontier:
        rule = rule_frontier.pop()
        cfg_expansions, new_rules, new_tokens = print_rule_cfg(rule, rules, expansion_string=expansion_string)
        cfg_expansions_by_rule[rule] = cfg_expansions

        for new_rule in reversed(new_rules):
            if new_rule not in encountered_rules:
                rule_frontier.append(new_rule)
                encountered_rules.add(new_rule)

        for rule, tokens in new_tokens.items():
            token_rule_expansions[rule].update(tokens)

    cfg_expansions_by_rule['token_rules'] = ['## Token rule expansions']
    for rule, tokens in token_rule_expansions.items():
        cfg_expansions_by_rule[rule] = [f'{rule.upper()} {expansion_string} {token}    (P = {prob:.4f})' for token, prob in tokens.items()]

    return cfg_expansions_by_rule


def print_rules_cfg(rules: typing.Dict[str, typing.Dict[str, typing.Any]], start_rule: str = START, expansion_string: str = EXPANSION_STRING):
    cfg_expansions_by_rule = extract_cfg_rule_dict(rules, start_rule, expansion_string=expansion_string)
    return '\n'.join(itertools.chain.from_iterable(cfg_expansions_by_rule[rule] for rule in cfg_expansions_by_rule))


cfg_expansions_by_rule = extract_cfg_rule_dict(sampler.rules, expansion_string='=>')
markdown_output = print_rules_cfg(sampler.rules, expansion_string='=>')
    

Found special rule object_type terminal
Found special rule color_type terminal
Found special rule color terminal
Found special rule orientation_type terminal
Found special rule orientation terminal
Found special rule side_type terminal
Found special rule side terminal
Found special rule predicate_or_function_term term
Found special rule predicate_or_function_side_term term
Found special rule predicate_or_function_orientation_term term
Found special rule predicate_or_function_color_term term
Found special rule predicate_or_function_type_term term
Found special rule object_name terminal
Found special rule object_type terminal
Found special rule color_type terminal
Found special rule color terminal
Found special rule orientation_type terminal
Found special rule orientation terminal
Found special rule side_type terminal
Found special rule side terminal
Found special rule predicate_or_function_term term
Found special rule predicate_or_function_side_term term
Found special rule predicate_or_

In [23]:
SECTION_TEMPLATE = """
\\subsection{{{title}}}
{prefix_text}

\\begin{{lstlisting}}
{rules_text}
\\end{{lstlisting}}

{suffix_text}

""".strip()


@dataclasses.dataclass
class PCFGSectionDefinition:
    title: str
    rules: typing.List[str]
    rule_comments: typing.Dict[str, str] = dataclasses.field(default_factory=dict)
    prefix_text: str = ''
    suffix_text: str = ''

    def format(self, rule_expansions: typing.Dict[str, typing.List[typing.Union[Expansion, str]]]):
        rule_texts = []
        for rule in self.rules:
            if rule not in rule_expansions:
                raise ValueError(f'Unknown rule {rule}')
            
            expansions = [str(exp) for exp in rule_expansions[rule]]

            if rule in self.rule_comments:
                expansions.insert(0, f'# {self.rule_comments[rule]}')

            rule_texts.append('\n'.join(expansions))

        rules_text = '\n'.join(rule_texts)
        return SECTION_TEMPLATE.format(title=self.title, prefix_text=self.prefix_text, rules_text=rules_text, suffix_text=self.suffix_text)




PCFG_SECTIONS = [
    PCFGSectionDefinition('Overall Game Definition', ['start', 'game_def', 'domain_def', 'id'],
                          dict(id='An ID is a string of letters, numbers, and dashes that starts with a letter or number.')),
]


for pcfg_section in PCFG_SECTIONS:
    print(pcfg_section.format(cfg_expansions_by_rule))

\subsection{Overall Game Definition}


\begin{lstlisting}
START => (define GAME_DEF DOMAIN_DEF SETUP_DEF CONSTRAINTS_DEF TERMINAL_DEF SCORING_DEF )    (P = 1.0000)
GAME_DEF => (game ID )    (P = 1.0000)
DOMAIN_DEF => (:domain ID )    (P = 1.0000)
# An ID is a string of letters, numbers, and dashes that starts with a letter or number.
ID => /[a-z0-9][a-z0-9\-]+/    (P = 1.0000)
\end{lstlisting}




In [16]:
print(markdown_output)

START => (define GAME_DEF DOMAIN_DEF SETUP_DEF CONSTRAINTS_DEF TERMINAL_DEF SCORING_DEF )    (P = 1.0000)
GAME_DEF => (game ID )    (P = 1.0000)
ID => /[a-z0-9][a-z0-9\-]+/    (P = 1.0000)
DOMAIN_DEF => (:domain ID )    (P = 1.0000)
SETUP_DEF => (:setup SETUP )    (P = 0.6020)
SETUP_DEF => $\epsilon$    (P = 0.3980)
SETUP => SETUP_AND    (P = 0.2322)
SETUP => SETUP_OR    (P = 0.0372)
SETUP => SETUP_NOT    (P = 0.0372)
SETUP => SETUP_EXISTS    (P = 0.2074)
SETUP => SETUP_FORALL    (P = 0.1455)
SETUP => SETUP_STATEMENT    (P = 0.3406)
SETUP_AND => (and [ SETUP ]+ )    (P = 1.0000)
SETUP_OR => (or [ SETUP ]+ )    (P = 1.0000)
SETUP_NOT => (not SETUP )    (P = 1.0000)
SETUP_EXISTS => (exists VARIABLE_LIST SETUP )    (P = 1.0000)
VARIABLE_LIST => ( VARIABLE_TYPE_DEF )    (P = 0.9252)
VARIABLE_LIST => ( COLOR_VARIABLE_TYPE_DEF )    (P = 0.0340)
VARIABLE_LIST => ( ORIENTATION_VARIABLE_TYPE_DEF )    (P = 0.0204)
VARIABLE_LIST => ( SIDE_VARIABLE_TYPE_DEF )    (P = 0.0204)
VARIABLE_TYPE_DEF => [

## Manual edits to make
* Change the way VARIABLE_LIST is written to expan to some generic list of variable definitions, which expands to each of the variable type defs

In [None]:
sampler.rules['game_def']

In [None]:
game_texts = list(load_games_from_file('../dsl/interactive-beta.pddl'))

In [None]:
sampler.rules['predicate_or_function_term']['term']['token_posterior']