In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import itertools
import re

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from scipy.spatial.distance import cdist
import seaborn as sns
from tqdm import tqdm

from src.analysis import analogy
from src.analysis.state_space import StateSpaceAnalysisSpec, \
    prepare_state_trajectory, flatten_trajectory
from src.datasets.speech_equivalence import SpeechHiddenStateDataset


In [None]:
base_model = "w2v2_8"

model_class = "discrim-rnn_32-mAP1"
model_name = "word_broad_10frames_fixedlen25"

inflection_results_path = "inflection_results.parquet"
all_cross_instances_path = "all_cross_instances.parquet"

train_dataset = "librispeech-train-clean-100"
hidden_states_path = f"outputs/hidden_states/{base_model}/{train_dataset}.h5"
state_space_specs_path = f"state_space_spec.h5"
embeddings_path = f"outputs/model_embeddings/{train_dataset}/{base_model}/{model_class}/{model_name}/{train_dataset}.npy"

output_dir = f"."

pos_counts_path = "data/pos_counts.pkl"

seed = 42

metric = "cosine"

agg_fns = [
    "mean",
]

## Prepare model representations

In [None]:
if embeddings_path == "ID":
    model_representations = SpeechHiddenStateDataset.from_hdf5(hidden_states_path).states
else:
    with open(embeddings_path, "rb") as f:
        model_representations: np.ndarray = np.load(f)
state_space_spec = StateSpaceAnalysisSpec.from_hdf5(state_space_specs_path)
assert state_space_spec.is_compatible_with(model_representations)

In [None]:
trajectory_agg = prepare_state_trajectory(model_representations, state_space_spec, 
                                          agg_fn_spec="mean", agg_fn_dimension=1)

In [6]:
agg, agg_src = flatten_trajectory(trajectory_agg)

## Prepare metadata

In [7]:
cuts_df = state_space_spec.cuts.xs("phoneme", level="level").drop(columns=["onset_frame_idx", "offset_frame_idx"])
cuts_df["label_idx"] = cuts_df.index.get_level_values("label").map({l: i for i, l in enumerate(state_space_spec.labels)})
cuts_df["frame_idx"] = cuts_df.groupby(["label", "instance_idx"]).cumcount()
cuts_df = cuts_df.reset_index().set_index(["label", "instance_idx", "frame_idx"]).sort_index()

In [8]:
cut_phonemic_forms = cuts_df.groupby(["label", "instance_idx"]).description.agg(' '.join)

In [9]:
word_freq_df = pd.read_csv("data/WorldLex_Eng_US.Freq.2.txt", sep="\t", index_col="Word")
# compute weighted average frequency across domains
word_freq_df["BlogFreq_rel"] = word_freq_df.BlogFreq / word_freq_df.BlogFreq.sum()
word_freq_df["TwitterFreq_rel"] = word_freq_df.TwitterFreq / word_freq_df.TwitterFreq.sum()
word_freq_df["NewsFreq_rel"] = word_freq_df.NewsFreq / word_freq_df.NewsFreq.sum()
word_freq_df["Freq"] = word_freq_df[["BlogFreq", "TwitterFreq", "NewsFreq"]].mean(axis=1) \
    * word_freq_df[["BlogFreq", "TwitterFreq", "NewsFreq"]].sum().mean()
word_freq_df["LogFreq"] = np.log10(word_freq_df.Freq)

In [10]:
all_cross_instances = pd.read_parquet(all_cross_instances_path)

In [11]:
inflection_results_df = pd.read_parquet(inflection_results_path)

## Behavioral tests

In [33]:
experiments = {
    "morph_related": {
        "group_by": ["morph_related"],
    },
    "morph_to_non": {
        "base_query": "morph_related",
        "inflected_query": "not morph_related",
    },
    "non_to_morph": {
        "base_query": "not morph_related",
        "inflected_query": "morph_related",
    },
}

In [None]:
experiment_results = pd.concat({
    experiment: analogy.run_experiment_equiv_level(
        experiment, config,
        state_space_spec, all_cross_instances,
        agg, agg_src,
        num_samples=5000,
        seed=seed,
        device="cuda")
    for experiment, config in tqdm(experiments.items(), unit="experiment")
}, names=["experiment"])
experiment_results["correct"] = experiment_results.predicted_label == experiment_results.gt_label
experiment_results

### Save

In [35]:
experiment_results.to_csv(f"{output_dir}/experiment_results.csv")