In [None]:
%autoreload 2

In [None]:
from collections import defaultdict, Counter
from argparse import Namespace
from ast import literal_eval
import copy
import gzip
import itertools
import json
import math
import os
import pickle
import sys
import textwrap
import typing

import logging
logging.getLogger('matplotlib').setLevel(logging.WARNING)

import duckdb
from IPython.display import display, Markdown, HTML  # type: ignore
import matplotlib
import matplotlib.axes
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from Levenshtein import distance as _edit_distance
import networkx as nx
import numpy as np
import pandas as pd
import tabulate
import tatsu
import tatsu.ast
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
import tabulate
from tqdm.notebook import tqdm
from scipy import stats
import seaborn as sns
import sklearn
from sklearn.model_selection import GridSearchCV, train_test_split, KFold
from sklearn.pipeline import Pipeline

sys.path.append(os.path.abspath('..'))
sys.path.append(os.path.abspath('../src'))
sys.path.append(os.path.abspath('../reward-machine'))
from src.ast_utils import _extract_game_id, deepcopy_ast, replace_child
from src.ast_printer import ast_to_lines
from src import fitness_energy_utils as utils
from src.fitness_energy_utils import NON_FEATURE_COLUMNS
from src.fitness_features import *
from src.ast_counter_sampler import *
from src.evolutionary_sampler import *
from src import fitness_features_by_category, latest_model_paths, ast_utils

In [None]:
grammar = open('../dsl/dsl.ebnf').read()
grammar_parser = tatsu.compile(grammar)
game_asts = list(cached_load_and_parse_games_from_file('../dsl/interactive-beta.pddl', grammar_parser, False, relative_path='..'))
real_game_texts = [ast_printer.ast_to_string(ast, '\n') for ast in game_asts]
# regrown_game_texts = list(load_games_from_file('../dsl/ast-real-regrowth-samples.pddl'))
# regrown_game_1024_texts = list(load_games_from_file('../dsl/ast-real-regrowth-samples-1024.pddl'))
# print(len(real_game_texts), len(regrown_game_texts), len(regrown_game_texts) / 98, len(regrown_game_1024_texts), len(regrown_game_1024_texts) / 98)

fitness_df = utils.load_fitness_data('../data/fitness_features_1024_regrowths.csv.gz')
print(fitness_df.src_file.unique())
fitness_df.head()

In [None]:
SUBPLOTS_ADJUST_PARAMS = dict(top=0.925)
DEFAULT_IGNORE_METRICS = ['Timestamp']


FIGURE_TEMPLATE = r'''\begin{{figure}}[!htb]
% \vspace{{-0.225in}}
\centering
\includegraphics[width=\linewidth]{{figures/{save_path}}}
\caption{{ {{\bf FIGURE TITLE.}} FIGURE DESCRIPTION.}}
\label{{fig:{label_name}}}
% \vspace{{-0.2in}}
\end{{figure}}
'''
WRAPFIGURE_TEMPLATE = r'''\begin{{wrapfigure}}{{r}}{{0.5\linewidth}}
\vspace{{-.3in}}
\begin{{spacing}}{{1.0}}
\centering
\includegraphics[width=0.95\linewidth]{{figures/{save_path}}}
\caption{{ {{\bf FIGURE TITLE.}} FIGURE DESCRIPTION.}}
\label{{fig:{label_name}}}
\end{{spacing}}
% \vspace{{-.25in}}
\end{{wrapfigure}}'''

SAVE_PATH_PREFIX = './figures'


def save_plot(save_path, bbox_inches='tight', should_print=False):
    if save_path is not None:
        save_path_no_ext = os.path.splitext(save_path)[0]
        if should_print:
            print('Figure:\n')
            print(FIGURE_TEMPLATE.format(save_path=save_path, label_name=save_path_no_ext.replace('/', '-').replace('_', '-')))
            print('\nWrapfigure:\n')
            print(WRAPFIGURE_TEMPLATE.format(save_path=save_path, label_name=save_path_no_ext.replace('/', '-').replace('_', '-')))
            print('')
        
        if not save_path.startswith(SAVE_PATH_PREFIX):
            save_path = os.path.join(SAVE_PATH_PREFIX, save_path)
        
        save_path = os.path.abspath(save_path)
        folder, filename = os.path.split(save_path)
        os.makedirs(folder, exist_ok=True)
        plt.savefig(save_path, bbox_inches=bbox_inches, facecolor=plt.gcf().get_facecolor(), edgecolor='none')


In [None]:
# trace_filter_results_path = '../samples/trace_filter_results_max_exemplar_preferences_by_bcs_with_expected_values_2023_11_29_2023_12_05_1.pkl.gz'
model_key = 'max_exemplar_preferences_by_bcs_with_expected_values'
model_spec = latest_model_paths.MAP_ELITES_MODELS[model_key]
baseline_model = typing.cast(MAPElitesSampler, model_spec.load())

