# Probe研究

主要研究SSL模型对于不同层级的表征

In [4]:
# 导入库
# 其他常用库
import glob
import os
import textgrids
from tqdm import tqdm
import torch
import numpy as np
import pandas as pd
import seaborn as sn
import matplotlib.pyplot as plt
from tqdm import tqdm
import IPython.display as ipd

# Transformer库
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
from transformers import WhisperFeatureExtractor, WhisperModel
from transformers import WhisperProcessor, WhisperForConditionalGeneration
# Datasets库
from datasets import load_dataset, load_from_disk

from jiwer import wer

import cca_core

In [3]:
# 从本地加载huggingface上保存下来的数据集
ds = load_from_disk("/data/chenhonghua/datasets/librispeech_test")
ds = ds['test.clean'] # 选择librispeech中的test clean数据集 
# ds = load_from_disk("/data/chenhonghua/datasets/librispeech_asr_dummy")

In [8]:
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large").to("cuda")
processor = WhisperProcessor.from_pretrained("openai/whisper-large")

### layer-wise analysis代码
1. 让librispeech和对齐的文本进行匹配: *read_librispeech_alignments.py*
2. 采集数据样本: *create_data_samples.py*，可以采集多个数据样本
3. 提取模型的表征：*extract_rep.py*，得到表征的级别用span表示，模式用rep_type表示。
4. 与上下文不相关的word embedding，从*prepare_wordsim_data.py*中获取数据，用*extract_static_word_embed.py*获取
5. 评估layer-wise特征：*save_embeddings.py*下载数据，*get_scores.py*进行CCA、MI和WordSim打分

In [6]:
# read_librispeech_alignments
#from utils import save_dct, write_to_file, load_dct, read_lst

class LibrispeechAlign:
    def save_data(self, data_dir, dataset_split, audio_dir, audio_ext):
        """
        Save alignment info as a dictionary of token mapped to a list of occurences with time stamps
        Also, updates the count dictionary and list of tokens
        """
        self.audio_dir = audio_dir
        self.audio_ext = audio_ext
        self.dataset_split = dataset_split
        self.data_dir = data_dir

        token_lst_dct = self.read_data()
        self.get_token_alignment_ordered_lst(token_lst_dct, data_dir, dataset_split)
        token_alignment_dct = self.get_token_alignment_dct(token_lst_dct)
        for key, value in token_alignment_dct.items():
            save_dct(
                os.path.join(data_dir, f"alignment_{key}_{dataset_split}.json"), value
            )

        if "train" in dataset_split:
            count_fn, token_lst_fn = {}, {}
            for key in ["phone", "word"]:
                count_fn[key] = os.path.join(data_dir, f"{key}_count.json")
                token_lst_fn[key] = os.path.join(data_dir, f"{key}.lst")
            self.update_tokens(count_fn, token_lst_fn, token_alignment_dct)

    def read_data(self):
        """
        Read data from textgrids into a list of tuples
        """
        wrd_lst, phn_lst = [], []
        parent_dir = os.path.join(self.data_dir, self.dataset_split)
        all_fns = glob.glob(os.path.join(parent_dir, "*/*/*.TextGrid"))
        for fname in tqdm(all_fns):
            self.get_info(fname, phn_lst, wrd_lst)
        token_lst_dct = {"phone": phn_lst, "word": wrd_lst}

        return token_lst_dct

    def get_token_alignment_dct(self, token_lst_dct):
        """
        Convert a list of token-level alignments to a dictionary for a list of occurences of each token type
        """
        token_alignment_dct = {}
        for key, value in token_lst_dct.items():
            token_alignment_dct[key] = {}
            for item in tqdm(value):
                utt_id, start, end, token = item.split(" ")
                audio_path = os.path.join(
                    self.audio_dir,
                    "/".join(utt_id.split("-")[:2]),
                    utt_id + "." + self.audio_ext,
                )
                _ = token_alignment_dct[key].setdefault(token, [])
                token_alignment_dct[key][token].append((utt_id, audio_path, start, end))

        return token_alignment_dct

    def get_token_alignment_ordered_lst(self, token_lst_dct, data_dir, dataset_split):
        """
        Save the list of token-level alignments to a tsv file
        """
        for key, value in token_lst_dct.items():
            write_str = []
            for item in tqdm(value):
                write_str.append("\t".join(item.split(" ")))
            write_fn = os.path.join(data_dir, f"alignment_{key}_{dataset_split}.tsv")
            write_to_file("\n".join(write_str), write_fn)

    def phn_map(self, phn_label):
        if phn_label == "sil":
            return "SIL"
        elif phn_label[-1] in ["0", "1", "2"]:
            return phn_label[:-1].lower()
        else:
            return phn_label.lower()

    def txt_from_tier(self, tier_content, data_lst, fname, unit):
        """
        Save as filename start end label
        """
        for item in tier_content:
            label = item.text
            if label:  # check that it is non-empty
                start = str(item.xmin)
                end = str(item.xmax)
                if "phone" in unit:  # map to the traditional 39 phone phn set
                    label = self.phn_map(label)
                text_out = " ".join([fname, start, end, label])
                if label not in ["spn", "sp"]:
                    data_lst.append(text_out)

    def get_info(self, fname, phn_lst, wrd_lst):
        grid = textgrids.TextGrid(fname)
        fname = fname.split("/")[-1].split(".")[0]
        self.txt_from_tier(grid["phones"], phn_lst, fname, "phone")
        self.txt_from_tier(grid["words"], wrd_lst, fname, "word")

    def update_tokens(self, count_fn, token_lst_fn, token_alignment_dct):
        count_dct = {}
        for token_type, value in count_fn.items():
            if os.path.exists(value):
                count_dct[token_type] = load_dct(value)
            else:
                count_dct[token_type] = {}
            alignment_info_dct = token_alignment_dct[token_type]
            for token, alignment_info_lst in alignment_info_dct.items():
                _ = count_dct[token_type].setdefault(token, 0)
                count_dct[token_type][token] += len(alignment_info_lst)

            save_dct(value, count_dct[token_type])
            dct = count_dct[token_type]
            sorted_token_lst = sorted(dct, key=dct.get, reverse=True)
            write_to_file("\n".join(sorted_token_lst), token_lst_fn[token_type])


