In [None]:
%load_ext autoreload
%autoreload 2
from pcfg import PCFG
from pcsg import PCSG, from_pcsg_string
import numpy as np
import pickle
import os
import random
import math
import time


from utils_grammar import (
    get_grammar_string,
    compute_random_guess_metric,
    random_sentence_generator,
    generate_similar_sequences,
    get_subgrammar_string,
    get_grammatical_sentences,
    train_test_split,
    compute_terminal_freq,
    get_perturbed_grammar,
    to_latex_equation,
    get_nongrammatical_sentences_from_perturbed_grammar
)
from example_grammar import grammar_details_dict
from utils_plot_grammar import plot_nonterminal_map, plot_nonterminal_map_with_edit
from tqdm import tqdm


import sys
sys.path.append("../")
from read_output.utils_plot import save_image_template


def get_stat(sequences):
    try:
        len_sequences = [len(sequence) for sequence in sequences]
        print(f"Number of Sequences: {len(sequences)}")
        print(f"Unique Sequences: {len(set(sequences))}")
        print(f"Max length: {max(len_sequences)}")
        print(f"Min length: {min(len_sequences)}")
        print(f"Mean length: {np.mean(len_sequences)}")
        print(f"Std length: {np.std(len_sequences)}")
        result = {
            "num sequences": len(sequences),
            "unique sequences": len(set(sequences)),
            "max": max(len_sequences),
            "min": min(len_sequences),
            "mean": np.mean(len_sequences),
            "std": np.std(len_sequences)
        }
        return result
    except Exception as e:
        print(e)
        return None


In [None]:
show_image = False

# context-free grammar

# grammar_name = "pcfg_one_character_missing"
# grammar_name = "pcfg_cfg3b_disjoint_terminals_latin"
grammar_name = "pcfg_cfg3b_disjoint_terminals"
# grammar_name = "pcfg_cfg3b_disjoint_terminals_sensitivity"
# grammar_name = "pcfg_cfg3b_disjoint_terminals_sensitivity_modification_2"
# grammar_name = "pcfg_cfg3b_disjoint_terminals_sensitivity_modification_10_5"


# grammar_name = "pcfg_cfg3b_disjoint_terminals_one_rule_different"
# grammar_name = "pcfg_cfg3b_disjoint_terminals_two_rules_different"
# grammar_name = "pcfg_cfg3b_disjoint_terminals_three_rules_different"
# grammar_name = "pcfg_cfg3b_disjoint_terminals_four_rules_different"
# grammar_name = "pcfg_cfg3b_disjoint_terminals_five_rules_different"


# grammar_name = "pcfg_4_3_1_2_3_4_5_6_7_8_9_one_rule_different"
# grammar_name = "pcfg_4_3_1_2_3_4_5_6_7_8_9_two_rules_different"
# grammar_name = "pcfg_4_3_1_2_3_4_5_6_7_8_9_three_rules_different"
# grammar_name = "pcfg_4_3_1_2_3_4_5_6_7_8_9_four_rules_different"
# grammar_name = "pcfg_4_3_1_2_3_4_5_6_7_8_9_five_rules_different"



# grammar_name = "pcfg_cfg_extended_eq_len_skewed_prob"


# grammar_name = "pcfg_cfg3b_disjoint_terminals_leaf_0.55"
# grammar_name = "pcfg_cfg3b_disjoint_terminals_leaf_0.60"
# grammar_name = "pcfg_cfg3b_disjoint_terminals_leaf_0.70"
# grammar_name = "pcfg_cfg3b_disjoint_terminals_leaf_0.80"
# grammar_name = "pcfg_cfg3b_disjoint_terminals_leaf_0.90"


# grammar_name = "pcfg_cfg3b_disjoint_terminals_all_rules_0.55"
# grammar_name = "pcfg_cfg3b_disjoint_terminals_all_rules_0.60"
# grammar_name = "pcfg_cfg3b_disjoint_terminals_all_rules_0.70"
# grammar_name = "pcfg_cfg3b_disjoint_terminals_all_rules_0.80"
# grammar_name = "pcfg_cfg3b_disjoint_terminals_all_rules_0.90"
# grammar_name = "pcfg_cfg3b_disjoint_terminals_all_rules_0.95"



