In [None]:
from collections import defaultdict
import csv
from argparse import Namespace
from ast import literal_eval
import copy
import gzip
import itertools
import json
import math
import os
import pickle
import sys
import textwrap
import typing

import logging
logging.getLogger('matplotlib').setLevel(logging.WARNING)

import duckdb
from IPython.display import display, Markdown, HTML  # type: ignore
import matplotlib
import matplotlib.axes
import matplotlib.pyplot as plt
from Levenshtein import distance as _edit_distance
import numpy as np
import pandas as pd
import tabulate
import tatsu
import tatsu.ast
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
import tabulate
from tqdm.notebook import tqdm
from scipy import stats
from scipy.special import comb
import seaborn as sns
import sklearn
from sklearn.model_selection import GridSearchCV, train_test_split, KFold
from sklearn.pipeline import Pipeline
from tqdm import tqdm

sys.path.append(os.path.abspath('..'))
sys.path.append(os.path.abspath('../src'))
from src.ast_utils import _extract_game_id, deepcopy_ast, replace_child
from src.ast_printer import ast_to_lines
from src import fitness_energy_utils as utils
from src.fitness_energy_utils import NON_FEATURE_COLUMNS, load_data
from src.fitness_features import *
from src.ast_counter_sampler import *
from src.evolutionary_sampler import *
from src import fitness_features_by_category, latest_model_paths

In [None]:
grammar = open('../dsl/dsl.ebnf').read()
grammar_parser = tatsu.compile(grammar)
game_asts = list(cached_load_and_parse_games_from_file('../dsl/interactive-beta.pddl', grammar_parser, False, relative_path='..'))

In [None]:
model_key = 'max_exemplar_preferences_by_bcs_with_expected_values'
model_spec = latest_model_paths.MAP_ELITES_MODELS[model_key]
model = typing.cast(MAPElitesSampler, model_spec.load())

key_to_real_game_index = defaultdict(list)
real_game_index_to_key = {}
real_game_fitness_scores = []
ALL_REAL_GAME_KEYS = []
for i, ast in enumerate(game_asts):
    fitness_score, features = model._score_proposal(ast, return_features=True)  # type: ignore
    real_game_fitness_scores.append(fitness_score)
    key = model._features_to_key(ast, features)
    key_to_real_game_index[key].append(i)
    real_game_index_to_key[i] = key
    ALL_REAL_GAME_KEYS.append(key)

In [None]:
trace_filter_results = model_spec.load_trace_filter_data()
human_games_trace_filter_data = load_data('', 'samples', f'/trace_filter_results_interactive-beta.pddl_2024_03_19', relative_path='..')

In [None]:
human_games_trace_filter_data['full'][10].keys()

In [None]:
print("\n".join(ast_to_lines(real_game_key_to_ast[real_game_index_to_key[10]])))

In [None]:
REAL_GAME_INDICES_TO_INCLUDE = [
    0, 4, 6, 7, 11,
    14, 17, 23, 26, 28,
    31, 32, 35, 37, 40,
    41, 42, 45, 49, 51,
    52, 55, 58, 59, 64,
    74, 88, 90, 94, 96,
]

# REAL_GAME_INDICES_TO_INCLUDE = list(range(98))

REAL_GAME_KEY_LIST = [real_game_index_to_key[i] for i in REAL_GAME_INDICES_TO_INCLUDE]
REAL_GAME_KEY_DICT = {key: i for i, key in enumerate(REAL_GAME_KEY_LIST)}
REAL_GAME_KEYS = set(REAL_GAME_KEY_LIST)

UNMATCHED_TOP_30_KEYS = [
    (1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0),
    (1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1),
    (1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0),
    (1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0),
    (1, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0),
    (1, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0),
    (1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0),
    (1, 1, 2, 0, 0, 0, 0, 0, 1, 0, 0, 0),
    (1, 0, 2, 0, 1, 0, 0, 0, 0, 1, 0, 0),
    (1, 1, 2, 0, 0, 1, 0, 0, 0, 0, 0, 1),
    (1, 0, 2, 0, 0, 0, 0, 1, 0, 0, 0, 0),
    (1, 0, 2, 1, 0, 0, 1, 0, 0, 0, 0, 0),
    (1, 1, 2, 0, 0, 0, 0, 0, 0, 0, 1, 0),
    (1, 1, 2, 0, 0, 0, 1, 0, 1, 0, 0, 0),
    (1, 1, 2, 0, 1, 1, 0, 0, 0, 0, 0, 0),
    (1, 1, 3, 1, 0, 0, 1, 0, 0, 0, 1, 0),
    (1, 1, 3, 0, 0, 2, 0, 0, 0, 0, 0, 0),
    (1, 1, 3, 0, 0, 0, 0, 0, 0, 3, 0, 0),
    (1, 0, 3, 0, 0, 0, 0, 0, 1, 0, 0, 0),
    (1, 1, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0),
    (1, 1, 3, 0, 1, 0, 0, 0, 0, 0, 0, 0),
    (1, 1, 3, 0, 0, 1, 0, 0, 1, 0, 0, 0),
    (1, 0, 4, 0, 1, 1, 0, 1, 0, 1, 0, 0),
    (1, 0, 4, 0, 0, 0, 0, 0, 3, 0, 0, 0),
    (1, 1, 4, 2, 0, 0, 0, 0, 0, 0, 0, 0),
    (1, 1, 4, 0, 0, 1, 1, 1, 0, 1, 0, 0),
    (1, 0, 4, 2, 0, 0, 0, 0, 0, 0, 0, 1),
    (1, 1, 4, 0, 2, 0, 0, 0, 1, 0, 0, 0),
    (1, 1, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0),
    (1, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0)
]