key_to_real_game_index = defaultdict(list)
real_game_index_to_key = {}
real_game_fitness_scores = []
ALL_REAL_GAME_KEYS = []
for i, ast in enumerate(game_asts):
    fitness_score, features = baseline_model._score_proposal(ast, return_features=True)  # type: ignore
    real_game_fitness_scores.append(fitness_score)
    key = baseline_model._features_to_key(ast, features)
    key_to_real_game_index[key].append(i)
    real_game_index_to_key[i] = key
    ALL_REAL_GAME_KEYS.append(key)

trace_filter_results = model_spec.load_trace_filter_data()
trace_filter_results.keys()

In [None]:
ARCHIVE_OCCUPANCY = 'archive_occupancy'


def plot_sampler_fitness_trajectory(
        evo: PopulationBasedSampler, title: typing.Optional[typing.List[str]] = None, 
        axsize: typing.Tuple[int, int] = (8, 6),
        plot_metrics: typing.Optional[bool] = None, 
        ignore_metrics: typing.Optional[typing.List[str]] = DEFAULT_IGNORE_METRICS,
        subplots_adjust_params: typing.Dict[str, float] = SUBPLOTS_ADJUST_PARAMS,
        min_real_game_fitness: typing.Optional[float] = None, 
        max_real_game_fitness: typing.Optional[float] = None,
        mean_real_game_fitness: typing.Optional[float] = None, 
        archive_occupancy_ignore_fake_bc: bool = True,
        fitness_left: bool = True,
        vertical: bool = False,
        fontsize: int = 16,
        save_path: typing.Optional[str] = None): 
    
    if min_real_game_fitness is None or max_real_game_fitness is None:    
        min_real_game_fitness =  -1 * evo.fitness_function.score_dict['max']
        max_real_game_fitness = -1 * evo.fitness_function.score_dict['min']
        mean_real_game_fitness = -1 * evo.fitness_function.score_dict['mean']

    if plot_metrics is None:
        plot_metrics = hasattr(evo, 'archive_metrics_history') and len(evo.archive_metrics_history) > 0  # type: ignore

    if ignore_metrics is None:
        ignore_metrics = []
    
    if not plot_metrics:
        layout = (1, 1)
    elif vertical:
        layout = (2, 1)
    else:
        layout = (1, 2)

    figsize = (axsize[0] * layout[1], axsize[1] * layout[0])
    title_fontsize = fontsize + 4

    fig, axes = plt.subplots(*layout, figsize=figsize)

    mean, max_fit, std = [], [], []
    for step_dict in evo.fitness_metrics_history:
        mean.append(step_dict['mean'])
        max_fit.append(step_dict['max'])
        std.append(step_dict['std'])

    mean = np.array(mean)
    max_fit = np.array(max_fit)
    std = np.array(std)
    
    fitness_ax_index = 0 if fitness_left else 1
    fitness_ax = typing.cast(matplotlib.axes.Axes, axes[fitness_ax_index] if plot_metrics else axes)

    fitness_ax.plot(mean, label='MAP-Elites fitness mean')
    fitness_ax.fill_between(np.arange(len(mean)), mean - std, mean + std, alpha=0.2, label='MAP-Elites fitness std')  # type; ignore
    fitness_ax.plot(max_fit, label='MAP-Elites fitness max')

    fitness_ax.hlines(min_real_game_fitness, 0, len(mean), label='Real game fitness range', color='black', ls='--')
    fitness_ax.hlines(max_real_game_fitness, 0, len(mean), color='black', ls='--')

    if mean_real_game_fitness is not None:
        fitness_ax.hlines(mean_real_game_fitness, 0, len(mean), label='Real game fitness mean', color='black', ls=':')

    fitness_ax.set_xlabel('Generation (index)', fontsize=fontsize)
    fitness_ax.set_ylabel('Fitness (arbitrary units)', fontsize=fontsize)

    fitness_ax.legend(loc='best', fontsize=fontsize)
    fitness_ax.tick_params(axis='both', which='major', labelsize=fontsize)

    if title is not None:
        if len(title) > 1 or not plot_metrics:
            fitness_ax.set_title(title[0], fontsize=title_fontsize)
        else:
            plt.suptitle(title[0], fontsize=title_fontsize)
    
    if plot_metrics:
        metrics_ax = typing.cast(matplotlib.axes.Axes, axes[1 - fitness_ax_index])

        if plot_metrics == ARCHIVE_OCCUPANCY:
            relevant_first_occupancies = evo.archive_cell_first_occupied  # type: ignore
            if archive_occupancy_ignore_fake_bc:
                relevant_first_occupancies = {k: v for k, v in relevant_first_occupancies.items() if k[0] == 1}

            first_occupancy_counter = Counter(relevant_first_occupancies.values())
            first_occupancy_arr = np.zeros(max(first_occupancy_counter.keys()) + 1)

            for k, v in first_occupancy_counter.items():
                first_occupancy_arr[k] = v

            first_occupancy_cumsum = np.cumsum(first_occupancy_arr)
            first_occupancy_cumsum /= first_occupancy_cumsum.max()

            metrics_ax.plot(first_occupancy_cumsum)
            metrics_ax.set_xlabel('Generation (index)', fontsize=fontsize)
            metrics_ax.set_ylabel('Archive occupancy (%)', fontsize=fontsize)

        else:
            metrics = {key: [] for key in evo.archive_metrics_history[0].keys() if key not in ignore_metrics}  # type: ignore
            for step_dict in evo.archive_metrics_history:  # type: ignore
                for key, value in step_dict.items():
                    if key in metrics:
                        metrics[key].append(value)

            
            for key, values in metrics.items():
                metrics_ax.plot(values, label=key.title())

            metrics_ax.set_xlabel('Generation', fontsize=fontsize)
            metrics_ax.set_ylabel('Number of games reaching threshold', fontsize=fontsize)

            metrics_ax.legend(loc='best', fontsize=fontsize)
        
        metrics_ax.tick_params(axis='both', which='major', labelsize=fontsize)

        if title is not None and len(title) > 1:
            metrics_ax.set_title(title[1], fontsize=title_fontsize)

        plt.subplots_adjust(**subplots_adjust_params)
        

    if save_path is not None:
        save_plot(save_path)

    plt.show()