# grammar_name = "pcfg_balanced_parenthesis"
# grammar_name = "pcfg_reverse_string"
# grammar_name = "pcfg_cfg3b_disjoint_terminals_skewed_prob"

# grammar_name = "pcfg_4_3_1_2_3_4_5_6_7_8_9_latin"
# grammar_name = "pcfg_4_3_1_2_3_4_5_6_7_8_9"
# grammar_name = "pcfg_4_3_1_2_3_4_5_6_7_8_9_skewed_0.5"
# grammar_name = "pcfg_4_3_1_2_3_4_5_6_7_8_9_skewed_0.8"
# grammar_name = "pcfg_4_3_1_2_3_4_5_6_7_8_9_skewed_0.95"

# grammar_name = "pcfg_4_3_1_2_3_4_5_6_7_8_9_eq_len_uniform_prob"
# grammar_name = "pcfg_4_3_1_2_3_4_5_6_7_8_9_eq_len_skewed_prob"
# grammar_name = "pcfg_4_3_1_2_3_4_5_6_7_8_9_eq_len_skewed_prob_0.80"

# grammar_name = "pcfg_4_3_1_2_3_4_5_6_7_8_9_eq_len_all_rules_skewed_prob"

# grammar_name = "pcfg_cfg3b_eq_len_uniform_prob"
# grammar_name = "pcfg_cfg3b_eq_len_skewed_prob"
# grammar_name = "pcfg_cfg3b_eq_len_skewed_prob_0.75"
# grammar_name = "pcfg_cfg3b_eq_len_skewed_prob_0.90"
# grammar_name = "pcfg_cfg3b_eq_len_skewed_prob_0.98"


# grammar_name = "pcfg_example_regular"
# grammar_name = "pcfg_example_context_free"
# grammar_name = "pcfg_example_context_free"

# grammar_name = "pcfg_max-depth_3_max-breadth_3_rules_2_skewness_-1_alphabet_0-1-2-3-4-5-6-7-8-9"




# regular_grammar

# grammar_name = "preg_9_5_4_2_4_1_1"
# grammar_name = "preg_9_10_4_2_4_1_1"
# grammar_name = "preg_9_20_4_2_4_1_1"
# grammar_name = "preg_9_30_4_2_4_1_1"
# grammar_name = "preg_9_40_4_2_4_1_1"
# grammar_name = "preg_9_40_4_2_4_1_1_1"

# grammar_name = "preg_alphabet_2"
# grammar_name = "preg_alphabet_7"
# grammar_name = "preg_alphabet_26"

# grammar_name = "preg_numeral_2"
# grammar_name = "preg_numeral_7"
# grammar_name = "preg_numeral_10"


# grammar_name = "preg_alphabet_combined"
# grammar_name = "preg_alphabet_combined_skewed_prob"



# context-sensitive grammar

# grammar_name = "pcsg_csg3b_disjoint_terminals_A8_left"
# grammar_name = "pcsg_csg3b_disjoint_terminals_A8_right"



# grammar_name = "pcfg_double-branch_max-depth_4_max-breadth_3_alphabet_1-2-3-4-5-6-7-8-9"
# grammar_name = "pcfg_max-depth_16_max-breadth_2_rules_2_skewness_0_alphabet_0-1-2-3-4"



# standard_name

# grammar_name = "pcfg_double-branch_max-depth_4_max-breadth_3_rules_3_skewness_2_alphabet_1-2-3-4-5-6-7-8-9"
# grammar_name = "pfsa_states_2_symbols_3_index_0_alphabet_0-1"
# grammar_name = "pcfg_max-depth_3_max-breadth_3_rules_3_skewness_3_alphabet_0-1-2-3-4-5-6-7-8-9"



# import argparse
# parser = argparse.ArgumentParser()
# parser.add_argument("--grammar_name", type=str, default="pcfg_4_5_2_10_latin", help="Grammar name")
# args = parser.parse_args()
# grammar_name = args.grammar_name


pfsa_object = {}

