In [43]:
from pathlib import Path
script_dir = str(Path('../../scripts').absolute())
import numpy as np

import sys
import pathlib
sys.path.insert(0, script_dir)
from va_asr import inference_utils
from functools import partial
import torch
from pprint import pprint

from sklearn.metrics.pairwise import cosine_similarity

from IPython import display

WAV2VEC_CHECKPOINT_PATH = "/media/xtrem/data/experiments/nicolingua-0003-wa-wav2vec/wav2vec-training-exp-01/checkpoints/checkpoint_best.pt"
VAASR_CHECKPOINT_PATH = "/media/xtrem/code/git/nicolingua/notebooks/E401/results_401/checkpoints/VAASRCNN3PoolAvgAggMax/retrained-wav2vec_features-c_0_checkpoints/0800.pt"
MAX_SEQUENCE_LENGTH = 300
AUDIO_BASE_DIR = '/media/xtrem/data/experiments/nicolingua-0004-va-asr/datasets/gn_va_asr_dataset_2020-09-04_01/annotated_segments/'

# Load items (audio files)

In [9]:
def list_audio_files():
    selected_langs = set(["maninka"])
    selected_class_ids = set(list(range(301, 310+1)))
    audio_files_per_class = {}
    for ix, path in enumerate(Path(AUDIO_BASE_DIR).iterdir()):
        if path.suffix != '.wav': continue
        lang = path.stem.split("__")[0].split("_")[-1]
        class_name = path.stem.split("__")[-1]
         
        class_id = int(class_name.split("_")[0])
        if lang not in selected_langs: continue
        if class_id not in selected_class_ids: continue
            
        
        fq_class_name = f"{lang}_{class_id}"
        if fq_class_name not in audio_files_per_class:
            audio_files_per_class[fq_class_name] = []
        audio_files_per_class[fq_class_name].append(path)
    
    
    
    rs = np.random.RandomState(seed=42)
    for item_class in sorted(audio_files_per_class.keys()):
        samples = rs.choice(audio_files_per_class[item_class], 10, replace=False)
        for s in samples:
            yield {
                'class_name': item_class,
                'path': s
            }vaasr

SyntaxError: invalid syntax (<ipython-input-9-1810b7e8e683>, line 29)

In [10]:
all_items = list(list_audio_files())
all_item_waves = [inference_utils.load_audio_from_file(item['path']) for item in all_items]

# compute features (VAASR features)

In [11]:
class HookedModel:
    def __init__(self, model, to_numpy=True):
        self._model = model
        self._to_numpy = to_numpy
        self._backward_hooks_registered = False
        self.gradients_dict = {}

    @staticmethod
    def forward_hook(feature_dict, module_name, to_numpy, module, input, output):
        if to_numpy:
            feature_dict[module_name] = output.detach().cpu().numpy()
        else:
            feature_dict[module_name] = output

    @staticmethod
    def backward_hook(gradient_dict, module_name, to_numpy, module, input, output):
        if to_numpy:
            gradient_dict[module_name] = output.detach().cpu().numpy()
        else:
            gradient_dict[module_name] = output

    def __call__(self, x):
        handles = []
        features_dict = {}
        

        for name, module in self._model.named_modules():
            handles.append(
                module.register_forward_hook(partial(HookedModel.forward_hook, features_dict, name, self._to_numpy))
            )
            if not self._backward_hooks_registered:
                module.register_backward_hook(partial(HookedModel.backward_hook, self.gradients_dict, name, self._to_numpy))
        self._backward_hooks_registered = True

        if self._to_numpy:
            features_dict['output'] = self._model(x).detach().cpu().numpy()
        else:
            features_dict['output'] = self._model(x)

        for handle in handles:
            handle.remove()

        return features_dict

In [12]:
wav2vec_model = inference_utils.load_wav2vec(WAV2VEC_CHECKPOINT_PATH)
va_asr_model = inference_utils.load_vaasr_model(VAASR_CHECKPOINT_PATH)
va_asr_model = HookedModel(va_asr_model)