plot_sampler_fitness_trajectory(baseline_model, ['Fitness', 'Occupancy'], axsize=(8, 4),
                                plot_metrics=ARCHIVE_OCCUPANCY, fitness_left=False,
                                save_path='baseline_quantitative_results.png')

In [None]:
ANNOTATION_LINEPLOT_KWARGS = dict(lw=1.5, c='black')
ANNOTATION_TEXT_KWARGS = dict(fontsize=12, ha='center', va='bottom', weight='bold')


def annotate_significance(ax: plt.Axes, pair: typing.Tuple[int, int], data_df: pd.DataFrame, 
                          attribute: str = 'fitness', group_attribute: str = 'n_prefs',
                          y_increment: float = 0, y_margin: float = 0.2, 
                          bar_y: float = 0.2, x_margin: float = 0.025,
                          AST: str = '*', starts_only_y_dec: float = 0.1,
                          plot_kwargs: dict = ANNOTATION_LINEPLOT_KWARGS,
                          text_kwargs: dict = ANNOTATION_TEXT_KWARGS):
    category_to_position = {int(t.get_text()): t._x for t in ax.get_xticklabels() if int(t.get_text()) in pair}

    first_data = data_df[data_df[group_attribute] == pair[0]][attribute]
    second_data = data_df[data_df[group_attribute] == pair[1]][attribute]
    # result = stats.ttest_ind(first_data, second_data)
    result = stats.ttest_ind(first_data, second_data)
    p_value = result.pvalue
    stars = AST * int(p_value < 0.05) + AST * int(p_value < 0.01) + AST * int(p_value < 0.001)
    if not stars:
        stars = 'n.s.'

    y_max = max(first_data.max(), second_data.max())
    y_bar_start = y_max + y_margin + y_increment
    y_bar_end = y_bar_start + bar_y

    points = [
        (category_to_position[pair[0]] + x_margin, y_bar_start),
        (category_to_position[pair[0]] + x_margin, y_bar_end),
        (category_to_position[pair[1]] - x_margin, y_bar_end),
        (category_to_position[pair[1]] - x_margin, y_bar_start),
    ]
    x, y = zip(*points)

    ax.plot(list(x), list(y), **plot_kwargs)

    middle = (category_to_position[pair[0]] + category_to_position[pair[1]]) / 2
    text_height = y_bar_end - starts_only_y_dec if '*' in stars else y_bar_end  
    ax.text(middle, text_height, stars, **text_kwargs)


In [None]:
baseline_model.fitness_metrics_history[-1]['mean'], -1 * baseline_model.fitness_function.score_dict['mean']

## Crossover ablation

In [None]:
ablation_no_custom_ops_key = 'ablation_max_exemplar_preferences_by_bcs_with_expected_values_no_custom_ops'
ablation_no_custom_ops_model_spec = latest_model_paths.MAP_ELITES_MODELS[ablation_no_custom_ops_key]
ablation_no_custom_ops_model = typing.cast(MAPElitesSampler, ablation_no_custom_ops_model_spec.load())

ablation_no_custom_ops_no_crossover_key = 'ablation_max_exemplar_preferences_by_bcs_with_expected_values_no_custom_ops_no_crossover'
ablation_no_custom_ops_no_crossover_model_spec = latest_model_paths.MAP_ELITES_MODELS[ablation_no_custom_ops_no_crossover_key]
ablation_no_custom_ops_no_crossover_model = typing.cast(MAPElitesSampler, ablation_no_custom_ops_no_crossover_model_spec.load())

In [None]:
NAME_TO_MODEL = {
    'baseline': baseline_model,
    'no_custom_ops': ablation_no_custom_ops_model,
    'no_custom_ops_no_crossover': ablation_no_custom_ops_no_crossover_model,
}

df_rows = []