if grammar_name.startswith("pfsa"):
    pfsa_object = get_grammar_string(grammar_name)
    if pfsa_object is None:
        raise ValueError(grammar_name)
    else:
        assert isinstance(pfsa_object, dict)
        grammar_string = pfsa_object['grammar_string']
        if pfsa_object['entropy_analytic'] > 1000:
            raise ValueError(grammar_name)
        
else:
    grammar_string = get_grammar_string(grammar_name)



print(grammar_string)
if grammar_name.startswith("pcfg") or grammar_name.startswith("preg") or grammar_name.startswith("pfsa"):
    grammar = PCFG.fromstring(grammar_string)
elif grammar_name.startswith("pcsg"):
    grammar = from_pcsg_string(grammar_string)
else:
    grammar = None
print(f"Random guess loss: {compute_random_guess_metric(grammar)}")



# LaTeX print
# print(grammar_name.replace("_", " "))
color_dict = {
    "S" : "\\textcolor{red}",
    "A" : "\\textcolor{red}",
    "[" : "\\textcolor{blue}",
    "-" : "\\textcolor{black}",
    "B" : "\\textcolor{red}",
    "C" : "\\textcolor{red}",
    "E" : "\\textcolor{red}",
    "T" : "\\textcolor{red}",
}



In [None]:
# to_latex_equation(grammar_string, color_dict=color_dict, script_notation="_")
# to_latex_equation(grammar_string, script_notation="_")

## Generation

In [None]:
num_samples = 10000
# num_samples = 10
# if grammar_name == "pcfg_cfg3b_disjoint_terminals_skewed_prob":
#     num_samples = 50000
seed = 5
sequences, sequence_to_non_terminal_applied_position_map, sequence_freq, sequence_prob_dict = get_grammatical_sentences(grammar, num_samples, seed)

len(sequences), len(sequence_freq), len(sequence_prob_dict), len(sequence_to_non_terminal_applied_position_map)

In [None]:
grammar_meta_data = {
    "grammar_name": grammar_name,
    "grammar_string": grammar_string,
    "terminals": list(grammar._lexical_index.keys()),
    "num_terminals": len(grammar._lexical_index),
    "num_nonterminals": len(set([production.lhs() for production in grammar.productions()])) if isinstance(grammar, PCFG) else len(grammar.nonterminals),
    "expected_length": np.mean([len(sequence) for sequence in sequences if len(sequence) > 0]),
}

for key in pfsa_object:
    if key in ['rank', 'entropy_analytic']:
        grammar_meta_data[key] = pfsa_object[key]

# frequency
prob_list = list(sequence_freq.values())
sum_prob = sum(prob_list) # normalize
prob_list = [x/sum_prob for x in prob_list]
entropy = 0
for prob in prob_list:
    if prob == 0:
        continue
    entropy += -1 * prob * math.log(prob, 2)

grammar_meta_data['entropy_freq_approximation'] = entropy

# generation probability
prob_list = list(sequence_prob_dict.values())
sum_prob = sum(prob_list) # normalize
prob_list = [x/sum_prob for x in prob_list]
entropy = 0
for prob in prob_list:
    if prob == 0:
        continue
    entropy += -1 * prob * math.log(prob, 2)


grammar_meta_data['entropy_prob_approximation'] = entropy


grammar_meta_data

In [None]:
# delete empty sequences
sequences = [sequence for sequence in sequences if len(sequence) > 0]
sequence_freq = {sequence: freq for sequence, freq in sequence_freq.items() if len(sequence) > 0}
sequence_prob_dict = {sequence: prob for sequence, prob in sequence_prob_dict.items() if len(sequence) > 0}
sequence_to_non_terminal_applied_position_map = {sequence: non_terminal_applied_position_map for sequence, non_terminal_applied_position_map in sequence_to_non_terminal_applied_position_map.items() if len(sequence) > 0}

len(sequences), len(sequence_freq), len(sequence_prob_dict), len(sequence_to_non_terminal_applied_position_map)

In [None]:
show_image = True

In [None]:
# lowest length index
index = 0
length = 1000000
for i in range(len(sequences)):
    if len(sequences[i]) < length:
        length = len(sequences[i])
        index = i

print(index, length)

