In [3]:
%autoreload 2

In [25]:
from collections import defaultdict
import os
import sys
import typing

import numpy as np

sys.path.append(os.path.abspath('..'))
sys.path.append(os.path.abspath('../src'))
from src.ast_utils import load_games_from_file

In [5]:
import openai
openai.api_key = os.getenv('OPENAI_API_KEY')

In [6]:
game_texts = list(load_games_from_file('../dsl/interactive-beta.pddl'))

In [24]:
INSERT_TAG = '[insert]'
GAME_START = '(define'
SETUP_SECTION = '(:setup'
PREFERENCES_SECTION = '(:constraints'
TERMINAL_SECTION = '(:terminal'
SCORING_SECTION = '(:scoring'
SECTION_SUFFIX = ')\n'
GAME_END_SUFFIX = '))\n'


DEFAULT_RANDOM_SEED = 33


def split_game(game_text: str, 
    start_section_str: str, end_section_str: typing.Union[None, str, typing.Sequence[str]],
    suffix: str = SECTION_SUFFIX, game_end_suffix: str = GAME_END_SUFFIX) -> typing.Tuple[str, str]:
    start_index = game_text.find(start_section_str)

    if end_section_str is None:
        end_index = None

    elif isinstance(end_section_str, str):
        end_index = game_text.find(end_section_str)

    else:
        end_indices = [game_text.find(s) for s in end_section_str]
        end_index = min([i for i in end_indices if i >= 0])

    if start_index == -1:
        start_index = end_index

    if end_index == -1:
        end_index = None

    if end_index is None:
        return game_text[:start_index] + start_section_str, game_end_suffix

    return game_text[:start_index] + start_section_str, suffix + game_text[end_index:]
        
        
def create_multi_game_prompt_suffix(game_texts: typing.List[str],
    n_games_before: int, n_games_after: int, 
    target_game_split_func: typing.Callable[..., typing.Tuple[str, str]],
    target_game_split_func_args: typing.Dict[str, typing.Any],
    game_filter_str: typing.Optional[str] = None,
    rng: typing.Optional[np.random.Generator] = None, random_seed: int = DEFAULT_RANDOM_SEED
    ) -> typing.Tuple[str, str]:

    if rng is None:
        rng = np.random.default_rng(random_seed)

    if game_filter_str is not None:
        game_texts = [g for g in game_texts if game_filter_str in g]

    indices = rng.choice(len(game_texts), n_games_before + n_games_after + 1, replace=False)
    before_indices = indices[:n_games_before]
    target_index = indices[n_games_before]
    after_indices = indices[n_games_before + 1:]

    return create_multi_game_prompt_suffix_from_indices(game_texts, target_game_split_func, target_game_split_func_args, before_indices, target_index, after_indices)


def create_multi_game_prompt_suffix_from_indices(game_texts: typing.List[str], 
    target_game_split_func: typing.Callable[..., typing.Tuple[str, str]], 
    target_game_split_func_args: typing.Dict[str, typing.Any], 
    before_indices: typing.Sequence[int], 
    target_index: int, 
    after_indices: typing.Sequence[int]):

    before_games = [game_texts[i] for i in before_indices]
    target_prompt, target_suffix = target_game_split_func(game_texts[target_index], **target_game_split_func_args)
    after_games = [game_texts[i] for i in after_indices]

    before_games.append(target_prompt)
    after_games.insert(0, target_suffix)

    return '\n'.join(before_games), '\n'.join(after_games)


GAMES_WITH_SETUP = set([i for i, g in enumerate(game_texts) if SETUP_SECTION in g])
GAMES_WITH_SETUP_LIST = list(GAMES_WITH_SETUP)
GAMES_WITH_TERMINAL = set([i for i, g in enumerate(game_texts) if TERMINAL_SECTION in g])
GAMES_WITH_TERMINAL_LIST = list(GAMES_WITH_TERMINAL)