for name, model in NAME_TO_MODEL.items():
    for key, fitness in model.fitness_values.items():
        if key[0] == 1:
            df_rows.append([name, key, fitness])


ablation_fitness_df = pd.DataFrame(df_rows, columns=['model', 'key', 'fitness'])
ablation_fitness_df = ablation_fitness_df.assign(standardized_fitness= (ablation_fitness_df.fitness - ablation_fitness_df.fitness.mean()) / ablation_fitness_df.fitness.std())
print(ablation_fitness_df.shape)
ablation_fitness_df.groupby('model').agg(['mean', 'std'])

In [None]:
plt.figure(figsize=(8, 6))

sns.violinplot(data=ablation_fitness_df, x='model', y='fitness', palette='tab10', hue='model', inner='quart', alpha=0.75, cut=1)
sns.pointplot(data=ablation_fitness_df,  x='model', y='fitness', errorbar=('ci', 95), linestyle='none', color='black', markers='d', markersize=15)


plt.xlabel('Ablation', fontsize=16)
plt.ylabel('Fitness', fontsize=16)

ax = plt.gca()
ax.set_xticklabels(['Full Model', 'No Custom Ops', 'No Custom Ops\nNo Crossover'])
ax.tick_params(axis='both', which='major', labelsize=12)

save_plot('ablation_fitness_violinplot.png')
plt.show()

In [None]:
for first_model_name, second_model_name in itertools.combinations(NAME_TO_MODEL.keys(), 2):
    first_model_df = ablation_fitness_df[ablation_fitness_df.model == first_model_name]
    second_model_df = ablation_fitness_df[ablation_fitness_df.model == second_model_name]

    merged_df = first_model_df.merge(second_model_df, on='key', suffixes=(f'_{first_model_name}', f'_{second_model_name}'))

    result = stats.ttest_rel(merged_df[f'standardized_fitness_{first_model_name}'], merged_df[f'standardized_fitness_{second_model_name}'])
    stars = '*' * int(result.pvalue < 0.05) + '*' * int(result.pvalue < 0.01) + '*' * int(result.pvalue < 0.001)
    print(f'{first_model_name} vs {second_model_name}:') 
    print(f'\tt-statistic = {result.statistic}')
    print(f'\tp-value = {result.pvalue} {stars}')
    print()
    

In [None]:
min_human_fitness = min(real_game_fitness_scores)
median_human_fitness = np.median(real_game_fitness_scores)
min_human_fitness, median_human_fitness

In [None]:
min_count = 0
median_count = 0

for key, fitness in baseline_model.fitness_values.items():
    if key[0] == 1:
        if fitness >= min_human_fitness:
            min_count += 1
        
        if fitness >= median_human_fitness:
            median_count += 1


min_count, median_count

In [None]:
count_nonzero = 0
count_any_nonzero = 0

for key in baseline_model.population:
    if key[0] == 1:
        if trace_filter_results['summary'][key] > 0:
            count_nonzero += 1
        
        per_component_count = {k : sum(v.values()) for k, v in trace_filter_results['full'][key].items()}
        if any(v > 0 for v in per_component_count.values()):
            count_any_nonzero += 1
        else:
            print(key)


count_nonzero, count_any_nonzero

In [None]:

2000 - 1515


In [None]:
print(ast_printer.ast_to_string(model.population[(1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0)], '\n'))

## Common sense ablation

In [None]:
ablation_no_play_trace_features_key = 'ablation_max_exemplar_preferences_by_bcs_with_expected_values_no_play_trace_features'
ablation_no_play_trace_features_model_spec = latest_model_paths.MAP_ELITES_MODELS[ablation_no_play_trace_features_key]
ablation_no_play_trace_features_model = typing.cast(MAPElitesSampler, ablation_no_play_trace_features_model_spec.load())
ablation_no_play_trace_features_trace_filter_data = ablation_no_play_trace_features_model_spec.load_trace_filter_data()


In [None]:
from tqdm import tqdm

no_play_trace_feature_games_full_fitness_model_scores = {}


for key, game in tqdm(ablation_no_play_trace_features_model.population.items()):
    if key[0] == 1:
        no_play_trace_feature_games_full_fitness_model_scores[key] = baseline_model._score_proposal(game)



In [None]:
relevant_baseline_full_fitness_model_scores = {k: v for k, v in baseline_model.fitness_values.items() if k[0] == 1}
print(np.mean(list(relevant_baseline_full_fitness_model_scores.values())), np.mean(list(no_play_trace_feature_games_full_fitness_model_scores.values())))

In [None]:
ablated = []
full = []

for key in relevant_baseline_full_fitness_model_scores:
    ablated.append(no_play_trace_feature_games_full_fitness_model_scores[key])
    full.append(relevant_baseline_full_fitness_model_scores[key])


print(np.mean(ablated), np.mean(full))

stats.ttest_rel(ablated, full)

In [None]:
keys_with_trace_filter_difference = {}
ablated_scores = []
full_scores = []

