In [None]:
%autoreload 2


In [None]:
from argparse import Namespace
from collections import defaultdict
import copy
import difflib
import gzip
import itertools
import os
import pickle
import sys
import typing

import logging
logging.getLogger('matplotlib').setLevel(logging.WARNING)

from IPython.display import display, Markdown, HTML
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from scipy import stats
import sklearn
from sklearn.model_selection import GridSearchCV, train_test_split, KFold
from sklearn.pipeline import Pipeline
import tabulate
import tatsu
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
import tqdm.notebook as tqdm


sys.path.append(os.path.abspath('..'))
sys.path.append(os.path.abspath('../src'))
from src import fitness_energy_utils as utils
from src.fitness_energy_utils import NON_FEATURE_COLUMNS
from src.ast_counter_sampler import *
from src.ast_utils import cached_load_and_parse_games_from_file, load_games_from_file, _extract_game_id
from src import ast_printer
from src.fitness_features_by_category import FEATURE_CATEGORIES, PREDICATE_ROLE_FILLER_PATTERN_DICT, COUNTING_FEATURES_PATTERN_DICT

In [None]:
grammar = open('../dsl/dsl.ebnf').read()
grammar_parser = tatsu.compile(grammar)
game_asts = list(cached_load_and_parse_games_from_file('../dsl/interactive-beta.pddl', grammar_parser, False, relative_path='..'))
real_game_texts = [ast_printer.ast_to_string(ast, '\n') for ast in game_asts]
regrown_game_1024_texts = list(load_games_from_file('../dsl/ast-real-regrowth-samples-1024.pddl.gz'))
print(len(real_game_texts), len(regrown_game_1024_texts), len(regrown_game_1024_texts) / 98)


In [None]:
def extract_game_index(game_name: str):
    first_dash = game_name.find('-')
    second_dash = game_name.find('-', first_dash + 1)
    index = game_name[first_dash + 1:second_dash] if second_dash != -1 else game_name[first_dash + 1:]
    return int(index)


def extract_negative_index(game_name: str):
    first_dash = game_name.find('-')
    second_dash = game_name.find('-', first_dash + 1)
    if second_dash == -1:
        return -1
    
    third_dash = game_name.find('-', second_dash + 1)
    index = game_name[second_dash + 1:third_dash]
    return int(index)


fitness_df = utils.load_fitness_data('../data/fitness_features_1024_regrowths.csv.gz')

# fitness_df = fitness_df.assign(real=fitness_df.real.astype('int'), game_index=fitness_df.game_name.apply(extract_game_index), 
#                                negative_index= fitness_df.game_name.apply(extract_negative_index), fake=~fitness_df.real.astype('int'))
# fitness_df = fitness_df.sort_values(by=['fake', 'game_index', 'negative_index'], ignore_index=True).reset_index(drop=True)
# fitness_df.drop(columns=['Index', 'fake', 'game_index', 'negative_index'], inplace=True)
print(fitness_df.src_file.unique())
fitness_df.head()

## Compare model weights/ranks from different models

In [None]:
from fitness_features_by_category import FEATURE_CATEGORIES

ignore_categories = [
    "forall_less_important", "counting_less_important", 
    "grammar_use_less_important", "predicate_under_modal", 
    "predicate_role_filler", "compositionality"
]

fitness_model_paths = [
    # 'in_data_prop_L2_categories_full_seed_33_2023_11_23',
    # 'in_data_prop_L2_categories_full_seed_42_2023_11_23',
    # 'in_data_prop_L2_categories_full_seed_66_2023_11_23',
    # 'in_data_prop_L2_categories_full_seed_66_2023_11_23',
    'in_data_prop_L2_categories_minimal_counting_grammar_use_forall_seed_42_2023_12_22',
]

loaded_models = [utils.load_model_and_feature_columns(model_path) for model_path in fitness_model_paths]
table_models, table_feature_names = zip(*loaded_models)
table_names = ['L2 (seed 33)']  # , 'L2 (seed 42)', 'L2 (seed 66)', 'L2 (seed 99)']