In [None]:
fig = plot_nonterminal_map(list(sequences[index]), 
                        sequence_to_non_terminal_applied_position_map[sequences[index]], 
                        # is_hierarchy=grammar_name in grammar_details_dict
                        is_hierarchy=True
                        )

fig.update_layout(
    width=600,
    height=200
)
fig.show()
os.system("mkdir -p ../read_output/figures")
os.system("mkdir -p ../read_output/figures/sentence_annotated")
store_filename = f"../read_output/figures/sentence_annotated/{grammar_name}_annotated_sentence_{index}.pdf"
fig.write_image(store_filename)
time.sleep(2)
fig.write_image(store_filename)

In [None]:
os.system("mkdir -p ../read_output/figures/sentence_annotated")
for i in range(min(10, len(sequences))):
    print(sequences[i])
    if show_image:
        if i <= 5:
            fig = plot_nonterminal_map(list(sequences[i]), 
                                    sequence_to_non_terminal_applied_position_map[sequences[i]], 
                                    # is_hierarchy=grammar_name in grammar_details_dict
                                    # is_hierarchy=False
                                    is_hierarchy=True
                                    )
            fig.show()
            store_filename = f"../read_output/figures/sentence_annotated/{grammar_name}_annotated_sentence_{i}.pdf"
            fig.write_image(store_filename)
            time.sleep(2)
            fig.write_image(store_filename)

meta_data = {
    "non_terminal_applied_position_map": sequence_to_non_terminal_applied_position_map,
    "sequence_freq": sequence_freq,
    "sequence_prob_dict": sequence_prob_dict
}
sequences = sorted(sequences, key=lambda x: len(x[0]))


### Plot length distribution

In [None]:
import plotly.express as px
import plotly.graph_objects as go


# scatter plot of sequence lengths
fig = go.Figure()
len_sequences = [len(sequence) for sequence in sequences]
fig.add_trace(go.Scatter(x=list(range(len(len_sequences))), y=len_sequences, mode='markers'))
fig.update_layout(title="Scatter plot of sequence lengths",
                  xaxis_title="Sequence Index",
                  yaxis_title="Length")
if show_image:
    fig.show()

In [None]:
# generate pdf of length distribution
os.system("mkdir -p ../read_output/figures/grammar")
fig = go.Figure()
fig.add_trace(go.Histogram(x=len_sequences, nbinsx=100, histnorm='probability density'))
fig.update_layout(xaxis_title="Length in Tokens",
                  yaxis_title="Probability")
fig = save_image_template(fig, height=200, width=300)
# if grammar_name.startswith("pcfg_cfg3b_disjoint_terminals"):
#     fig.update_yaxes(range=[0, 0.65])
if show_image:
    fig.show()
store_filename = f"../read_output/figures/grammar/{grammar_name}_length_distribution.pdf"
print(store_filename)
fig.write_image(store_filename)
time.sleep(2)
fig.write_image(store_filename)

In [None]:
get_stat(sequences)

## Training and test data

In [None]:
train_size = num_samples//2
test_size = num_samples - train_size


# shuffle data
random.seed(seed)
random.shuffle(sequences)

ratio = (train_size) / (train_size + test_size)

non_grammatical_sentence_size = test_size
sequences = sequences[:train_size + test_size]

train_sequences, test_sequences = train_test_split(sequences, seed, sequence_freq, ratio)
data = {
        "train_sequences": train_sequences,
        "test_sequences": test_sequences,
}

In [None]:
train_sequence_freq_list = []
for sequence in set(train_sequences):
    train_sequence_freq_list.append(sequence_freq[sequence])

test_sequence_freq_list = []
for sequence in set(test_sequences):
    test_sequence_freq_list.append(sequence_freq[sequence])


fig = go.Figure()
# histogram of sequence lengths
fig.add_trace(go.Histogram(x=train_sequence_freq_list, name="Train"))
fig.add_trace(go.Histogram(x=test_sequence_freq_list, name="Test"))
fig.update_layout(title="Sequence frequency distribution",
                  xaxis_title="Sequence Count",
                  yaxis_title="# Occurrences")
# yscale log
fig.update_yaxes(type="log")
if show_image:
    fig.show()

In [None]:
len(sequence_to_non_terminal_applied_position_map), len(set(train_sequences)), len(set(test_sequences)), len(set(train_sequences).intersection(set(test_sequences)))

