In [1]:
%autoreload 2

In [39]:
from collections import defaultdict
from argparse import Namespace
import copy
import gzip
import itertools
import os
import pickle
import sys
import typing

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import tatsu
import tatsu.ast
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
import tqdm
import sklearn
from sklearn.model_selection import GridSearchCV, train_test_split, KFold
from sklearn.pipeline import Pipeline

sys.path.append(os.path.abspath('..'))
sys.path.append(os.path.abspath('../src'))
from src import fitness_energy_utils as utils
from src.fitness_energy_utils import NON_FEATURE_COLUMNS
from src.fitness_function import build_fitness_featurizer, ASTFitnessFeaturizer
from src.ast_counter_sampler import *
from src.ast_counter_sampler import parse_or_load_counter

# Plan of attack
* Use the code in `ast_counter_sampler.py` to generate a sample from the MLE
* Score it with a fitness function adapted to working with a single example at a time
* At each subsequent iteration, regrow the game from a random node, and score the regrowth
* Since we have energy scores before and after, we can either accept greedily if $E_{new} < E_{old}$, or accept with probability $\exp(\beta (E_{old} - E_{new} ))$

In [61]:
DEFUALT_RANDOM_SEED = 33
DEFAULT_FITNESS_FUNCTION_PATH = '../models/cv_fitness_model_2022_12_05.pkl.gz'


def _load_and_wrap_fitness_function(fitness_function_path: str = DEFAULT_FITNESS_FUNCTION_PATH) -> typing.Callable[[torch.Tensor], float]:
    with gzip.open(fitness_function_path, 'rb') as f:
        cv_fitness_model = pickle.load(f)

    def _wrap_fitness(features: torch.Tensor):
        return cv_fitness_model.transform(features).item()

    return _wrap_fitness


DEFAULT_ARGS = Namespace(
    grammar_file=os.path.join('..', DEFAULT_GRAMMAR_FILE),
    parse_counter=False,
    counter_output_path=os.path.join('..', DEFAULT_COUNTER_OUTPUT_PATH),
    random_seed=DEFUALT_RANDOM_SEED,
)

class MCMCRegrowthSampler:
    def __init__(self,
        args: Namespace,
        fitness_function_path: str = DEFAULT_FITNESS_FUNCTION_PATH,
        plateau_patience_steps: int = 10,
        max_steps: int = 1000,
        greedy_acceptance: bool = False,
        acceptance_temperature: float = 1.0,
    ):
        self.grammar = open(args.grammar_file).read()
        self.grammar_parser = tatsu.compile(self.grammar)
        self.counter = parse_or_load_counter(args, self.grammar_parser)
        self.sampler = ASTSampler(self.grammar_parser, self.counter, seed=args.random_seed)
        self.regrowth_sampler = RegrowthSampler(self.sampler, args.random_seed)
        self.fitness_featurizer = build_fitness_featurizer(args)
        self.rng = np.random.default_rng(args.random_seed)
        self.fitness_function = _load_and_wrap_fitness_function(fitness_function_path)
        self.plateau_patience_steps = plateau_patience_steps
        self.max_steps = max_steps
        self.greedy_acceptance = greedy_acceptance
        self.acceptance_temperature = acceptance_temperature

        self.sample_index = 0
        self.samples = []

    def multiple_samples(self, n_samples: int, should_tqdm: bool = False):
        sample_iter = tqdm.notebook.trange(n_samples) if should_tqdm else range(n_samples)
        for _ in sample_iter:
            self.sample()

    def sample(self, verbose: bool = False):
        current_proposal = self.sampler.sample(global_context=dict(original_game_id=f'mcmc-{self.sample_index}'))
        current_proposal_fitness = self._score_proposal(current_proposal)  # type: ignore

        last_accepted_step = 0
        for step in range(self.max_steps):
            current_proposal, current_proposal_fitness, accepted = self.mcmc_regrowth_step(
                current_proposal, current_proposal_fitness, step  # type: ignore
            )

            if accepted:
                last_accepted_step = step
                if verbose:
                    print(f'Accepted step {step} with fitness {current_proposal_fitness}')

            else:
                if step - last_accepted_step > self.plateau_patience_steps:
                    if verbose:
                        print(f'Plateaued at step {step} with fitness {current_proposal_fitness}')
                    break

        self.samples.append((current_proposal, current_proposal_fitness))
        self.sample_index += 1

    def mcmc_regrowth_step(self,
        current_proposal: tatsu.ast.AST, 
        current_proposal_fitness: float,
        step_index: int,
        ):

        if self.regrowth_sampler.source_ast != current_proposal:
            self.regrowth_sampler.set_source_ast(current_proposal)

        new_proposal = None
        sample_generated = False
        try:
            while not sample_generated:
                new_proposal = self.regrowth_sampler.sample(step_index)
                # _test_ast_sample(ast, args, text_samples, grammar_parser)
                if ast_printer.ast_to_string(new_proposal) == ast_printer.ast_to_string(current_proposal):  # type: ignore
                    print('Regrowth generated identical games, repeating')
                else:
                    sample_generated = True

        except RecursionError:
            print('Recursion error, skipping sample')

        except SamplingException:
            print('Sampling exception, skipping sample')

        new_proposal_fitness = self._score_proposal(new_proposal)  # type: ignore
        
        if self.greedy_acceptance:
            if new_proposal_fitness < current_proposal_fitness:
                return new_proposal, new_proposal_fitness, True
            else:
                return current_proposal, current_proposal_fitness, False

        else:
            acceptance_probability = np.exp(-self.acceptance_temperature * (new_proposal_fitness - current_proposal_fitness))
            if self.rng.uniform() < acceptance_probability:
                return new_proposal, new_proposal_fitness, True
            else:
                return current_proposal, current_proposal_fitness, False

    def _score_proposal(self, proposal: tatsu.ast.AST):
        proposal_features = self.fitness_featurizer.parse(proposal, 'mcmc', True)  # type: ignore
        proposal_tensor = torch.tensor([v for k, v in proposal_features.items() if k not in NON_FEATURE_COLUMNS])  # type: ignore
        proposal_fitness = self.fitness_function(proposal_tensor)
        return proposal_fitness