all_feature_names = set()
for feature_names in table_feature_names:
    all_feature_names.update(feature_names)


all_ignore_features = set()
for category in ignore_categories:
    for feature in FEATURE_CATEGORIES[category]:
        if isinstance(feature, re.Pattern):
            all_ignore_features.update([f for f in all_feature_names if feature.match(f)])
        else:
            all_ignore_features.add(feature)

use_absolute_values = True

weights_by_model = {}
weight_ranks_by_model = {}


for model, name, feature_names in zip(table_models, table_names, table_feature_names):
    model_weights = model.named_steps['fitness'].model.fc1.weight.data.detach().squeeze()  # type: ignore
    if use_absolute_values:
        model_weights = torch.abs(model_weights)
    model_weights_rank = stats.rankdata(model_weights.numpy())
    if use_absolute_values:
        model_weights_rank = len(model_weights_rank) - model_weights_rank + 1
    weights_by_model[name] = {feature_names[i]: model_weights[i].item() for i in range(len(feature_names))}
    weight_ranks_by_model[name] = {feature_names[i]: model_weights_rank[i] for i in range(len(feature_names))}


feature_mean_rank = {
    feature_name: np.nanmean([weights.get(feature_name, np.nan) for weights in weight_ranks_by_model.values()])
    for feature_name in all_feature_names
}


mean_mean_rank_by_feature_number = defaultdict(list)
for feature_name, mean_rank in feature_mean_rank.items():
    if feature_name[-1].isdigit():
        feature_name = feature_name[:-2]
        mean_mean_rank_by_feature_number[feature_name].append(mean_rank)


mean_rank_by_pref_forall_type = defaultdict(list)
for feature_name, mean_rank in feature_mean_rank.items():
    if 'pref_forall' in feature_name:
        if feature_name.endswith('incorrect'):
            mean_rank_by_pref_forall_type['incorrect'].append(mean_rank)
        
        elif feature_name.endswith('correct'):
            mean_rank_by_pref_forall_type['correct'].append(mean_rank)

        elif feature_name.endswith('incorrect_count'):
            mean_rank_by_pref_forall_type['incorrect_count'].append(mean_rank)


feature_names_by_mean_rank = sorted(feature_mean_rank.keys(), key=lambda feature_name: feature_mean_rank[feature_name], reverse=False)

headers = ['Feature', 'Ignored', 'Mean Rank'] + table_names
rows = [[feature_name, 'Yes' if feature_name in all_ignore_features else 'No',  f'{feature_mean_rank[feature_name]:.3f}'] + [f'{weights_by_model[name].get(feature_name, np.nan):.3f} ({int(weight_ranks_by_model[name].get(feature_name, -1))})'   
                          for name in table_names] 
        for feature_name in feature_names_by_mean_rank]

# with open('temp_outputs/features_by_mean_weight.tsv', 'w') as f:
#     f.write(tabulate.tabulate(rows, headers, tablefmt='tsv'))

print(tabulate.tabulate(rows, headers, tablefmt='fancy_grid'))


In [None]:
mean_mean_ranks = [(np.mean(ranks), feature_name, len(ranks)) for feature_name, ranks in mean_mean_rank_by_feature_number.items()]
mean_mean_ranks.sort()

for mean_rank, feature_name, n_features in mean_mean_ranks:
    print(f'{feature_name}: {mean_rank:.3f} ({n_features})')

print()

higher_level_means = defaultdict(list)
higher_level_totals = defaultdict(int)
for mean_rank, feature_name, n_features in mean_mean_ranks:
    last_underscore_index = feature_name.rfind('_')
    category_name = feature_name[:last_underscore_index]
    higher_level_means[category_name].append(mean_rank * n_features)
    higher_level_totals[category_name] += n_features

higher_level_means = {category_name: np.sum(means) / higher_level_totals[category_name] for category_name, means in higher_level_means.items()}