In [23]:
def get_va_asr_output(wav2vec_model, va_asr_model, wav, max_sequence_length):
    c = inference_utils.get_wav2vec_context_feature(wav2vec_model, wav, max_sequence_length)
    all_features = va_asr_model(c)
    return all_features['lin61']

In [31]:
all_item_features = []
for ix, wav in enumerate(all_item_waves):
    feature = get_va_asr_output(wav2vec_model, va_asr_model, wav, MAX_SEQUENCE_LENGTH)
    all_item_features.append(feature[0])
all_item_features = np.array(all_item_features)

# Compute all cosine similarities

In [38]:
all_cos_similarities = cosine_similarity(all_item_features)

# Example Analogies

In [58]:
random_probe_ix = np.random.randint(0, all_cos_similarities.shape[0])
sorted_similar_ixes = np.argsort(all_cos_similarities[random_probe_ix])
max_analogies = 5
for i in sorted_similar_ixes:
    if all_items[i]['class_name'] == all_items[random_probe_ix]['class_name']:
        similarity = all_cos_similarities[random_probe_ix, i]
        print(f"cos similarity: {similarity}")
        print(all_items[random_probe_ix]['path'].stem)
        display.display(display.Audio(all_item_waves[random_probe_ix], rate=16000))
        
        print(all_items[i]['path'].stem)
        display.display(display.Audio(all_item_waves[i], rate=16000))
        max_analogies -= 1
        if max_analogies == 0:
            break
        



cos similarity: 0.6201631426811218
r084_s032_d002_maninka__hawa_camara__309_eight


r118_s043_d004_maninka__tenin_keita__309_eight


cos similarity: 0.6693811416625977
r084_s032_d002_maninka__hawa_camara__309_eight


r077_s029_d002_maninka__mamadou_sow__309_eight


cos similarity: 0.6771317720413208
r084_s032_d002_maninka__hawa_camara__309_eight


r009_s006_d003_maninka__moussa_camara__309_eight


cos similarity: 0.7447845935821533
r084_s032_d002_maninka__hawa_camara__309_eight


r051_s022_d001_maninka__hadja_doumbouya__309_eight


cos similarity: 0.7479783296585083
r084_s032_d002_maninka__hawa_camara__309_eight


r121_s044_d007_maninka__sere_moussa_doumdouya__309_eight


# Example Contrasts

In [66]:
random_probe_ix = np.random.randint(0, all_cos_similarities.shape[0])
sorted_similar_ixes = np.argsort(all_cos_similarities[random_probe_ix])
max_contrasts = 5
for i in sorted_similar_ixes[::-1]:
    if all_items[i]['class_name'] != all_items[random_probe_ix]['class_name']:
        similarity = all_cos_similarities[random_probe_ix, i]
        print(f"cos similarity: {similarity}")
        print(all_items[random_probe_ix]['path'].stem)
        display.display(display.Audio(all_item_waves[random_probe_ix], rate=16000))
        
        print(all_items[i]['path'].stem)
        display.display(display.Audio(all_items[i]['path'], rate=16000))
        max_contrasts -= 1
        if max_contrasts == 0:
            break
        

cos similarity: 0.8001948595046997
r106_s039_d005_maninka__moussa_camara__307_six


r004_s003_d001_maninka__kodoba_camara__306_five


cos similarity: 0.7740617990493774
r106_s039_d005_maninka__moussa_camara__307_six


r009_s006_d001_maninka__moussa_camara__308_seven


cos similarity: 0.7503194808959961
r106_s039_d005_maninka__moussa_camara__307_six


r106_s039_d005_maninka__moussa_camara__310_nine


cos similarity: 0.721650242805481
r106_s039_d005_maninka__moussa_camara__307_six


r074_s028_d001_maninka__bilguissa_barry__308_seven


cos similarity: 0.7105814218521118
r106_s039_d005_maninka__moussa_camara__307_six


r057_s024_d002_maninka__adama_keita__305_four