def create_all_prompts_for_game(game_texts: typing.List[str], 
    game_index: int, n_games_before: int, n_games_after: int, 
    rng: typing.Optional[np.random.Generator] = None, random_seed: int = DEFAULT_RANDOM_SEED):

    if rng is None:
        rng = np.random.default_rng(random_seed)

    setup_games = GAMES_WITH_SETUP_LIST[:]
    if game_index in setup_games:
        setup_games.remove(game_index)

    preferences_and_scoring_games = list(range(len(game_texts)))
    preferences_and_scoring_games.remove(game_index)

    terminal_games = GAMES_WITH_TERMINAL_LIST[:]
    if game_index in terminal_games:
        terminal_games.remove(game_index)

    setup_context_indices = rng.choice(setup_games, n_games_before + n_games_after, replace=False)
    preference_context_indices = rng.choice(preferences_and_scoring_games, n_games_before + n_games_after, replace=False)
    terminal_context_indices = rng.choice(terminal_games, n_games_before + n_games_after, replace=False)
    scoring_context_indice = rng.choice(preferences_and_scoring_games, n_games_before + n_games_after, replace=False)

    return {
        SETUP_SECTION: create_multi_game_prompt_suffix_from_indices(game_texts, split_game, 
            {'start_section_str': SETUP_SECTION, 'end_section_str': PREFERENCES_SECTION}, 
            setup_context_indices[:n_games_before], game_index, setup_context_indices[n_games_before:]),   # type: ignore
        PREFERENCES_SECTION: create_multi_game_prompt_suffix_from_indices(game_texts, split_game, 
            {'start_section_str': PREFERENCES_SECTION, 'end_section_str': (TERMINAL_SECTION, SCORING_SECTION)},
            preference_context_indices[:n_games_before], game_index, preference_context_indices[n_games_before:]),  # type: ignore
        TERMINAL_SECTION: create_multi_game_prompt_suffix_from_indices(game_texts, split_game, 
            {'start_section_str': TERMINAL_SECTION, 'end_section_str': SCORING_SECTION},
            terminal_context_indices[:n_games_before], game_index, terminal_context_indices[n_games_before:]),  # type: ignore
        SCORING_SECTION: create_multi_game_prompt_suffix_from_indices(game_texts, split_game, 
            {'start_section_str': SCORING_SECTION, 'end_section_str': None},
            scoring_context_indice[:n_games_before], game_index, scoring_context_indice[n_games_before:]),  # type: ignore
        
    }


In [21]:
for k, (prompt, suffix) in create_all_prompts_for_game(game_texts, 0, 2, 2).items():
    print(f'{k}:')
    print(prompt)
    print('*' * 20)
    print(suffix)
    print('=' * 80)