for category_name, mean_rank in sorted(higher_level_means.items(), key=lambda x: x[1]):
    print(f'{category_name}: {mean_rank:.3f}')


print()

for pref_forall_type, mean_ranks in mean_rank_by_pref_forall_type.items():
    print(f'{pref_forall_type}: {np.mean(mean_ranks):.3f} ({len(mean_ranks)})')

## Older stuff

In [None]:
# USE_BINARIZED_FEATURES_MODEL = True

# if USE_BINARIZED_FEATURES_MODEL:
#     model_path = '../models/cv_binarized_model_2023_01_20.pkl.gz'
#     data_df = binarized_df
# else:
#     model_path = '../models/cv_fitness_model_2023_01_20.pkl.gz'
#     data_df = filtered_fitness_df
from latest_model_paths import LATEST_FITNESS_FUNCTION_DATE_ID, LATEST_SPECIFIC_OBJECTS_FITNESS_FUNCTION_DATE_ID
# model_date_id = LATEST_FITNESS_FUNCTION_DATE_ID
# model_date_id = 'in_data_prop_categories_minimal_counting_grammar_use_seed_42_2023_09_19'
# model_date_id = 'cv_fitness_model_in_data_prop_L1_categories_minimal_counting_grammar_use_seed_42_2023_09_19'
model_date_id = '1_4_regrowths_in_data_prop_L2_categories_minimal_counting_grammar_use_seed_42_2023_09_19'


if '1_4_regrowths' in model_date_id:
    data_df = utils.load_fitness_data('../data/fitness_features_1024_regrowths_1_4_regrowths.csv.gz')

else:
    data_df = fitness_df

cv_energy_model, feature_columns = utils.load_model_and_feature_columns(model_date_id)
print(len(feature_columns))



In [None]:
full_tensor = utils.df_to_tensor(data_df, feature_columns)
if 'wrapper' in cv_energy_model.named_steps: cv_energy_model.named_steps['wrapper'].eval()
full_tensor_scores = cv_energy_model.transform(full_tensor).detach()

In [None]:
real_game_scores = full_tensor_scores[:, 0]

print(f'Real game scores: {real_game_scores.mean():.4f} ± {real_game_scores.std():.4f}, min = {real_game_scores.min():.4f}, median = {torch.median(real_game_scores):.4f}, max = {real_game_scores.max():.4f}')

negatives_scores = full_tensor_scores[:, 1:]
torch.quantile(negatives_scores.ravel(), torch.linspace(0, 1, 11))
print(f'30th percentile negative energy: {torch.quantile(negatives_scores.ravel(), 0.3)}')

In [None]:
utils.evaluate_fitness_overall_ecdf(cv_energy_model, full_tensor)

In [None]:
from sklearn import metrics

positive_scores_numpy = real_game_scores.numpy().reshape(-1)
negative_scores_numpy = negatives_scores.numpy().reshape(-1)

n_positives = len(positive_scores_numpy)
n_negatives = len(negative_scores_numpy)

labels = np.concatenate([np.ones(n_positives), np.zeros(n_negatives)])
scores = np.concatenate([positive_scores_numpy, negative_scores_numpy]) * -1  # flipping the signs of the energies  # type: ignore


plt.figure(figsize=(4, 4))
metrics.PrecisionRecallDisplay.from_predictions(labels, scores)  # type: ignore
plt.legend(loc='best')
plt.show()

print(metrics.average_precision_score(labels, scores))


In [None]:
plt.figure(figsize=(4, 4))
metrics.RocCurveDisplay.from_predictions(labels, scores)  # type: ignore
plt.legend(loc='best')
plt.show()

In [None]:
positives_mean, positives_variance = positive_scores_numpy.mean(), positive_scores_numpy.var()
negatives_mean, negatives_variance = negative_scores_numpy.mean(), negative_scores_numpy.var()

-1 * (positives_mean - negatives_mean) / np.sqrt(0.5 * (positives_variance + negatives_variance))