In [None]:
TRANSLATIONS_DIR = '../llm_tests/translations'
TRANSLATION_DATE = '2024_01_12'
UNMATCHED_ONLY_TOP_30 = True

with open(f'{TRANSLATIONS_DIR}/human_games_translations_split_{TRANSLATION_DATE}.json') as f:
    human_game_texts = json.load(f)
    human_game_texts = {literal_eval(k): v for k, v in human_game_texts.items()}

with open(f'{TRANSLATIONS_DIR}/human_cell_archive_games_translations_split_{TRANSLATION_DATE}.json') as f:
    human_cell_archive_game_texts = json.load(f)
    human_cell_archive_game_texts = {literal_eval(k): v for k, v in human_cell_archive_game_texts.items()}

with open(f'{TRANSLATIONS_DIR}/novel_archive_cell_games_translations_split_{TRANSLATION_DATE}.json') as f:
    novel_archive_cell_game_texts = json.load(f)
    novel_archive_cell_game_texts = {literal_eval(k): v for k, v in novel_archive_cell_game_texts.items()}
    if UNMATCHED_ONLY_TOP_30:
        novel_archive_cell_game_texts = {k: v for k, v in novel_archive_cell_game_texts.items() if k in UNMATCHED_TOP_30_KEYS}

In [None]:
real_game_key_to_ast = {key: game_asts[i] for key, i in REAL_GAME_KEY_DICT.items()}
matched_game_key_to_ast = {key: model.population[key] for key in human_cell_archive_game_texts.keys()}
unmatched_game_key_to_ast = {key: model.population[key] for key in novel_archive_cell_game_texts.keys()}

In [None]:
def get_activating_traces(filter_info, key, exclude_setup=False):
    sub_ast_to_trace_activations = filter_info['full'][key]
    
    sub_ast_to_activating_traces = {}
    for sub_ast, trace_activations in sub_ast_to_trace_activations.items():
        activating_traces = [trace for trace, activation in trace_activations.items() if activation > 0]
        sub_ast_to_activating_traces[sub_ast] = set(activating_traces)

    if exclude_setup:
        sub_ast_to_activating_traces = {sub_ast: traces for sub_ast, traces in sub_ast_to_activating_traces.items() if 'setup' not in sub_ast}

    sub_ast_to_activating_traces['all'] = set.intersection(*[sub_ast_to_activating_traces[sub_ast] for sub_ast in sub_ast_to_activating_traces.keys()])
    sub_ast_to_activating_traces['any'] = set.union(*[sub_ast_to_activating_traces[sub_ast] for sub_ast in sub_ast_to_activating_traces.keys()])


    return sub_ast_to_activating_traces

In [None]:
# Remap the human_games_trace_filter_data according to the key instead of the index
remapped_human_games_trace_filter_data = {"full": {}}

for real_game_idx in human_games_trace_filter_data['full'].keys():    
    real_game_key = real_game_index_to_key[real_game_idx]
    remapped_human_games_trace_filter_data['full'][real_game_key] = human_games_trace_filter_data['full'][real_game_idx]

## Actual Analysis

In [None]:
generated_keys_mapping = {
    "matched": list(matched_game_key_to_ast.keys()),
    "unmatched": list(unmatched_game_key_to_ast.keys())
}
all_human_game_keys = real_game_index_to_key.values()

In [None]:
def jaccard(a, b, aggregation):
    if len(a[aggregation]) == 0 and len(b[aggregation]) == 0:
        return 0
    
    return len(a[aggregation].intersection(b[aggregation])) / len(a[aggregation].union(b[aggregation]))