In [63]:
mcmc = MCMCRegrowthSampler(DEFAULT_ARGS, DEFAULT_FITNESS_FUNCTION_PATH, greedy_acceptance=True, plateau_patience_steps=20)
mcmc.multiple_samples(100, should_tqdm=True)

No counted data for setup_not.not_args
String token rule for any: {'token_posterior': {'(any)': 1}, 'production': [('token', '(any)')]}
No counted data for hold_for.hold_pred
No counted data for hold_to_end.hold_pred
No counted data for forall_seq.forall_seq_vars
No counted data for forall_seq.forall_seq_then
No counted data for terminal_not.not_args
No counted data for scoring_external_minimize.scoring_expr
No counted data for scoring_and.and_args
No counted data for scoring_or.or_args
No counted data for scoring_not.not_args
String token rule for total_time: {'token_posterior': {'(total-time)': 1}, 'production': [('token', '(total-time)')]}
String token rule for total_score: {'token_posterior': {'(total-score)': 1}, 'production': [('token', '(total-score)')]}
No counted data for scoring_equals_comp.expr
No counted data for count_longest.name_and_types
No counted data for count_shortest.name_and_types
No counted data for count_total.name_and_types
No counted data for count_increasing_

  0%|          | 0/100 [00:00<?, ?it/s]

IndexError: list assignment index out of range

In [53]:
print(ast_printer.ast_to_string(mcmc.samples[0][0]))

(define (game mcmc-0-7-16-19-20-37-41-49-70-81-100-109-113-116-133) (:domain medium-objects-room-v1)(:setup  (game-conserved    (in_motion ?w)  ))(:constraints  (and    (forall (?m - (either cellphone cd) ?m - beachball)      (and        (preference preference2          (exists (?e - ball ?q - teddy_bear ?e - wall)            (then              (once (adjacent south_west_corner top_shelf) )              (hold (not (agent_holds ?e ?m) ) )              (hold (< (distance agent 7) (distance ?q ?e)) )            )          )        )      )    )    (preference preference2      (exists (?w - ball)        (at-end          (not            (agent_holds ?w)          )        )      )    )  ))(:terminal  (> (count-measure preference2:basketball) 4 ))(:scoring  20))