In [None]:
steps = torch.linspace(0, 1, 11)
deciles = torch.quantile(negatives_scores.ravel(), steps)
print(steps)
print(f'Energy deciles: {deciles}')

In [None]:
weights = cv_energy_model.named_steps['fitness'].model.fc1.weight.data.detach().squeeze()  # type: ignore
bias = cv_energy_model.named_steps['fitness'].model.fc1.bias.data.detach().squeeze()  # type: ignore
print(f'Weights mean: {weights.mean():.4f}, std: {weights.std():.4f}, bias: {bias:.4f}')

plt.hist(weights, bins=30)
plt.title('Energy model weights')
plt.xlabel('Weight magnitude')
plt.ylabel('Count')
plt.show()

In [None]:
[f for f in feature_columns if 'n_5' in f]

In [None]:
weights.abs().max() / weights.abs().sum()

In [None]:
K = 20
top_features = torch.topk(weights, K)
bottom_features = torch.topk(weights, K, largest=False)

lines = []

lines.append('### Features with largest negative weights (most predictive of real games):')
for i in range(K):
    lines.append(f'{i+1}. {feature_columns[bottom_features.indices[i]]} ({bottom_features.values[i]:.4f})')

lines.append('### Features with largest positive weights (most predictive of fake games):')
for i in range(K):
    lines.append((f'{i+1}. {feature_columns[top_features.indices[i]]} ({top_features.values[i]:.4f})'))

display(Markdown('\n'.join(lines)))


In [None]:
quantiles = [0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99]
quantile_index = 0

abs_weights = weights.abs()

for magnitude in torch.linspace(0,abs_weights.max(), 5000):
    n = torch.sum(abs_weights < magnitude).item()
    if n / len(weights) >= quantiles[quantile_index]:
        print(f'Approximately {quantiles[quantile_index] * 100}% ({n}, {n / len(weights) * 100:.2f}%) of the weights have magnitude < {magnitude:.4f}')
        quantile_index += 1

    if quantile_index >= len(quantiles):
        break

In [None]:
from src.fitness_features_by_category import *


def print_weights_summary_by_category(model_weights: torch.Tensor, all_feature_columns: typing.List[str], return_lines: bool = False):
    abs_weights = model_weights.abs()
    sorted_feature_names = [t[1] for t in sorted([(abs_weights[i], all_feature_columns[i]) for i in range(len(all_feature_columns))], key=lambda x: x[0], reverse=True)]

    lines = []
    all_assigned_features = set()

    for category, features in FEATURE_CATEGORIES.items():
        category_feature_list = []
        for feature in features:
            if isinstance(feature, re.Pattern):
                category_feature_list.extend([f for f in feature_columns if feature.match(f)])

            else:
                category_feature_list.append(feature)

        all_assigned_features.update(category_feature_list)
        
        mean_abs_weight = np.mean([abs_weights[feature_columns.index(feature)] for feature in category_feature_list])
        sum_abs_weight = np.sum([abs_weights[feature_columns.index(feature)] for feature in category_feature_list])
        mean_sorted_index = np.mean([sorted_feature_names.index(feature) for feature in category_feature_list])
        prefix = f'For category {category} with {len(category_feature_list)} features'
        line = f'{prefix:54} | mean abs weight is {mean_abs_weight:5.2f} | sum abs weight is {sum_abs_weight:6.2f} | mean sorted index is {mean_sorted_index:6.2f}'
        if return_lines:
            lines.append(line)
        else:
            print(line)

    unassigned_features = [f for f in feature_columns if f not in all_assigned_features]
    if len(unassigned_features) > 0:
        print(f'Unassigned features: {unassigned_features}')

    if return_lines:
        return lines