def combine_alignments(data_dir, data_split, token_type):
    combined_dct = {}
    if data_split == "train-clean":
        constitutes = ["train-clean-100", "train-clean-360"]
    elif data_split == "train":
        constitutes = ["train-clean-100", "train-clean-360", "train-other-500"]
    for sub_data_split in constitutes:
        alignment_dct = load_dct(
            os.path.join(data_dir, f"alignment_{token_type}_{sub_data_split}.json")
        )
        for token in tqdm(alignment_dct):
            alignment_lst = alignment_dct[token]
            _ = combined_dct.setdefault(token, [])
            combined_dct[token].extend(alignment_lst)
    
    save_dct(
        os.path.join(data_dir, f"alignment_{token_type}_{data_split}.json"),
        combined_dct,
    )


def save_one_hot_encodings(token, data_dir, save_dir, num_tokens=-1):
    token_lst = read_lst(os.path.join(data_dir, f"{token}.lst"))
    if token == "word":
        assert num_tokens != -1
        token_lst.remove("<unk>")
        token_lst = token_lst[:num_tokens]
    rep_mat = np.eye(len(token_lst))
    rep_dct = {token: one_hot_arr for token, one_hot_arr in zip(token_lst, rep_mat)}
    save_dct(os.path.join(save_dir, f"{token}_embed.pkl"), rep_dct)

# combine_alignments()
# save_one_hot_encodings()

In [7]:
# create_data_samples
def sample_utterances(data_dir, save_fn, audio_ext, dir_depth, num_samples):
    """
    Save utterance ids and corresponding paths to the audio in a file
    """
    search_path = "/".join(["*"] * dir_depth)
    all_files = glob(os.path.join(data_dir, search_path + "." + audio_ext))
    chosen_fnames = random.sample(all_files, num_samples)

    chosen_sent_ids = [Path(fname).name.split(".")[0] for fname in chosen_fnames]
    write_lst = [
        "\t".join([sent_id, fname])
        for sent_id, fname in zip(chosen_sent_ids, chosen_fnames)
    ]
    write_to_file("\n".join(write_lst), save_fn)