for key, baseline_trace_filter_score in trace_filter_results['summary'].items():
    if key[0] == 1:
        if key not in ablation_no_play_trace_features_trace_filter_data['summary']:
            print(key)
            continue

        ablation_trace_filter_score = ablation_no_play_trace_features_trace_filter_data['summary'][key]
        if baseline_trace_filter_score != ablation_trace_filter_score:
            keys_with_trace_filter_difference[key] = (baseline_trace_filter_score, ablation_trace_filter_score)

        ablated_scores.append(int(ablation_trace_filter_score > 0))
        full_scores.append(int(baseline_trace_filter_score > 0))

print(len(keys_with_trace_filter_difference))
print(len([k for k, v in keys_with_trace_filter_difference.items() if v[0] > 0 and v[1] == 0]))
print(len([k for k, v in keys_with_trace_filter_difference.items() if v[0] == 0 and v[1] > 0]))
print(sum(ablated_scores), sum(full_scores))

In [None]:
stats.ttest_rel(ablated_scores, full_scores)

## Coherence Ablation

In [None]:
ablation_no_coherence_features_key = 'ablation_max_exemplar_preferences_by_bcs_with_expected_values_no_coherence_features'
ablation_no_coherence_features_model_spec = latest_model_paths.MAP_ELITES_MODELS[ablation_no_coherence_features_key]
ablation_no_coherence_features_model = typing.cast(MAPElitesSampler, ablation_no_coherence_features_model_spec.load())
ablation_no_coherence_features_trace_filter_data = ablation_no_coherence_features_model_spec.load_trace_filter_data()

In [None]:
from tqdm import tqdm

no_coherence_feature_games_full_fitness_model_scores = {}


for key, game in tqdm(ablation_no_coherence_features_model.population.items()):
    if key[0] == 1:
        no_coherence_feature_games_full_fitness_model_scores[key] = baseline_model._score_proposal(game)



In [None]:
relevant_baseline_full_fitness_model_scores = {k: v for k, v in baseline_model.fitness_values.items() if k[0] == 1}
print(np.mean(list(relevant_baseline_full_fitness_model_scores.values())), np.mean(list(no_coherence_feature_games_full_fitness_model_scores.values())))

In [None]:
ablated = []
full = []

for key in relevant_baseline_full_fitness_model_scores:
    ablated.append(no_coherence_feature_games_full_fitness_model_scores[key])
    full.append(relevant_baseline_full_fitness_model_scores[key])


stats.ttest_rel(ablated, full)

In [None]:
MISSING_AS_ZERO = True

keys_with_trace_filter_difference = {}
ablated_scores = []
full_scores = []
missing_count = 0

for key, baseline_trace_filter_score in trace_filter_results['summary'].items():
    if key[0] == 1:
        if key not in ablation_no_coherence_features_trace_filter_data['summary']:
            missing_count += 1
            if not MISSING_AS_ZERO:
                continue
        
        ablation_trace_filter_score = ablation_no_coherence_features_trace_filter_data['summary'].get(key, 0)

        if baseline_trace_filter_score != ablation_trace_filter_score:
            keys_with_trace_filter_difference[key] = (baseline_trace_filter_score, ablation_trace_filter_score)

        ablated_scores.append(int(ablation_trace_filter_score > 0))
        full_scores.append(int(baseline_trace_filter_score > 0))

print(len(keys_with_trace_filter_difference), missing_count)
print(len([k for k, v in keys_with_trace_filter_difference.items() if v[0] > 0 and v[1] == 0]))
print(len([k for k, v in keys_with_trace_filter_difference.items() if v[0] == 0 and v[1] > 0]))
print(sum(ablated_scores), sum(full_scores))
stats.ttest_rel(ablated_scores, full_scores)

## Creativity/Complexity ablations

In [None]:
NUM_PREFERENCES_PREFIX = 'num_preferences_defined_'

n_prefs_per_game = []

for game in game_asts:
    features = baseline_model._proposal_to_features(game)
    n_prefs_key = [k for k, v in features.items() if k.startswith(NUM_PREFERENCES_PREFIX) and v][0]
    n_prefs = int(n_prefs_key[len(NUM_PREFERENCES_PREFIX):])
    n_prefs_per_game.append(n_prefs)

np.mean(n_prefs_per_game), np.median(n_prefs_per_game), np.std(n_prefs_per_game)


In [None]:
Counter(n_prefs_per_game)

In [None]:
fitness_by_prefs_df = pd.read_csv('./human_evals_data/fitness_by_prefs.csv')

ANNOTATION_INCREMENT = 0.3
FONTSIZE = 16




fig = plt.figure(figsize=(16, 6))
gs = gridspec.GridSpec(1, 2, width_ratios=[1, 2])

n_prefs_ax = plt.subplot(gs[0])
n_prefs_ax.hist(n_prefs_per_game, bins=range(1, 8), align='left')
n_prefs_ax.set_xlabel('Number of preferences defined', fontsize=FONTSIZE)
n_prefs_ax.set_ylabel('Number of games', fontsize=FONTSIZE)
n_prefs_ax.tick_params(axis='both', which='major', labelsize=FONTSIZE)
n_prefs_ax.set_xticks(range(1, 7))
n_prefs_ax.set_title('Number of preferences in real games', fontsize=FONTSIZE)