In [None]:
def print_mean_weight_by_category_pattern(model_weights: torch.Tensor, model_feature_columns: typing.List[str], 
                                          patterns: typing.Dict[str, re.Pattern], top_k: int = 3, prefix_width: int = 40,
                                          sort_by: str = 'mean_abs_weight'):
    
    abs_weights = model_weights.abs()
    lines_and_weights = []
    for name, pattern in patterns.items():
        feature_indices = [i for i, feature in enumerate(model_feature_columns) if pattern.match(feature)]
        if len(feature_indices) == 0:
            continue
        mean_abs_weight = abs_weights[feature_indices].mean().item()
        max_abs_weight = abs_weights[feature_indices].max().item()
        k = top_k
        if k > len(feature_indices):
            k = len(feature_indices)
        mean_top_k_abs_weight = abs_weights[feature_indices].topk(k).values.mean().item()
        prefix = f'For "{name}" features'.ljust(prefix_width)
        line = f'{prefix} | {len(feature_indices):2} features | mean weight = {mean_abs_weight:2.3f} | mean top {k} weights = {mean_top_k_abs_weight:2.3f} | max weight = {max_abs_weight:2.3f}'
        if sort_by == 'mean_abs_weight':
            sort_key = mean_abs_weight
        elif sort_by == 'mean_top_k_abs_weight':
            sort_key = mean_top_k_abs_weight
        elif sort_by == 'max_abs_weight':
            sort_key = max_abs_weight
        else:
            raise ValueError(f'Unknown sort_by value {sort_by}')
        
        lines_and_weights.append((line, sort_key))

    lines_and_weights.sort(key=lambda x: x[1], reverse=True)
    for line, _ in lines_and_weights:
        print(line)



def print_top_features_by_category(model_weights: torch.Tensor, model_feature_columns: typing.List[str], category: str, k: int = 10):
    category = category.lower()
    if category not in FEATURE_CATEGORIES:
        raise ValueError(f'Category {category} not found')
    
    category_features = []
    for category_feature_or_pattern in FEATURE_CATEGORIES[category]:
        if isinstance(category_feature_or_pattern, re.Pattern):
            category_features.extend([f for f in model_feature_columns if category_feature_or_pattern.match(f)])
        else:
            category_features.append(category_feature_or_pattern)

    sorted_features = sorted([(model_weights[i], model_feature_columns[i]) for i in range(len(model_feature_columns))], key=lambda x: x[0], reverse=False)
    sorted_category_features = [f for f in sorted_features if f[1] in category_features]

    lines = []

    lines.append(f'### {category.capitalize()} features with largest negative weights:')
    for i in range(k):
        lines.append(f'{i+1}. {sorted_category_features[i][1]} ({sorted_category_features[i][0]:.4f})')

    lines.append(f'### {category.capitalize()} features with largest positive weights:')
    for i in range(k):
        lines.append(f'{i+1}. {sorted_category_features[-(i + 1)][1]} ({sorted_category_features[-(i + 1)][0]:.4f})')

    display(Markdown('\n'.join(lines)))


def print_weights_in_category(model_weights: torch.Tensor, model_feature_columns: typing.List[str], category: str, model_note: typing.Optional[str] = None, 
                              small_threshold: float = 1e-2, large_threshold: float = 1e-1):
    category = category.lower()
    if category not in FEATURE_CATEGORIES:
        raise ValueError(f'Category {category} not found')
    
    category_features = []
    for category_feature_or_pattern in FEATURE_CATEGORIES[category]:
        if isinstance(category_feature_or_pattern, re.Pattern):
            category_features.extend([f for f in model_feature_columns if category_feature_or_pattern.match(f)])
        else:
            category_features.append(category_feature_or_pattern)

    sorted_features = sorted([(model_weights[i], model_feature_columns[i]) for i in range(len(model_feature_columns))], key=lambda x: x[0], reverse=False)
    sorted_category_features = [f for f in sorted_features if f[1] in category_features]

    lines = []
    if model_note is not None:
        lines.append(f'### {category.capitalize()} features by weight ({model_note}):')
    else:    
        lines.append(f'### {category.capitalize()} features by weight:')
    for i, (weight, feature) in enumerate(sorted_category_features):
        threshold_marker = ""
        if weight.abs() > large_threshold:
            threshold_marker = "**"
        elif weight.abs() > small_threshold:
            threshold_marker = "*"
        lines.append(f'{i+1}. {feature} ({weight:.4f}){threshold_marker}')

    display(Markdown('\n'.join(lines)))