class tokenLevelSamples:
    def __init__(
        self, data_split, data_dir, data_sample, token, save_dir, dur_threshold=10000
    ):
        self.data_split = data_split
        self.data_dir = data_dir
        self.data_sample = data_sample
        self.save_dir = save_dir
        self.token = token
        self.dur_threshold = dur_threshold  # seconds

        os.makedirs(self.save_dir, exist_ok=True)

    def sample_tokens(self, token_lst, min_cnt, max_cnt, alignment_dct):
        """
        Sample alignments such that each token has a "good" representation
        """
        sampled_alignments = {}
        tot_dur = 0
        for token in token_lst:
            all_alignments = alignment_dct[token]
            num_instances = len(all_alignments)
            if "train" in self.data_split:
                min_cnt = min([min_cnt, num_instances])
                max_cnt = min([max_cnt, num_instances])
                num_samples = random.randint(min_cnt, max_cnt)
            else:
                num_samples = min([num_instances, min_cnt])
            chosen_alignments_idx = np.random.choice(
                np.arange(0, num_instances), num_samples, replace=False
            )
            chosen_alignments = [all_alignments[idx] for idx in chosen_alignments_idx]
            for sent_id, fname, start_time, end_time in chosen_alignments:
                start_time, end_time = float(start_time), float(end_time)
                _ = sampled_alignments.setdefault(sent_id, [])
                sampled_alignments[sent_id].append((fname, start_time, end_time, token))
                tot_dur += end_time - start_time
        print("Total duration of %s spans: %.2f seconds" % (self.token, tot_dur))
        self.split_into_sublists(sampled_alignments)

    def save_to_file(self, current_sample, alignment_dct):
        print(
            "Saving %dth split of %s spans for %s sample %d"
            % (current_sample, self.token, self.data_split, self.data_sample)
        )
        save_dct(
            os.path.join(
                self.save_dir,
                f"{self.data_split}_segments_sample{self.data_sample}_{current_sample}.json",
            ),
            alignment_dct,
        )

    def split_into_sublists(self, sampled_alignments):
        """
        Split sampled alignments into sublists
        """
        current_sample, current_dur = 0, 0
        alignment_dct = {}

        for sent_id, alignment_lst in sampled_alignments.items():
            for fname, start_time, end_time, token in alignment_lst:
                current_dur += end_time - start_time
                if sent_id not in alignment_dct:
                    alignment_dct[sent_id] = [fname]
                alignment_dct[sent_id].append((start_time, end_time, token))
                if current_dur > self.dur_threshold:
                    self.save_to_file(current_sample, alignment_dct)
                    current_sample += 1
                    current_dur = 0
                    alignment_dct = {}
        if current_dur > 0:
            self.save_to_file(current_sample, alignment_dct)

    def sample_phone_alignments(self, num_phones=39):
        """
        Sample phone alignments for MI experiments
        """
        if "train" in self.data_split:
            min_cnt, max_cnt = 3000, 7000
        else:
            min_cnt, max_cnt = 200, 1e6
        phn_lst = read_lst(os.path.join(self.data_dir, "phone.lst"))
        alignment_dct = load_dct(
            os.path.join(self.data_dir, f"alignment_phone_{self.data_split}.json")
        )
        phn_lst.remove("SIL")
        assert len(phn_lst) == num_phones
        self.sample_tokens(phn_lst, min_cnt, max_cnt, alignment_dct)

    def sample_word_alignments(self, num_words=500):
        """
        Sample word alignments for MI experiments
        """
        # from nltk.corpus import stopwords
        # english_stop_words = stopwords.words("english")
        if "train" in self.data_split:
            if num_words == 350:
                min_cnt = 800
            elif num_words == 500:
                min_cnt = 600
            max_cnt = 1200
        else:
            min_cnt, max_cnt = 15, 1e6
        alignment_dct = load_dct(
            os.path.join(self.data_dir, f"alignment_word_{self.data_split}.json")
        )
        wrd_lst = read_lst(os.path.join(self.data_dir, "word.lst"))
        wrd_lst.remove("<unk>")
        # wrd_lst = list(set(wrd_lst) - set(english_stop_words))[:num_words]
        self.sample_tokens(wrd_lst[:num_words], min_cnt, max_cnt, alignment_dct)