fitness_ax = plt.subplot(gs[1])
sns.violinplot(ax=fitness_ax, data=fitness_by_prefs_df, x='n_prefs', y='fitness', palette='tab10', hue='source', inner='quart', alpha=0.75, cut=1)
# sns.swarmplot(data=fitness_by_prefs_df,  x='n_prefs', y='fitness', color='white', dodge=False, size=10, alpha=0.25)
sns.pointplot(ax=fitness_ax, data=fitness_by_prefs_df,  x='n_prefs', y='fitness', errorbar=('ci', 95), linestyle='none', color='black', markers='d', markersize=15)


pairs = [(1, 4), (1, 3), (2, 4), (1, 2), (2, 3), (3, 4)]

for i, (low, high) in enumerate(pairs[::-1]):
    annotate_significance(fitness_ax, (low, high), fitness_by_prefs_df, y_increment=i * ANNOTATION_INCREMENT + 0.1)


fitness_ax.set_xlabel('Number of preferences defined', fontsize=FONTSIZE)
fitness_ax.set_ylabel('Fitness Score (arbitrary units)', fontsize=FONTSIZE)
fitness_ax.tick_params(axis='both', which='major', labelsize=FONTSIZE)
fitness_ax.legend(fontsize=14)
fitness_ax.set_ylim(33, fitness_ax.get_ylim()[1])
fitness_ax.set_title('Fitness by number of preferences defined', fontsize=FONTSIZE)
save_plot('complexity_n_prefs_combined.png')
plt.show()

In [None]:
bcs_ablation_with_pref_count_key = 'bcs_ablation_predicate_and_object_groups_setup_at_end_pref_count_expected_values'
bcs_ablation_with_pref_count_model_spec = latest_model_paths.MAP_ELITES_MODELS[bcs_ablation_with_pref_count_key]
bcs_ablation_with_pref_count_model = typing.cast(MAPElitesSampler, bcs_ablation_with_pref_count_model_spec.load())
bcs_ablation_with_pref_count_trace_filter_data = bcs_ablation_with_pref_count_model_spec.load_trace_filter_data()


bcs_ablation_no_pref_count_key = 'bcs_ablation_latest_at_end_no_game_object_expected_values'
bcs_ablation_no_pref_count_model_spec = latest_model_paths.MAP_ELITES_MODELS[bcs_ablation_no_pref_count_key]
bcs_ablation_no_pref_count_model = typing.cast(MAPElitesSampler, bcs_ablation_no_pref_count_model_spec.load())
bcs_ablation_no_pref_count_trace_filter_data = bcs_ablation_no_pref_count_model_spec.load_trace_filter_data()




In [None]:
print(ast_printer.ast_to_string(bcs_ablation_no_pref_count_model.population[(0, 1, 0, 1, 1, 1, 0, 1, 0, 1, 0)], '\n'))

In [None]:
s = """(on ?v1 ?v8)
            (on ?v4 ?v7)
            (on ?v6 ?v9)
            (on ?v9 ?v5)
            (on ?v6 ?v7)
            (on ?v4 ?v5)
            (on ?v5 ?v8)
            (on ?v2 ?v6)
            (on ?v6 ?v4)
            (on ?v2 ?v9)
            (on ?v4 ?v1)
            (on ?v7 ?v1)
            (on ?v1 ?v6)
            (on ?v1 ?v9)
            (on ?v2 ?v8)
            (on ?v6 ?v8)
            (on ?v5 ?v7)
            (on ?v3 ?v6)
            (on ?v3 ?v8)
            (on ?v7 ?v9)
            (on ?v3 ?v4)
            (on ?v2 ?v7)
            (on ?v8 ?v7)"""


edge_list = [l.replace('(on', '').replace(')', '').strip() for l in s.split('\n')]
G_directed = nx.parse_edgelist(edge_list, nodetype=str, create_using=nx.DiGraph())
cycle_found = len(list(nx.simple_cycles(G_directed))) > 0
G_undirected = G_directed.to_undirected()
disconnected = not nx.is_connected(G_undirected)

cycle_found, disconnected

In [None]:
import reward_machine_trace_filter


def load_trace_filter_compute_keys_with_traces(map_elites_key: str):
   args_str = f"""   --tqdm 
      --max-traces-per-game 400 
      --n-workers 11 
      --chunksize 1 
      --save-interval 1 
      --dont-sort-keys-by-traces 
      --use-only-database-nonconfirmed-traces 
      --worker-tqdm 
      --map-elites-model-name {map_elites_key}
      --relative-path ..
   """

   args = reward_machine_trace_filter.parser.parse_args(args_str.split())
   trace_filter = reward_machine_trace_filter.build_trace_evaluator(args)

   key_iter, population_size = trace_filter._build_key_iter()

   keys_with_traces = []
   for key in key_iter:
      if key in trace_filter.result_summary_by_key:
         continue

      no_traces_retval = trace_filter.handle_single_game_cache_no_traces(key)
      if not no_traces_retval[0]:
         keys_with_traces.append(key)

   return trace_filter, keys_with_traces