def get_model_weights(model):
    return model.named_steps['fitness'].model.fc1.weight.data.detach().squeeze()


In [None]:
print_weights_in_category(get_model_weights(cv_energy_model), feature_columns, 'grammar_use')

In [None]:
dates_and_names = [
    ('2023_09_13', 'fitness_sweep_1_4_regrowths_no_in_data_all_L1_categories_minimal_counting_seed_33'),
    ('2023_09_13', 'fitness_sweep_1_4_regrowths_no_in_data_all_L1_categories_minimal_counting_seed_42'),
    ('2023_09_13', 'fitness_sweep_1_4_regrowths_no_in_data_all_L1_categories_minimal_counting_seed_66'),
]

final_model_names = [
    '1_4_regrowths_no_in_data_all_L1_categories_minimal_counting_seed_33_2023_09_13',
    '1_4_regrowths_no_in_data_all_L1_categories_minimal_counting_seed_42_2023_09_13',
    '1_4_regrowths_no_in_data_all_L1_categories_minimal_counting_seed_66_2023_09_13',
]


for model_identifier in final_model_names:
# for model_identifier in dates_and_names:
    # display(Markdown(f'## {name}'))
    current_cv_data = None
    if isinstance(model_identifier, tuple):
        date_id, name = model_identifier
        current_cv_data = utils.load_data(date_id, 'data/fitness_cv', name)
        current_model = current_cv_data['cv'].best_estimator_
        current_feature_columns = current_cv_data['feature_columns']

    else:
        current_model, current_feature_columns = utils.load_model_and_feature_columns(model_identifier)
    # print_top_features_by_category(get_model_weights(l1_with_counting_cv_data['cv'].best_estimator_), l1_with_counting_cv_data['feature_columns'], 'counting', k=10)

    # print(f'For model {model_identifier}: {current_cv_data["cv"].best_params_}')

    model_note = None
    if current_cv_data is not None:
        model_note = f"L1 = {current_cv_data['cv'].best_params_['fitness__regularization_weight']:1.2f}"

    print_weights_in_category(get_model_weights(current_model), current_feature_columns, 'grammar_use',
                              model_note=model_note)

    # print_mean_weight_by_category_pattern(
    #     get_model_weights(current_model),
    #     current_feature_columns, COUNTING_FEATURES_PATTERN_DICT,
    #     sort_by='mean_abs_weight',
    # )
    # print()

In [None]:
dates_and_names = [
    ('2023_09_13', 'fitness_sweep_no_in_data_all_L1_categories_minimal_counting_seed_33'),
    ('2023_09_13', 'fitness_sweep_no_in_data_all_L1_categories_minimal_counting_seed_42'),
    ('2023_09_13', 'fitness_sweep_no_in_data_all_L1_categories_minimal_counting_seed_66'),
]

final_model_names = [
    'no_in_data_all_L1_categories_minimal_counting_seed_33_2023_09_13',
    'no_in_data_all_L1_categories_minimal_counting_seed_42_2023_09_13',
    'no_in_data_all_L1_categories_minimal_counting_seed_66_2023_09_13',
]


# for model_identifier in final_model_names:
for model_identifier in dates_and_names:
    # display(Markdown(f'## {name}'))
    current_cv_data = None
    if isinstance(model_identifier, tuple):
        date_id, name = model_identifier
        current_cv_data = utils.load_data(date_id, 'data/fitness_cv', name)
        current_model = current_cv_data['cv'].best_estimator_
        current_feature_columns = current_cv_data['feature_columns']

    else:
        current_model, current_feature_columns = utils.load_model_and_feature_columns(model_identifier)
    # print_top_features_by_category(get_model_weights(l1_with_counting_cv_data['cv'].best_estimator_), l1_with_counting_cv_data['feature_columns'], 'counting', k=10)

    # print(f'For model {model_identifier}: {current_cv_data["cv"].best_params_}')

    model_note = None
    if current_cv_data is not None:
        model_note = f"L1 = {current_cv_data['cv'].best_params_['fitness__regularization_weight']:1.2f}"

    print_weights_in_category(get_model_weights(current_model), current_feature_columns, 'grammar_use',
                              model_note=model_note)

    # print_mean_weight_by_category_pattern(
    #     get_model_weights(current_model),
    #     current_feature_columns, COUNTING_FEATURES_PATTERN_DICT,
    #     sort_by='mean_abs_weight',
    # )
    # print()