class AllWrdSegments:
    def __init__(
        self,
        alignment_data_dir,
        word_lst_pth,
        save_dir,
        dur_thresh=10000,
        num_instances=200,
    ):
        self.data_dir = alignment_data_dir
        self.dur_thresh = dur_thresh
        self.max_cnt = num_instances
        self.word_lst_pth = word_lst_pth
        self.save_dir = save_dir
        os.makedirs(self.save_dir, exist_ok=True)

    def get_tot_dur(self, wrd_segment_lst):
        start_times = np.array([float(item[2]) for item in wrd_segment_lst])
        end_times = np.array([float(item[3]) for item in wrd_segment_lst])
        tot_num_secs = np.sum(end_times - start_times)
        return tot_num_secs

    def find_valid_split_idx(self, split_to_dur, tot_segment_dur):
        if tot_segment_dur > self.dur_thresh:
            keys = [split_idx for split_idx, dur in split_to_dur.items() if dur == 0]
        else:
            keys = [
                split_idx
                for split_idx, dur in split_to_dur.items()
                if (tot_segment_dur + dur) < self.dur_thresh
            ]
        if len(keys) == 0:
            return len(split_to_dur)
        else:
            return keys[0]

    def sample_word_segments(self):
        alignment_dct = load_dct(
            os.path.join(self.data_dir, f"alignment_word_train.json")
        )
        curr_split_idx = 0
        split_to_segments = {0: []}
        split_to_dur = {0: 0}
        split_to_labels = {}
        wrd_lst = read_lst(self.word_lst_pth)
        for wrd in wrd_lst:
            tot_num_wrd_segments = len(alignment_dct[wrd])
            num_samples = np.min([self.max_cnt, tot_num_wrd_segments])
            chosen_alignments_idx = np.random.choice(
                np.arange(0, tot_num_wrd_segments), num_samples, replace=False
            )
            wrd_segments = [alignment_dct[wrd][idx] for idx in chosen_alignments_idx]
            tot_segment_dur = self.get_tot_dur(wrd_segments)
            split_idx = self.find_valid_split_idx(split_to_dur, tot_segment_dur)
            _ = split_to_segments.setdefault(split_idx, [])
            _ = split_to_dur.setdefault(split_idx, 0)
            _ = split_to_labels.setdefault(split_idx, [])
            split_to_dur[split_idx] += tot_segment_dur
            split_to_segments[split_idx].extend(wrd_segments)
            split_to_labels[split_idx].extend([wrd] * len(wrd_segments))
        for split_idx, labels_lst in split_to_labels.items():
            segment_lst = [
                ",".join(list(item)) for item in split_to_segments[split_idx]
            ]
            write_to_file(
                "\n".join(labels_lst),
                os.path.join(self.save_dir, f"labels_{split_idx}.lst"),
            )
            write_to_file(
                "\n".join(segment_lst),
                os.path.join(self.save_dir, f"word_segments_{split_idx}.lst"),
            )


def sample_segments(
    token_type,
    data_dir,
    data_split,
    num_tokens,
    data_sample,
    save_dir,
    dur_threshold=10000,
):
    sample_obj = tokenLevelSamples(
        data_split, data_dir, data_sample, token_type, save_dir, dur_threshold
    )
    getattr(sample_obj, f"sample_{token_type}_alignments")(num_tokens)


def sample_all_word_instances(
    alignment_data_dir, word_lst_pth, save_dir, dur_thresh=10000, num_instances=200
):
    """
    Sample word instances for processing word-by-word
    """
    sample_obj = AllWrdSegments(
        alignment_data_dir, word_lst_pth, save_dir, dur_thresh, num_instances
    )
    sample_obj.sample_word_segments()
# 获取不同level的样本
# sample_utterances
# sample_segments
# sample_all_word_instances

In [8]:
# extract_rep.py
#from model_utils import ModelLoader, FeatExtractor
#from utils import read_lst, load_dct, write_to_file