In [None]:
# deduplicate and combine (only for pcfg_cfg3b_disjoint_terminals_sensitivity_modification_3)
if False and grammar_name.startswith("pcfg_cfg3b_disjoint_terminals_sensitivity_modification_"):
    print("Only considering training sequences. Test sequences are dummy repeat")
    grammar_name = f"{grammar_name}_deduplicated"
    train_sequences = list(set(train_sequences + test_sequences))
    test_sequences = train_sequences
    data["train_sequences"] = train_sequences
    data["test_sequences"] = test_sequences

In [None]:
len(sequence_to_non_terminal_applied_position_map), len(set(train_sequences)), len(set(test_sequences)), len(set(train_sequences).intersection(set(test_sequences)))

## Randomly sampled data

In [None]:
if True:
    non_grammatical_sequences = random_sentence_generator(
                                        grammar=grammar,
                                        num_samples=non_grammatical_sentence_size,
                                        # num_samples=100,
                                        min_length=min(len_sequences),
                                        max_length=max(len_sequences),
                                        sampled_sequences=set(sequence_freq.keys()),
                                        seed=seed,
                                        timeout=1000,
                                        terminal_freq = compute_terminal_freq(train_sequences + test_sequences)
    )
    data["non_grammatical_sequences"] = non_grammatical_sequences
    get_stat(data["non_grammatical_sequences"])

## Grammar perturbation

In [None]:
grammar_edit = False
if grammar_name.startswith("pcfg"):
    grammar_edit = True
    total_runs = 5 # construct 5 different perturbed grammars
    num_grammar_edit_sample_per_run = 200
    verbose=False
    grammar_edit_dict = {}
    if grammar_edit:
        # edit_level = 2
        # edit = 1
        # forced_action = "all"

        for edit_level in [1, 2, 3, 4, 5, 6, 7]:
            for edit in [1, 2, 3, 4][:1]:
                for forced_action in ["insert", "delete", "replace", "all"][-1:]:
                    for run_ids in range(total_runs):
                        print(f"Edit level: {edit_level}, Edit: {edit}, Forced action: {forced_action}")
                        try:
                            grammar_perturbed, perturbation_result = get_perturbed_grammar(grammar, 
                                                                                        grammar_name, 
                                                                                        level=edit_level, 
                                                                                        edit=edit, 
                                                                                        forced_action=forced_action if forced_action != "all" else None,
                                                                                        seed=seed+run_ids, 
                                                                                        verbose=False)


                            grammar_edit_sequences, \
                            grammar_edit_sequence_to_non_terminal_applied_position_map, \
                            grammar_edit_sequence_freq, \
                            grammar_edit_sequence_prob_dict = get_nongrammatical_sentences_from_perturbed_grammar(
                                                            base_grammar=grammar,
                                                            perturbed_grammar=grammar_perturbed, 
                                                            num_samples=num_grammar_edit_sample_per_run, 
                                                            seed=seed)
                            
                            if len(grammar_edit_sequences) == 0:
                                continue

                            if f"non_grammatical_test_sequences_grammar_edit_{edit_level}_{edit}_{forced_action}" not in data:
                                data[f"non_grammatical_test_sequences_grammar_edit_{edit_level}_{edit}_{forced_action}"] = []
                            data[f"non_grammatical_test_sequences_grammar_edit_{edit_level}_{edit}_{forced_action}"] += grammar_edit_sequences
                            if f"grammar_edit_{edit_level}_{edit}_{forced_action}" not in grammar_edit_dict:
                                grammar_edit_dict[f"grammar_edit_{edit_level}_{edit}_{forced_action}"] = []    
                            grammar_edit_dict[f"grammar_edit_{edit_level}_{edit}_{forced_action}"].append({
                                "base_grammar": grammar,
                                "perturbed_grammar": grammar_perturbed,
                                "perturbation_result": perturbation_result,
                                "sequences": grammar_edit_sequences,
                                "non_terminal_applied_position_map": grammar_edit_sequence_to_non_terminal_applied_position_map,
                                "sequence_freq": grammar_edit_sequence_freq,
                                "sequence_prob_dict": grammar_edit_sequence_prob_dict,
                                "edit_level": edit_level,
                                "edit": edit,
                                "forced_action": forced_action
                            })


                            print(f"Generated sequences: {len(grammar_edit_sequences)}")

                            if verbose:
                                for nonterminal in perturbation_result:
                                    print(f"{nonterminal}:")
                                    for i, (production_before, production_after) in enumerate(zip(grammar.productions(nonterminal), grammar_perturbed.productions(nonterminal))):
                                        if i not in perturbation_result[nonterminal]:
                                            continue
                                        # print(f"{i}: {production_before.rhs()} => {production_after.rhs()}")
                                        print(f"{production_before.rhs()} => {production_after.rhs()}")
                                        print(perturbation_result[nonterminal][i])
                                        print()

                                if show_image:
                                    for index in range(min(1, len(sequences))):
                                        plot_nonterminal_map_with_edit(
                                            token_sequence=list(grammar_edit_sequences[index]),
                                            nonterminal_applied_position_map=grammar_edit_sequence_to_non_terminal_applied_position_map[grammar_edit_sequences[index]],
                                            grammar_perturbed=grammar_perturbed,
                                            perturbation_result=perturbation_result,
                                            edit_level=edit_level-1,
                                            verbose=False,
                                        ).show()
                        except:
                            pass