In [None]:
with_counting_dates_and_names = [
    ('2023_09_05', 'fitness_sweep_full_features_no_in_data_all_L1_categories_with_counting'),
    ('2023_09_07_1', 'fitness_sweep_full_features_no_in_data_all_L1_categories_with_role_filler'),  # I mucked up the names here
    ('2023_09_07', 'fitness_sweep_full_features_no_in_data_all_L1_categories_with_role_filler'), # I mucked up the names here
]

with_counting_final_model_date_ids = [
    'full_features_no_in_data_all_L1_categories_with_counting_2023_09_05',
    'full_features_no_in_data_all_L1_categories_with_role_filler_seed_42_2023_09_07', # I mucked up the names here
    'full_features_no_in_data_all_L1_categories_with_role_filler_seed_66_2023_09_07' # I mucked up the names here
]

# for model_identifier in with_counting_dates_and_names:
for model_identifier in with_counting_final_model_date_ids:
    # display(Markdown(f'## {name}'))
    if isinstance(model_identifier, tuple):
        date_id, name = model_identifier
        current_cv_data = utils.load_data(date_id, 'data/fitness_cv', name)
        current_model = current_cv_data['cv'].best_estimator_
        current_feature_columns = current_cv_data['feature_columns']

    else:
        current_model, current_feature_columns = utils.load_model_and_feature_columns(model_identifier)
    # print_top_features_by_category(get_model_weights(l1_with_counting_cv_data['cv'].best_estimator_), l1_with_counting_cv_data['feature_columns'], 'counting', k=10)

    print_mean_weight_by_category_pattern(
        get_model_weights(current_model),
        current_feature_columns, COUNTING_FEATURES_PATTERN_DICT,
        sort_by='mean_abs_weight',
    )
    print()

In [None]:
with_role_filler_dates_and_names = [
    ('2023_09_05', 'fitness_sweep_full_features_no_in_data_all_L1_categories_with_role_filler'),
    ('2023_09_07_1', 'fitness_sweep_full_features_no_in_data_all_L1_categories_with_counting'),  # I mucked up the names here
    ('2023_09_07', 'fitness_sweep_full_features_no_in_data_all_L1_categories_with_counting'), # I mucked up the names here
]

with_role_filler_final_model_date_ids = [
    'full_features_no_in_data_all_L1_categories_with_role_filler_2023_09_05',
    'full_features_no_in_data_all_L1_categories_with_counting_seed_42_2023_09_07', # I mucked up the names here
    'full_features_no_in_data_all_L1_categories_with_counting_seed_66_2023_09_07' # I mucked up the names here
]

# for model_identifier in with_role_filler_dates_and_names:
for model_identifier in with_role_filler_final_model_date_ids:
    # display(Markdown(f'## {name}'))
    if isinstance(model_identifier, tuple):
        date_id, name = model_identifier
        current_cv_data = utils.load_data(date_id, 'data/fitness_cv', name)
        current_model = current_cv_data['cv'].best_estimator_
        current_feature_columns = current_cv_data['feature_columns']

    else:
        current_model, current_feature_columns = utils.load_model_and_feature_columns(model_identifier)
    # print_top_features_by_category(get_model_weights(l1_with_counting_cv_data['cv'].best_estimator_), l1_with_counting_cv_data['feature_columns'], 'counting', k=10)

    print_mean_weight_by_category_pattern(
        get_model_weights(current_model),
        current_feature_columns, PREDICATE_ROLE_FILLER_PATTERN_DICT,
        sort_by='mean_abs_weight',
    )
    print()