def save_rep(
    model_name,
    ckpt_pth,
    save_dir,
    utt_id_fn,
    model_type="pretrained",
    rep_type="contextualized",
    dict_fn=None,
    fbank_dir=None,
    offset=False,
    mean_pooling=False,
    span="frame",
    pckg_dir=None,
):
    """
    Extract layer-wise representations from the model

    ckpt_pth: path to the model checkpoint
    save_dir: directory where the representations are saved
    utt_id_fn: identifier for utterances
    model_type: pretrained or finetuned
    rep_type: contextualized or local or quantized
    dict_fn: path to dictionary file in case of finetuned models
    fbank_dir: directory that has filterbanks stored
    offset: span representation attribute
    mean_pooling: span representation attribute
    span: frame | phone | word
    """
    assert rep_type in ["local", "quantized", "contextualized"]

    model_obj = ModelLoader(ckpt_pth, model_type, pckg_dir, dict_fn)
    encoder, task_cfg = getattr(model_obj, model_name.split("_")[0])()

    Path(save_dir).mkdir(exist_ok=True, parents=True)
    if ".tsv" in utt_id_fn:
        utt_id_lst = read_lst(utt_id_fn)
        label_lst = None
    else:
        utt_id_dct = load_dct(utt_id_fn)
        utt_id_lst = list(utt_id_dct.keys())
        label_lst = []
        label_lst_fn = os.path.join(
            save_dir, "..", f'labels_{save_dir.split("/")[-1]}.lst'
        )
    rep_dct = {}
    write_flag = True
    # local representations
    transformed_fbank_lst, truncated_fbank_lst = [], []
    # quantized representations
    quantized_features, quantized_indices = [], []
    quantized_features_dct, discrete_indices_dct = {}, {}

    start = time.time()
    for item in tqdm(utt_id_lst):
        if span == "frame":
            time_stamp_lst = None
            utt_id, wav_fn = item.split("\t")
        else:
            utt_id = item
            wav_fn = utt_id_dct[utt_id][0]
            time_stamp_lst = utt_id_dct[utt_id][1:]
        extract_obj = FeatExtractor(
            encoder,
            utt_id,
            wav_fn,
            rep_type,
            model_name,
            fbank_dir,
            task_cfg,
            offset=offset,
            mean_pooling=mean_pooling,
        )
        getattr(extract_obj, model_name.split("_")[0])()
        if rep_type == "local":
            extract_obj.extract_local_rep(
                rep_dct, transformed_fbank_lst, truncated_fbank_lst
            )

        elif rep_type == "contextualized":
            extract_obj.extract_contextualized_rep(rep_dct, time_stamp_lst, label_lst)

        elif rep_type == "quantized":
            extract_obj.extract_quantized_rep(
                quantized_features,
                quantized_indices,
                quantized_features_dct,
                discrete_indices_dct,
            )

    if span in ["phone", "word"]:
        write_to_file("\n".join(label_lst), label_lst_fn)

    if rep_type != "quantized":
        extract_obj.save_rep_to_file(rep_dct, save_dir)

    if rep_type == "local":
        if "avhubert" not in model_name:
            truncated_fbank_mat = np.concatenate(truncated_fbank_lst, 0)
            np.save(os.path.join(fbank_dir, "all_features.npy"), truncated_fbank_mat)
            sfx = ""
        else:
            sfx = "_by4"
        transformed_fbank_mat = np.concatenate(transformed_fbank_lst, 0)
        np.save(
            os.path.join(fbank_dir, f"all_features_downsampled{sfx}.npy"),
            transformed_fbank_mat,
        )

    elif rep_type == "quantized":
        rep_mat = np.concatenate(quantized_features, 0)
        idx_mat = np.concatenate(quantized_indices, 0)
        np.save(os.path.join(save_dir, "features.npy"), rep_mat)
        np.save(os.path.join(save_dir, "indices.npy"), idx_mat)
        save_dct(
            os.path.join(save_dir, "quantized_features.pkl"), quantized_features_dct
        )
        save_dct(os.path.join(save_dir, "discrete_indices.pkl"), discrete_indices_dct)
        
    print("%s representations saved to %s" % (rep_type, save_dir))

    print("Time required: %.1f mins" % ((time.time() - start) / 60))


In [9]:
# save_embeddings