In [None]:


DEFAULT_CONTEXT = {
    VARIABLES_CONTEXT_KEY: {
        f'?v{i}': VariableDefinition([f'?v{i}'], ['block'], None)
        for i in range(20)
    },
    SECTION_CONTEXT_KEY: PREFERENCES
}


KEY_TO_MISSING_PREF_KEYS = {}


def find_game_preferences(game_ast):
    preferences = None

    for i in range(3, len(game_ast)):
        if game_ast[i][0] == ast_parser.PREFERENCES:
            preferences = game_ast[i][1].preferences
            break

    return preferences


def find_pref_by_name(preferences, pref_name):
    key_preferences = [p for p in preferences if p.definition.pref_name == pref_name]
    if len(key_preferences) != 1:
        raise ValueError(f'Found {len(key_preferences)} preferences for key {key}')

    return key_preferences[0]


def print_game_and_remaining_keys(model, key, trace_filter):
    all_traces, non_database_confirmed_traces, counts_by_trace_and_key, stop_count_by_key, total_count_by_key = trace_filter._process_key_traces_databse_results(key)
    _, _, _, expected_keys, database_keys_to_traces = trace_filter._find_key_traces(key)
    ignore_keys = list(database_keys_to_traces.keys())
    remaining_keys = [k for k in expected_keys if k not in ignore_keys]

    game_ast = model.population[key]

    if '(:setup' in remaining_keys:
        print(f'Found remaining setup section for {key}:')
        print(ast_printer.ast_section_to_string(game_ast[3], ast_parser.SETUP, '\n'))
        print()
        remaining_keys.remove('(:setup')

    if len(remaining_keys) == 0:
        return 0, False, False

    preferences = find_game_preferences(game_ast)
    max_at_end_and_length = 0

    if preferences is None:
        print('Failed to find preferences section in game:')
        print(ast_printer.ast_to_string(game_ast, '\n'))
        raise ValueError('No preferences section found')
    
    KEY_TO_MISSING_PREF_KEYS[key] = remaining_keys
    cycle_found = False
    disconnected = False
    
    for pref_key in remaining_keys:    
        pref = find_pref_by_name(preferences, pref_key)
        pref_str = ast_printer.ast_section_to_string(pref, ast_parser.PREFERENCES, '\n')
        if 'block' in pref_str and '(at-end' in pref_str and '(and' in pref_str:
            at_end_and = pref.definition.pref_body.body.exists_args.at_end_pred.pred
            at_end_and_length = len(at_end_and.and_args)
            max_at_end_and_length = max(at_end_and_length, max_at_end_and_length)

            logical_expr = BOOLEAN_PARSER(at_end_and, **simplified_context_deepcopy(DEFAULT_CONTEXT))
            logical_evaluation = BOOLEAN_PARSER.evaluate_unnecessary_detailed_return_value(logical_expr)  # type: ignore

            if logical_evaluation != 0:
                print(f'Found non-zero logical evaluation for {key}-{pref_key}')


            # edges = [l.strip() for l in ast_printer.ast_section_to_string(pred, PREFERENCES, '\n').replace('(and', '').replace('(on', '').replace(')', '').strip().split('\n')]
            # print(edges)

            # edges = defaultdict(list)
            edge_list = []

            # print(ast_printer.ast_section_to_string(at_end_and, ast_parser.PREFERENCES, '\n'))
            for p in at_end_and.and_args:
                if p.pred.parseinfo.rule != 'predicate':
                    continue
                
                pred_type = p.pred.pred.parseinfo.rule.replace('predicate_', '')
                if pred_type == 'on':
                    # edges[p.pred.pred.arg_1.term].append(p.pred.pred.arg_2.term)
                    arg_1 = p.pred.pred.arg_1.term
                    if isinstance(arg_1, tatsu.ast.AST):
                        arg_1 = arg_1.terminal

                    arg_2 = p.pred.pred.arg_2.term
                    if isinstance(arg_2, tatsu.ast.AST):
                        arg_2 = arg_2.terminal

                    e = f'{arg_1} {arg_2}'
                    edge_list.append(e)
            
            if edge_list:
                G_d = nx.parse_edgelist(edge_list, nodetype=str, create_using=nx.DiGraph())
                cycle_found = len(list(nx.simple_cycles(G_d))) > 0
                G_ud = G_d.to_undirected()
                disconnected = not nx.is_connected(G_ud) 
                
            # ts = graphlib.TopologicalSorter(edges)
            # try:
            #     ts.prepare()
            # except graphlib.CycleError:
            #     cycle_found = True

            
        else:
            print(f'Found remaining preference of new form for {key}:')
            print(pref_str)
            print()

    return max_at_end_and_length, cycle_found, disconnected
    
    