## Looking at comparing the models on the negatives from the test set

In [None]:
no_reg_sweep_data = utils.load_data('2023_08_29', 'data/fitness_cv', 'fitness_sweep_full_features_no_in_data_all')
print(no_reg_sweep_data.keys())
no_reg_test_scores = no_reg_sweep_data['cv'].best_estimator_.transform(no_reg_sweep_data['test_tensor'])
no_reg_test_real_game_scores = no_reg_test_scores[:, 0]
no_reg_test_negatives_scores = no_reg_test_scores[:, 1:]

no_reg_argsort = torch.argsort(no_reg_test_negatives_scores.ravel())
no_reg_position_to_index = torch.zeros_like(no_reg_argsort)
no_reg_position_to_index[no_reg_argsort] = torch.arange(len(no_reg_argsort))

no_reg_indices_with_better_negative = set(torch.where((no_reg_test_negatives_scores < no_reg_test_real_game_scores[:, None, :]).ravel())[0].numpy())
print(len(no_reg_indices_with_better_negative))



l1_sweep_data = utils.load_data('2023_08_29', 'data/fitness_cv', 'fitness_sweep_full_features_no_in_data_all_L1')
l1_test_scores = l1_sweep_data['cv'].best_estimator_.transform(l1_sweep_data['test_tensor'])
l1_test_real_game_scores = l1_test_scores[:, 0]
l1_test_negatives_scores = l1_test_scores[:, 1:]

l1_argsort = torch.argsort(l1_test_negatives_scores.ravel())
l1_position_to_index = torch.zeros_like(l1_argsort)
l1_position_to_index[l1_argsort] = torch.arange(len(l1_argsort))

l1_indices_with_better_negative = set(torch.where((l1_test_negatives_scores < l1_test_real_game_scores[:, None, :]).ravel())[0].numpy())
print(len(l1_indices_with_better_negative))

In [None]:
utils.plot_energy_histogram(
    no_reg_sweep_data['cv'], 
    no_reg_sweep_data['train_tensor'], 
    no_reg_sweep_data['test_tensor'],
    histogram_title_note='no regularization')

utils.plot_energy_histogram(
    l1_sweep_data['cv'],
    l1_sweep_data['train_tensor'],
    l1_sweep_data['test_tensor'],
    histogram_title_note='L1 regularization')

In [None]:
better_negative_both = no_reg_indices_with_better_negative.intersection(l1_indices_with_better_negative)
better_negative_l1_only = l1_indices_with_better_negative.difference(no_reg_indices_with_better_negative)
better_negative_no_reg_only = no_reg_indices_with_better_negative.difference(l1_indices_with_better_negative)

print(f'Both: {len(better_negative_both)}, L1 only: {len(better_negative_l1_only)}, No reg only: {len(better_negative_no_reg_only)}')

In [None]:
index_diffs = no_reg_position_to_index - l1_position_to_index
index_diffs.topk(10, largest=False)

In [None]:
utils.evaluate_energy_contributions(cv_energy_model, full_tensor, 987, 
        feature_columns, full_tensor, real_game_texts, regrown_game_1024_texts, display_features_diff=False, min_display_threshold=0.001)
    


In [None]:
utils.evaluate_energy_contributions(l1_model, full_tensor, 987, 
        feature_columns, full_tensor, real_game_texts, regrown_game_1024_texts, display_features_diff=False, min_display_threshold=0.001)
    


In [None]:
l1_categories_sweep_data = utils.load_data('2023_09_01', 'data/fitness_cv', 'fitness_sweep_full_features_no_in_data_all_L1_categories')

In [None]:
utils.visualize_cv_outputs(
    l1_categories_sweep_data['cv'],
    l1_categories_sweep_data['train_tensor'],
    l1_categories_sweep_data['test_tensor']
)