In [10]:
# get_scores
class getCCA:
    def __init__(
        self,
        model_name,
        fbank_dir,
        rep_dir,
        exp_name,
        base_layer=0,
        rep_dir2=None,
        embed_dir=None,
        sample_data_fn=None,
        span="phone",
        mean_score=False,
        eval_single_layer=False,
        layer_num=-1,
    ):
        """
        exp_name: cca-mel | cca-intra | cca-inter | cca-glove | cca-agwe
        """
        print(eval_single_layer)
        if eval_single_layer:
            assert layer_num != -1
        self.layer_num = layer_num
        self.eval_single_layer = eval_single_layer
        self.num_conv_layers = LAYER_CNT[model_name]["local"]
        self.num_transformer_layers = LAYER_CNT[model_name]["contextualized"]
        self.fbank_dir = fbank_dir
        self.rep_dir = rep_dir
        self.base_layer = base_layer
        self.rep_dir2 = rep_dir2
        self.embed_fn = os.path.join(embed_dir, f'{exp_name.split("_")[-1]}_embed.pkl')
        self.sample_data_fn = sample_data_fn
        self.model_name = model_name
        self.score_dct = {}
        if exp_name in ["cca_glove", "cca_agwe", "cca_word"]:
            assert span == "word"
        elif exp_name == "cca_phone":
            assert span == "phone"
        self.span = span
        self.exp_name = exp_name
        self.mean_score = mean_score

    def get_score_flag(self, layer_id):
        get_score = False
        if self.eval_single_layer:
            if self.layer_num == layer_id:
                get_score = True
        else:
            get_score = True
        return get_score

    def get_cca_score(
        self,
        view1,
        view2,
        rep_dir,
        layer_id,
        label_lst=None,
        force_train=False,
        subset=None,
    ):
        start_time = time.time()
        sim_score = tools.get_cca_score(
            view1,
            view2,
            rep_dir,
            layer_id,
            self.exp_name,
            label_lst=label_lst,
            subset=subset,
            force_train=force_train,
            mean_score=self.mean_score,
        )
        self.score_dct[layer_id] = sim_score

        print_score = np.round(sim_score, 2)
        if isinstance(layer_id, int):
            layer_type = "Transformer"
            layer_num = layer_id
        elif "C" in layer_id:
            layer_type = "Conv"
            layer_num = layer_id[1:]
        elif "T" in layer_id:
            layer_type = "Transformer"
            layer_num = layer_id[1:]
        print(
            f"[{format_time(start_time)}] {layer_type} layer {layer_num}: {print_score}"
        )
        return sim_score

    def cca_mel(self):
        rep_dir_contextualized = os.path.join(
            self.rep_dir, "contextualized", "frame_level"
        )
        rep_dir_local = os.path.join(self.rep_dir, "local", "frame_level")
        all_fbank = np.load(os.path.join(self.fbank_dir, "all_features.npy"))

        if "avhubert" in self.model_name:
            all_fbank_downsampled = np.load(
                os.path.join(self.fbank_dir, "all_features_downsampled_by4.npy")
            )
        else:
            all_fbank_downsampled = np.load(
                os.path.join(self.fbank_dir, "all_features_downsampled.npy")
            )
        layer_start = 1

        for layer_id in range(1, self.num_conv_layers + 1):
            if self.get_score_flag(f"C{layer_id}"):
                start_time = time.time()
                fname = "layer_" + str(layer_id) + ".npy"
                rep_mat = np.load(os.path.join(rep_dir_local, fname))
                if layer_id != self.num_conv_layers:  # downsample model representations
                    view1 = all_fbank.T
                    subset = "downsampled"
                else:
                    view1 = all_fbank_downsampled.T
                    subset = "original"
                sim_score = self.get_cca_score(
                    view1,
                    rep_mat.T,
                    rep_dir_local,
                    f"C{layer_id}",
                    subset=subset,
                )

        for layer_id in range(layer_start, self.num_transformer_layers + 1):
            if self.get_score_flag(f"T{layer_id}"):
                start_time = time.time()
                fname = "layer_" + str(layer_id) + ".npy"
                rep_mat = np.load(os.path.join(rep_dir_contextualized, fname))
                sim_score = self.get_cca_score(
                    all_fbank_downsampled.T,
                    rep_mat.T,
                    rep_dir_contextualized,
                    f"T{layer_id}",
                )

    def cca_intra(self):
        rep_dir = os.path.join(self.rep_dir, "contextualized", "frame_level")
        z_mat = np.load(os.path.join(rep_dir, f"layer_{self.base_layer}.npy"))
        for layer_id in range(1, self.num_transformer_layers + 1):
            if self.get_score_flag(layer_id):
                start_time = time.time()
                c_mat = np.load(os.path.join(rep_dir, f"layer_{layer_id}.npy"))
                sim_score = self.get_cca_score(
                    z_mat.T,
                    c_mat.T,
                    rep_dir,
                    layer_id,
                )

    def cca_inter(self):
        rep_dir1 = os.path.join(self.rep_dir, "contextualized", "frame_level")
        rep_dir2 = os.path.join(self.rep_dir2, "contextualized", "frame_level")
        for layer_id in range(1, self.num_transformer_layers + 1):
            if self.get_score_flag(layer_id):
                start_time = time.time()
                c_mat1 = np.load(os.path.join(rep_dir1, f"layer_{layer_id}.npy"))
                c_mat2 = np.load(os.path.join(rep_dir2, f"layer_{layer_id}.npy"))
                sim_score = self.get_cca_score(
                    c_mat1.T,
                    c_mat2.T,
                    rep_dir1,
                    layer_id,
                    rep_dir2=rep_dir2,
                )

    def get_num_splits(self):
        search_str = self.sample_data_fn.replace("_0.json", "_*.json")
        num_splits = len(glob(search_str))
        assert num_splits != 0, "data not found"
        return num_splits

    def update_label_lst(self, split_num, all_labels, dir_name=None):
        assert dir_name is not None
        fname = os.path.join(dir_name, f"labels_{split_num}.lst")
        label_lst = read_lst(fname)
        all_labels.extend(label_lst)

    def filter_label_lst(self, all_labels, embed_dct):
        num_labels = len(all_labels)
        valid_indices = list(np.arange(num_labels))
        valid_label_lst = []
        for idx, label in enumerate(all_labels):
            if label not in embed_dct:
                valid_indices.remove(idx)
        print(
            f"{num_labels-len(valid_indices)} of {num_labels} {self.span} segments dropped"
        )
        return valid_indices

    def cca_embed(self):
        rep_dir = os.path.join(self.rep_dir, "contextualized", f"{self.span}_level")
        embed_dct = load_dct(self.embed_fn)
        num_splits = self.get_num_splits()
        all_labels = []
        for layer_id in range(self.num_transformer_layers + 1):
            if self.get_score_flag(layer_id):
                start_time = time.time()
                all_rep = []
                for split_num in range(num_splits):
                    rep_fn = os.path.join(rep_dir, str(split_num), f"layer_{layer_id}.npy")
                    rep_mat = np.load(rep_fn)
                    all_rep.extend(rep_mat)
                    if layer_id == 0 or self.eval_single_layer:
                        self.update_label_lst(split_num, all_labels, rep_dir)

                all_rep = np.array(all_rep)  # N x d
                if layer_id == 0 or self.eval_single_layer:
                    valid_indices = self.filter_label_lst(all_labels, embed_dct)
                    all_embed = np.array(
                        [embed_dct[all_labels[idx1]] for idx1 in valid_indices]
                    )
                    valid_label_lst = [all_labels[idx1] for idx1 in valid_indices]
                all_rep = all_rep[np.array(valid_indices)]
                sim_score = self.get_cca_score(
                    all_rep.T,
                    all_embed.T,
                    rep_dir,
                    layer_id,
                    label_lst=valid_label_lst,
                )

    def cca_word(self):
        self.cca_embed()

    def cca_phone(self):
        self.cca_embed()

    def cca_glove(self):
        self.cca_embed()

    def cca_agwe(self):
        self.cca_embed()