In [None]:
perturbation_result

## Edit distance (lexer)

In [None]:
edit_distance_lexer = True
if edit_distance_lexer:
    edit_distance_perturb_dict = {}
    edit_distance_non_terminal_mapping = {}

    # for perturb_start_index, perturb_end_index in [(0, 25), (25, 50), (50, 200)]:
    # for perturb_start_index, perturb_end_index in [(0, 40), (40, 100), (100, 1000)]:
    stat = get_stat(test_sequences)
    for perturb_start_index, perturb_end_index in [(stat['min'], stat['max'])]:
        if perturb_start_index == perturb_end_index:
            print("Min and max length are same")
            perturb_start_index = 1
        for edit_distance in [1, 2, 3]:
            # test
            perturbed_sequences, perturb_position_dict, perturbed_sequence_to_non_terminal_applied_position_map = \
                    generate_similar_sequences(grammar, 
                                            test_sequences, 
                                            sequence_to_non_terminal_applied_position_map,
                                            edit_distance, 
                                            seed,
                                            sampled_sequences=set(sequence_freq.keys()), 
                                            perturb_start_index=perturb_start_index, 
                                            perturb_end_index=perturb_end_index)
            data[f"non_grammatical_test_sequences_edit_distance_{edit_distance}_{perturb_start_index}_{perturb_end_index}"] = perturbed_sequences
            edit_distance_perturb_dict[f"non_grammatical_test_sequences_edit_distance_{edit_distance}_{perturb_start_index}_{perturb_end_index}"] = perturb_position_dict
            edit_distance_non_terminal_mapping[f"non_grammatical_test_sequences_edit_distance_{edit_distance}_{perturb_start_index}_{perturb_end_index}"] = perturbed_sequence_to_non_terminal_applied_position_map


    stat = get_stat(train_sequences)
    for perturb_start_index, perturb_end_index in [(stat['min'], stat['max'])]:
        if perturb_start_index == perturb_end_index:
            print("Min and max length are same")
            perturb_start_index = 1
        for edit_distance in [1, 2, 3]:
            # train
            perturbed_sequences, perturb_position_dict, perturbed_sequence_to_non_terminal_applied_position_map = \
                    generate_similar_sequences(grammar, 
                                            train_sequences, 
                                            sequence_to_non_terminal_applied_position_map,
                                            edit_distance, 
                                            seed,
                                            sampled_sequences=set(sequence_freq.keys()), 
                                            perturb_start_index=perturb_start_index, 
                                            perturb_end_index=perturb_end_index)
            data[f"non_grammatical_train_sequences_edit_distance_{edit_distance}_{perturb_start_index}_{perturb_end_index}"] = perturbed_sequences
            edit_distance_perturb_dict[f"non_grammatical_train_sequences_edit_distance_{edit_distance}_{perturb_start_index}_{perturb_end_index}"] = perturb_position_dict
            edit_distance_non_terminal_mapping[f"non_grammatical_train_sequences_edit_distance_{edit_distance}_{perturb_start_index}_{perturb_end_index}"] = perturbed_sequence_to_non_terminal_applied_position_map

    meta_data_edit_distance = {
        "non_terminal_applied_position_map": edit_distance_non_terminal_mapping,
        "perturbation_result": edit_distance_perturb_dict
    }            

