In [None]:
%autoreload 2

In [None]:
from collections import defaultdict
from argparse import Namespace
import copy
import gzip
import itertools
import os
import pickle
import sys
import typing

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import tatsu
import tatsu.ast
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
import tqdm
import sklearn
from sklearn.model_selection import GridSearchCV, train_test_split, KFold
from sklearn.pipeline import Pipeline

sys.path.append(os.path.abspath('..'))
sys.path.append(os.path.abspath('../src'))
from src import fitness_energy_utils as utils
from src.fitness_energy_utils import NON_FEATURE_COLUMNS
from src.fitness_features import build_fitness_featurizer, ASTFitnessFeaturizer
from src.ast_counter_sampler import *
from src.ast_counter_sampler import parse_or_load_counter

# Plan of attack
* Use the code in `ast_counter_sampler.py` to generate a sample from the MLE
* Score it with a fitness function adapted to working with a single example at a time
* At each subsequent iteration, regrow the game from a random node, and score the regrowth
* Since we have energy scores before and after, we can either accept greedily if $E_{new} < E_{old}$, or accept with probability $\exp(\beta (E_{old} - E_{new} ))$

In [None]:
DEFUALT_RANDOM_SEED = 33
DEFAULT_FITNESS_FUNCTION_PATH = '../models/cv_fitness_model_2022_12_24.pkl.gz'


def _load_and_wrap_fitness_function(fitness_function_path: str = DEFAULT_FITNESS_FUNCTION_PATH) -> typing.Callable[[torch.Tensor], float]:
    with gzip.open(fitness_function_path, 'rb') as f:
        cv_fitness_model = pickle.load(f)

    def _wrap_fitness(features: torch.Tensor):
        return cv_fitness_model.transform(features).item()

    return _wrap_fitness


DEFAULT_ARGS = Namespace(
    grammar_file=os.path.join('..', DEFAULT_GRAMMAR_FILE),
    parse_counter=False,
    counter_output_path=os.path.join('..', DEFAULT_COUNTER_OUTPUT_PATH),
    random_seed=DEFUALT_RANDOM_SEED,
)