class getMI:
    def __init__(
        self,
        eval_dataset_split,
        sample_data_dir,
        rep_dir,
        save_fn,
        layer_id,
        span,
        iter_num,
        data_sample,
        num_clusters,
        train_dataset_split=None,
    ):
        self.sample_data_dir = sample_data_dir
        self.rep_dir = rep_dir
        self.save_fn = save_fn
        self.layer_id = layer_id
        self.data_sample = data_sample
        self.iter_num = iter_num

        if "train" in eval_dataset_split:
            self.all_rep, self.all_labels = self.read_data(
                eval_dataset_split
            )  # load train data
            self.eval_rep, self.eval_labels = None, None
        elif "dev" in eval_dataset_split:
            self.all_rep, self.all_labels = self.read_data(
                train_dataset_split
            )  # load train data
            self.eval_rep, self.eval_labels = self.read_data(eval_dataset_split)

        max_iter = 500
        if span == "phone":
            # n_clusters = 500
            n_clusters = num_clusters
            batch_size = 1500
        elif span == "word":
            # n_clusters = 5000
            n_clusters = num_clusters
            batch_size = 4000
        self.mi_score = tools.get_mi_score(
            n_clusters,
            batch_size,
            max_iter,
            eval_dataset_split,
            self.all_rep,
            self.all_labels,
            self.eval_rep,
            self.eval_labels,
        )

    def write_to_file(self, mi_score):
        """
        Saving scores to a file
        """
        with open(self.save_fn, "a") as f:
            f.write(
                ",".join(
                    list(
                        map(
                            str,
                            [
                                self.layer_id,
                                self.data_sample,
                                self.iter_num,
                                np.round(mi_score, 3),
                            ],
                        )
                    )
                )
                + "\n"
            )

    def read_data(self, split):
        rep_dir = self.rep_dir.replace("dev-clean", split)
        sample_data_fn = os.path.join(
            self.sample_data_dir, f"{split}_segments_sample{self.data_sample}_0.json"
        )
        search_str = sample_data_fn.replace("_0.json", "_*.json")
        num_splits = len(glob(search_str))
        assert num_splits != 0
        all_rep, all_labels = [], []
        for idx in range(num_splits):
            rep_fn = os.path.join(rep_dir, str(idx), f"layer_{self.layer_id}.npy")
            rep_mat = np.load(rep_fn)
            all_rep.extend(rep_mat)
            label_lst = read_lst(os.path.join(rep_dir, f"labels_{idx}.lst"))
            all_labels.extend(label_lst)

        all_rep = np.array(all_rep)
        assert len(all_rep) == len(all_labels)
        return all_rep, all_labels