In [None]:
for key in data:
    print(key, len(data[key]), len(set(data[key])))
    # get_stat(data[key])
    # for sequence in data[key][:5]:
    #     print(sequence)
    # print()


In [None]:
g2 = grammar_name
if os.path.isfile(f"../data_backup/{g2}/sequences_w_edit_distance_{g2}_10000_5.pkl"):
    with open(f"../data_backup/{g2}/sequences_w_edit_distance_{g2}_10000_5.pkl", 'rb') as f:
        data_g2 = pickle.load(f)
        # print(data_g2.keys()) 
        for key in data_g2:
            # print(key)
            assert key in data
            print(key, len(data_g2[key]), len(set(data_g2[key])))
            for i, sequence in enumerate(data_g2[key]):
                assert sequence == data[key][i]
            

## Store

In [None]:
store_path = "../data"
os.makedirs(f"{store_path}/{grammar_name}", exist_ok=True)
filename = f"{store_path}/{grammar_name}/sequences_w_edit_distance_{grammar_name}_{train_size + test_size}_{seed}.pkl"
with open(filename, 'wb') as f:
    pickle.dump(data, f)

In [None]:
assert num_samples == train_size + test_size
filename = f"{store_path}/{grammar_name}/meta_data_{grammar_name}_{num_samples}_{seed}.pkl"
with open(filename, 'wb') as f:
    pickle.dump(meta_data, f)

In [None]:
if edit_distance_lexer:
    filename = f"{store_path}/{grammar_name}/meta_data_lexer_edit_{grammar_name}_{num_samples}_{seed}.pkl"
    with open(filename, 'wb') as f:
        pickle.dump(meta_data_edit_distance, f)

In [None]:
if grammar_edit:
    filename = f"{store_path}/{grammar_name}/meta_data_grammar_edit_{grammar_name}_{num_samples}_{seed}.pkl"
    with open(filename, 'wb') as f:
        pickle.dump(grammar_edit_dict, f)

In [None]:
filename = f"{store_path}/{grammar_name}/grammar_meta_data_{grammar_name}.pkl"
with open(filename, 'wb') as f:
    pickle.dump(grammar_meta_data, f)

In [None]:
grammar_name

In [None]:
fa

In [None]:
meta_data.keys()

In [None]:
import pickle
filename = "../data/pcfg_cfg3b_disjoint_terminals/sequences_w_edit_distance_pcfg_cfg3b_disjoint_terminals_10000_5.pkl"
with open(filename, 'rb') as f:
    data = pickle.load(f)

data.keys()

In [None]:
# "pfsa_states_4_symbols_4_index_3_alphabet_0-1-2"
for num_states in [2, 4, 5, 8, 12, 16]:
    for num_symbols in [2, 4, 6, 8, 10]:
        for index in range(20):
            if index > min(num_states-1, num_symbols):
                continue
            alphabet = [str(x) for x in range(num_symbols)]
            print(f"\"pfsa_states_{num_states}_symbols_{num_symbols+1}_index_{index}_alphabet_{'-'.join(alphabet)}\" \\")

In [None]:
# "pcfg_max-depth_3_max-breadth_3_rules_3_skewness_3_alphabet_0-1-2-3-4-5-6-7-8-9"

for max_depth in [2, 4, 8, 16, 32]:
    for max_breadth in [2, 4, 8, 16]:
        for production_per_non_terminal in [2, 3, 4]:
            for skewness in [0, 1, 2]:
                for num_terminals in [5, 10]:
                    alphabet = [str(x) for x in range(num_terminals)]
                    print(f"\"pcfg_max-depth_{max_depth}_max-breadth_{max_breadth}_rules_{production_per_non_terminal}_skewness_{skewness}_alphabet_{'-'.join(alphabet)}\" \\")
                    