In [None]:
map_elites_key = bcs_ablation_with_pref_count_key
model = bcs_ablation_with_pref_count_model
trace_filter, keys_with_traces = load_trace_filter_compute_keys_with_traces(map_elites_key)

In [None]:
keys_by_max_length = defaultdict(list)
keys_by_max_length_with_cycles = defaultdict(list)
keys_by_max_length_with_disconnection = defaultdict(list)



for key in keys_with_traces:
    # ablation_no_coherence_features_model
    max_length, cycle_found, disconnected = print_game_and_remaining_keys(model, key, trace_filter)
    keys_by_max_length[max_length].append(key)
    if cycle_found:
        keys_by_max_length_with_cycles[max_length].append(key)
    if disconnected:
        keys_by_max_length_with_disconnection[max_length].append(key)


print({k: len(v) for k, v in keys_by_max_length.items()})
print({k: len(v) for k, v in keys_by_max_length_with_cycles.items()})
print({k: len(v) for k, v in keys_by_max_length_with_disconnection.items()})


In [None]:
test_game_str = """
(define (game test) (:domain few-objects-room-v1) 
(:constraints (and
    (preference preference0
    (exists (?v1 - cube_block ?v2 - block ?v3 - pyramid_block_red)
        (then
        (once (game_start))
        (hold (and (not (same_object ?v1 ?v3)) (adjacent ?v1 ?v2)))
        (once (and (not (same_object ?v2 ?v1)) (same_type ?v2 ?v3) (touch ?v1 ?v3) ))
    )
    )
    )
))
(:scoring
    (count-once-per-objects preference0)
))
"""
config = grammar_parser.config.replace_config(None)
ctx = NoParseinfoTokenizerModelContext(grammar_parser.rules, config=config)
test_game_ast = grammar_parser.parse(test_game_str, config=config, ctx=ctx)


In [None]:
# key = (0, 1, 3, 0, 1, 1, 0, 1, 0, 0, 0)
# prefs = find_game_preferences(model.population[key])
# pref = find_pref_by_name(prefs, KEY_TO_MISSING_PREF_KEYS[key][0])

prefs = find_game_preferences(test_game_ast)
pref = find_pref_by_name(prefs, 'preference0')


pred = pref.definition.pref_body.body.exists_args.at_end_pred
pred = ast_utils.deepcopy_ast(pred, ast_utils.ASTCopyType.NODE)
# pred.pred.and_args.pop(-1)
print(ast_printer.ast_section_to_string(pref, PREFERENCES, '\n'))

# mapping = {f'?v{i}': ['block'] for i in range(10)}
mapping = {f'?v{i}': ['game_object'] for i in range(10)}

trace_filter.trace_finder.predicate_data_estimator.max_child_args = 8
trace_filter.trace_finder.predicate_data_estimator.query_timeout = 120

# result = trace_filter.trace_finder.predicate_data_estimator.filter(pred, mapping, return_trace_ids=True)
# result.head()

local_context = dict(mapping={})

trace_filter.trace_finder._handle_ast(pref, section=PREFERENCES, local_context=local_context)

print({k: list(v)[0] if len(v) == 1 else len(v) for k, v in trace_filter.trace_finder.databse_confirmed_traces_by_preference_or_section.items()})

In [None]:
print(trace_filter.trace_finder.databse_confirmed_traces_by_preference_or_section.items())

In [None]:
replace_child(setup.setup.forall_vars.variables[0].var_type.type, 'terminal', 'block')
replace_child(setup.setup.forall_args.setup.statement.conserved_pred.pred.pred.arg_2.term, 'terminal', 'pyramid_block')

In [None]:
key = (0, 0, 2, 0, 1, 1, 0, 1, 0, 0, 0)
setup = model.population[key][3][1]
print(ast_printer.ast_section_to_string(setup, ast_parser.SETUP, '\n'))

# mapping = {f'?v{i}': ['block'] for i in range(10)}
mapping = {f'?v{i}': ['game_object'] for i in range(10)}

trace_filter.trace_finder.predicate_data_estimator.max_child_args = 8
trace_filter.trace_finder.predicate_data_estimator.query_timeout = 120

# result = trace_filter.trace_finder.predicate_data_estimator.filter(pred, mapping, return_trace_ids=True)
# result.head()

local_context = dict(mapping={})
trace_filter.trace_finder.setup_partial_results = []
retval = trace_filter.trace_finder._handle_ast(setup, section=ast_parser.SETUP, local_context=local_context)
print(retval)
print({k: len(v) for k, v in trace_filter.trace_finder.databse_confirmed_traces_by_preference_or_section.items()})
print(trace_filter.trace_finder.setup_partial_results)



In [None]:
import graphlib

edges = {
    '?v2': ['?v5', '?v1'],
    '?v5': ['?v0', '?v3'],
    '?v3': ['?v2'],
    '?v0': ['?v1', '?v4'],
}

ts = graphlib.TopologicalSorter(edges)
ts.prepare()