setup:
(define (game 61254c5a6facc8ed023a64de-48) (:domain medium-objects-room-v1)  
(:setup (and 
    (exists (?b - building ?h - hexagonal_bin) (game-conserved (and 
        (in ?b ?h)
        (>= (building_size ?b) 4) 
        (not (exists (?g - game_object) (and (in ?b ?g) (on ?h ?g))))
        (< (distance ?b room_center) 1)
    )))
))
(:constraints (and 
    (forall (?d - (either dodgeball basketball beachball))
        (preference ballThrownToBin (exists (?b - building ?h - hexagonal_bin)
            (then
                (once (agent_holds ?d))
                (hold (and (in_motion ?d) (not (agent_holds ?d))))
                (once (and (not (in_motion ?d)) (or (in ?h ?d) (on ?h ?d)) (or (in ?b ?h) (on ?b ?h))))
            )
        ))
    )
    (preference itemsHidingScreens 
        (exists (?s - (either desktop laptop) ?o - (either pillow doggie_bed teddy_bear)) 
            (at-end (on ?s ?o))    
        )
    )
    (preference objectsHidden
        (exists (?o - (either 

In [26]:
DEFAULT_CODEX_MODEL = "code-davinci-002"
DEFAULT_TEMPERATURE = 0.67
DEFAULT_MAX_TOKENS = 512
DEFAULT_STOP_SEQUENCES = [GAME_START, PREFERENCES_SECTION, TERMINAL_SECTION, SCORING_SECTION]
MAX_N = 10


DEFAULT_COMPLETION_KWARGS = dict(
    model=DEFAULT_CODEX_MODEL,
    temperature=DEFAULT_TEMPERATURE,
    max_tokens=DEFAULT_MAX_TOKENS,
    stop=DEFAULT_STOP_SEQUENCES,
)


def generate_codex_completions(prompt: str, suffix: str, n: int, 
    completion_kwargs: typing.Optional[typing.Dict[str, typing.Any]],
    ):

    if n > MAX_N:
        raise ValueError(f'n must be <= {MAX_N}')

    if completion_kwargs is None:
        completion_kwargs = DEFAULT_COMPLETION_KWARGS

    else:
        kwargs = DEFAULT_COMPLETION_KWARGS.copy()
        kwargs.update(completion_kwargs)
        completion_kwargs = kwargs

    completion_kwargs['prompt'] = prompt
    completion_kwargs['suffix'] = suffix
    completion_kwargs['n'] = n

    return openai.Completion.create(**completion_kwargs)
    

N_AUGMENTATIONS_PER_SECTION = {
    SETUP_SECTION: 3,
    PREFERENCES_SECTION: 6,
    TERMINAL_SECTION: 3,
    SCORING_SECTION: 3,
}


def generate_single_game_augmentations(game_texts: typing.List[str], 
    game_index: int, n_games_before: int, n_games_after: int, 
    rng: typing.Optional[np.random.Generator] = None, random_seed: int = DEFAULT_RANDOM_SEED,
    completion_kwargs_by_section: typing.Optional[typing.Dict[str, typing.Dict[str, typing.Any]]] = None
    ):

    if completion_kwargs_by_section is None:
        completion_kwargs_by_section = defaultdict(dict)

    prompts_by_section = create_all_prompts_for_game(game_texts, game_index, n_games_before, n_games_after, rng, random_seed)
    results_by_section = {}
    for section in prompts_by_section:
        prompt, suffix = prompts_by_section[section]
        results_by_section[section] = generate_codex_completions(prompt, suffix, N_AUGMENTATIONS_PER_SECTION[section], completion_kwargs_by_section[section])

    return results_by_section

In [27]:
results_by_section = generate_single_game_augmentations(game_texts, 0, 2, 2)

In [31]:
for section in results_by_section:
    for i, completion in enumerate(results_by_section[section].choices):
        print(f'\t\t{section} {i}: [{completion.finish_reason}] \n{completion.text}')

    print('=' * 80)

		(:setup 0: [stop] 
 (and 
    (exists (?w - wall) (game-conserved (and
        (>= (wall_size ?w) 4)
        (not (exists (?g - game_object) (and (touch ?w ?g) (on ?g ?w))))
    )))
    (exists (?d - doggie_bed) (game-conserved 
        (on ?d ?w)
    ))
))

		(:setup 1: [stop] 
 (and 
    (exists (?d - doggie_bed) (game-conserved (< (distance room_center ?d) 1)))
))

		(:setup 2: [stop] 
 (and 
    (exists (?r - triangular_ramp ?h - hexagonal_bin ?b - ball) (game-conserved (and
        (in ?r ?h) 
        (in ?h ?b) 
        (object_orientation ?h upright) 
    )))
))

		(:constraints 0: [stop] 
 (and 
    (forall (?b - ball)
        (preference ballKnocked (exists (?r - triangular_ramp ?b2 - ball)
            (then
                (once (and (agent_holds ?b) (on ?r agent)))
                (hold-while (and (not (agent_holds ?b)) (in_motion ?b))
                    (touch ?b ?b2)
                    (in_motion ?b2)
                )
            )
        ))
    )
    (forall (?b - b

In [34]:
for section in results_by_section:
    for i, completion in enumerate(results_by_section[section].choices):
        text = completion.text
        print(f'{section} {i}: [{completion.finish_reason}] (: {text.count("(")} ): {text.count(")")}')

(:setup 0: [stop] (: 17 ): 18
(:setup 1: [stop] (: 6 ): 7
(:setup 2: [stop] (: 8 ): 9
(:constraints 0: [stop] (: 36 ): 37
(:constraints 1: [stop] (: 37 ): 38
(:constraints 2: [length] (: 89 ): 80
(:constraints 3: [stop] (: 44 ): 44
(:constraints 4: [stop] (: 42 ): 42
(:constraints 5: [stop] (: 25 ): 25
(:terminal 0: [stop] (: 2 ): 3
(:terminal 1: [stop] (: 2 ): 2
(:terminal 2: [stop] (: 2 ): 2
(:scoring 0: [stop] (: 3 ): 5
(:scoring 1: [stop] (: 6 ): 6
(:scoring 2: [stop] (: 6 ): 6


# What's between these and a working augmented dataset:

1. Minor sanity checks, e.g. that the number of parentheses matches (or otherwise, remove trailing close parentheses)
2. Decide what to do about productions truncated for length (probably remove last partial preference)
3. Decide on a procedure to create games (sample from within the productions for a given game? between productions for differeng games?)
4. Generate a dataset with the procedure I have in mind, verify that the ASTs parse
5. Run it through the fitness function, compare the results to the real data
6. Create regrown corruptions for it, and repeat the previous synthetic experiments.