def evaluate_mi(
    eval_dataset_split,
    sample_data_dir,
    rep_dir,
    save_fn,
    layer_id,
    span,
    iter_num,
    data_sample,
    num_clusters,
    train_dataset_split=None,
):
    mi_obj = getMI(
        eval_dataset_split,
        sample_data_dir,
        rep_dir,
        save_fn,
        layer_id,
        span,
        iter_num,
        data_sample,
        num_clusters,
        train_dataset_split,
    )
    mi_obj.write_to_file(mi_obj.mi_score)


def evaluate_cca(
    model_name,
    save_fn,
    fbank_dir,
    rep_dir,
    exp_name,
    base_layer=0,
    rep_dir2=None,
    embed_dir=None,
    sample_data_fn=None,
    span="phone",
    mean_score=False,
    eval_single_layer=False,
    layer_num=-1,
):
    cca_obj = getCCA(
        model_name,
        fbank_dir,
        rep_dir,
        exp_name,
        base_layer,
        rep_dir2,
        embed_dir,
        sample_data_fn,
        span,
        mean_score,
        eval_single_layer,
        layer_num
    )
    getattr(cca_obj, exp_name)()

    if mean_score:
        save_fn = save_fn.replace(".json", "_mean.json")
    
    if eval_single_layer:
        assert len(cca_obj.score_dct) == 1
        sample_num = save_fn.split("_")[-1].split(".")[0][-1]
        save_fn = "_".join(save_fn.split("_")[:-1]) + ".lst"
        add_to_file(
            ",".join(
                list(map(str, [layer_num, sample_num, cca_obj.score_dct[layer_num]]))
            )
            + "\n",
            save_fn,
        )
    else:
        save_dct(save_fn, cca_obj.score_dct)
    print(f"Result saved at {save_fn}")

def evaluate_wordsim(model_name, wordsim_task_fn, embedding_dir, save_fn):
    wordsim_tasks = load_dct(wordsim_task_fn)
    num_transformer_layers = LAYER_CNT[model_name]["contextualized"]
    res_dct = {}
    _ = res_dct.setdefault("micro average", {})
    _ = res_dct.setdefault("macro average", {})
    mean_score = 0
    for layer_num in range(num_transformer_layers + 1):
        embed_dct = load_dct(os.path.join(embedding_dir, f"layer{layer_num}.json"))
        res_dct["micro average"][layer_num] = 0
        res_dct["macro average"][layer_num] = 0
        num_pairs = 0
        for task_name, task_lst in wordsim_tasks.items():
            srho_score = tools.get_similarity_score(task_lst, embed_dct)
            res_dct["micro average"][layer_num] += srho_score * len(task_lst)
            res_dct["macro average"][layer_num] += srho_score
            num_pairs += len(task_lst)
            _ = res_dct.setdefault(task_name, {})
            res_dct[task_name][layer_num] = srho_score
        res_dct["micro average"][layer_num] /= num_pairs
        res_dct["macro average"][layer_num] /= len(wordsim_tasks)
    save_dct(save_fn, res_dct)