class MCMCRegrowthSampler:
    def __init__(self,
        args: Namespace,
        fitness_function_path: str = DEFAULT_FITNESS_FUNCTION_PATH,
        plateau_patience_steps: int = 10,
        max_steps: int = 1000,
        greedy_acceptance: bool = False,
        acceptance_temperature: float = 1.0,
    ):
        self.grammar = open(args.grammar_file).read()
        self.grammar_parser = tatsu.compile(self.grammar)
        self.counter = parse_or_load_counter(args, self.grammar_parser)
        self.sampler = ASTSampler(self.grammar_parser, self.counter, seed=args.random_seed)
        self.regrowth_sampler = RegrowthSampler(self.sampler, args.random_seed)
        self.fitness_featurizer = build_fitness_featurizer(args)
        self.rng = np.random.default_rng(args.random_seed)
        self.fitness_function = _load_and_wrap_fitness_function(fitness_function_path)
        self.plateau_patience_steps = plateau_patience_steps
        self.max_steps = max_steps
        self.greedy_acceptance = greedy_acceptance
        self.acceptance_temperature = acceptance_temperature

        self.sample_index = 0
        self.samples = []

    def multiple_samples(self, n_samples: int, verbose: int = 0, should_tqdm: bool = False):
        sample_iter = tqdm.notebook.trange(n_samples) if should_tqdm else range(n_samples)
        for _ in sample_iter:
            self.sample(verbose)

    def sample(self, verbose: int = 0):
        current_proposal = None
        while current_proposal is None:
            try:
                current_proposal = self.sampler.sample(global_context=dict(original_game_id=f'mcmc-{self.sample_index}'))
            except RecursionError:
                if verbose >= 2: print('Recursion error, skipping sample')
            except SamplingException:
                if verbose >= 2: print('Sampling exception, skipping sample')
            
        current_proposal_features, current_proposal_fitness = self._score_proposal(current_proposal)  # type: ignore

        last_accepted_step = 0
        for step in range(self.max_steps):
            current_proposal, current_proposal_features, current_proposal_fitness, accepted = self.mcmc_regrowth_step(
                current_proposal, current_proposal_features, current_proposal_fitness, step, verbose  # type: ignore
            )

            if accepted:
                last_accepted_step = step
                if verbose:
                    print(f'Accepted step {step} with fitness {current_proposal_fitness}')

            else:
                if step - last_accepted_step > self.plateau_patience_steps:
                    if verbose:
                        print(f'Plateaued at step {step} with fitness {current_proposal_fitness}')
                    break

        self.samples.append((current_proposal, current_proposal_features, current_proposal_fitness))
        self.sample_index += 1

    def mcmc_regrowth_step(self,
        current_proposal: tatsu.ast.AST, 
        current_proposal_features: typing.Dict[str, float],
        current_proposal_fitness: float,
        step_index: int, verbose: int,
        ):

        if self.regrowth_sampler.source_ast != current_proposal:
            self.regrowth_sampler.set_source_ast(current_proposal)

        new_proposal = None
        sample_generated = False

        while not sample_generated:
            try:
                new_proposal = self.regrowth_sampler.sample(step_index, update_game_id=False)
                # _test_ast_sample(ast, args, text_samples, grammar_parser)
                if ast_printer.ast_to_string(new_proposal) == ast_printer.ast_to_string(current_proposal):  # type: ignore
                    if verbose >= 2: print('Regrowth generated identical games, repeating')
                else:
                    sample_generated = True

            except RecursionError:
                if verbose >= 2: print('Recursion error, skipping sample')

            except SamplingException:
                if verbose >= 2: print('Sampling exception, skipping sample')

        new_proposal_features, new_proposal_fitness = self._score_proposal(new_proposal)  # type: ignore
        
        if self.greedy_acceptance:
            accept = new_proposal_fitness < current_proposal_fitness
        else:
            acceptance_probability = np.exp(-self.acceptance_temperature * (new_proposal_fitness - current_proposal_fitness))
            accept = self.rng.uniform() < acceptance_probability

        if accept:
            return new_proposal, new_proposal_features, new_proposal_fitness, True
        else:
            return current_proposal, current_proposal_features, current_proposal_fitness, False

    def _score_proposal(self, proposal: tatsu.ast.AST):
        proposal_features = self.fitness_featurizer.parse(proposal, 'mcmc', True)  # type: ignore
        proposal_tensor = torch.tensor([v for k, v in proposal_features.items() if k not in NON_FEATURE_COLUMNS], 
            dtype=torch.float32)  # type: ignore
        proposal_fitness = self.fitness_function(proposal_tensor)
        return proposal_features, proposal_fitness




In [None]:
mcmc = MCMCRegrowthSampler(DEFAULT_ARGS, DEFAULT_FITNESS_FUNCTION_PATH, greedy_acceptance=False, 
    plateau_patience_steps=50, acceptance_temperature=10.0, max_steps=3000)
mcmc.multiple_samples(10, verbose=0, should_tqdm=True)
print([x[2] for x in mcmc.samples])

In [None]:
print(ast_printer.ast_to_string(mcmc.samples[1][0], '\n'))

In [None]:
fitness_featurizer = build_fitness_featurizer(DEFAULT_ARGS)
features = fitness_featurizer.parse(mcmc.samples[1][0], 'mcmc', True)
{k: v for k, v in features.items() if 'variables' in k}
# features

In [None]:
with open('/Users/guydavidson/Downloads/samples.pkl', 'rb') as f:
    samples = pickle.load(f)

print([x[2] for x in samples])

In [None]:
print(ast_printer.ast_to_string(samples[8][0], '\n'))

In [None]:
fitness_featurizer = build_fitness_featurizer(DEFAULT_ARGS)
features = fitness_featurizer.parse(samples[7][0], 'mcmc', True)
{k: v for k, v in features.items() if 'correctly' in k or 'forall' in k or 'no' in k}

In [None]:
from tatsu.infos import ParseInfo
ParseInfo?

In [None]:
with gzip.open(DEFAULT_FITNESS_FUNCTION_PATH, 'rb') as f:
    cv_fitness_model = pickle.load(f)

In [None]:
weights = cv_fitness_model.named_steps['fitness'].model.fc1.weight.data.detach().numpy().squeeze()
bias = cv_fitness_model.named_steps['fitness'].model.fc1.bias.data.detach().numpy().squeeze()
print(weights.mean(), bias)

In [None]:
s = 

In [None]:
plt.hist(weights, bins=100)
plt.title('Fitness Model Weights')
plt.xlabel('Weight magnitude')
plt.ylabel('Count')