In [None]:
data = []
for exclude_setup in [False, True]:
    # human_game_activating_traces = [
    #     get_activating_traces(remapped_human_games_trace_filter_data, key, exclude_setup=exclude_setup)
    #     for key in all_human_game_keys
    # ]

    human_game_activating_traces = [
        get_activating_traces(human_games_trace_filter_data, idx, exclude_setup=exclude_setup)
        for idx in range(98)
    ]

    for aggregation in ['all', 'any']:
        for key_type in generated_keys_mapping.keys():
            closest_similarities = []
            for key in generated_keys_mapping[key_type]:
                activating_traces = get_activating_traces(trace_filter_results, key, exclude_setup)
                similarities = [jaccard(activating_traces, human_game, aggregation) for human_game in human_game_activating_traces]
                closest_similarities.append(max(similarities))

            data.append({
                "exclude_setup": exclude_setup,
                "aggregation": aggregation,
                "key_type": key_type,
                "avg_closest_similarity": np.mean(closest_similarities)
            })
            print(f"\n{key_type} games, '{aggregation}' aggregation, exclude_setup={exclude_setup}:")
            print(f"Average Jaccard similarity between generated game and closest human game: {np.mean(closest_similarities)}")

In [None]:
print(pd.DataFrame(data))

## Code for Supplemental Figure

In [None]:
# The keys of the real games for which we want to know the most similar generated game
TARGET_KEYS = [
    (1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0), # matched 14
    (1, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0), # matched 31
    (1, 0, 2, 0, 0, 1, 0, 0, 0, 0, 1, 0), # matched 40

    (1, 0, 2, 0, 1, 0, 0, 0, 0, 1, 0, 0), # unmatched (place the bin near the north wall...)
    (1, 1, 3, 1, 0, 0, 1, 0, 0, 0, 1, 0), # unmatched (credit cards and CDs)
    (1, 1, 3, 0, 0, 0, 0, 0, 0, 3, 0, 0), # unmatched (block stacking)
]


In [None]:
len(all_human_game_keys)
len(real_game_key_to_ast.values())

In [None]:
for exclude_setup in [False, True]:
    # human_game_activating_traces = [
    #     get_activating_traces(remapped_human_games_trace_filter_data, key, exclude_setup=exclude_setup)
    #     for key in all_human_game_keys if key in remapped_human_games_trace_filter_data['full']
    # ]

    human_game_activating_traces = [
        get_activating_traces(human_games_trace_filter_data, idx, exclude_setup=exclude_setup)
        for idx in range(98)
    ]

    for aggregation in ['all', 'any']:
        data = []
        for idx, key in enumerate(TARGET_KEYS):
            activating_traces = get_activating_traces(trace_filter_results, key, exclude_setup)
            similarities = [jaccard(activating_traces, human_game, aggregation) for human_game in human_game_activating_traces]
            
            closest_human_game_idx = np.argmax(similarities)
            closest_human_game = "\n".join(ast_to_lines(game_asts[closest_human_game_idx]))

            target_ast = matched_game_key_to_ast[key] if idx <= 2 else unmatched_game_key_to_ast[key]
            target_game = "\n".join(ast_to_lines(target_ast))
            
            data.append((key, closest_human_game_idx, target_game, closest_human_game, np.max(similarities)))

        filename = f"./temp_outputs/supplemental_figure_closest_games_{'exclude_setup' if exclude_setup else 'include_setup'}_aggregation-{aggregation}.csv"
        with open(filename, 'w') as f:
            writer = csv.writer(f)
            writer.writerow(["key", "closest_human_game_idx", "target_game", "closest_human_game", "similarity"])
            writer.writerows(data)

        print(f"\nexclude_setup={exclude_setup}, aggregation={aggregation}: {[i[1] for i in data]}")

In [None]:
TARGET = (1, 1, 3, 1, 0, 0, 1, 0, 0, 0, 1, 0)
# CLOSEST = (1, 0, 2, 1, 0, 0, 0, 0, 0, 0, 1, 0)
CLOSEST = (1, 1, 3, 0, 0, 0, 0, 0, 0, 0, 1, 0)

t = get_activating_traces(trace_filter_results, TARGET, exclude_setup=False)
c = get_activating_traces(remapped_human_games_trace_filter_data, CLOSEST, exclude_setup=False)

In [None]:
print("\n".join(ast_to_lines(real_game_key_to_ast[CLOSEST])))

In [None]:
print("For target game:")
for k, v in t.items():
    print(f"- {k} -> {len(v)}")

print("For closest game:")
for k, v in c.items():
    print(f"- {k} -> {len(v)}")

In [None]:
aggregation = 'any'
len(t[aggregation].intersection(c[aggregation])), len(t[aggregation].union(c[aggregation]))

In [None]:
human_games_trace_filter_data['full']