In [4]:
model_path = "./vg-hubert_3"
wav_file = "/home/ldap-users/Share/Data/librispeech/train-clean-100/19/198/19-198-0037.flac"
tgt_layer = 9
threshold = 0.7

In [5]:
import torch
import soundfile as sf
import os
import pickle
from models import audio_encoder
from itertools import groupby
from operator import itemgetter

def cls_attn_seg(cls_attn_weights, threshold, spf, audio_len_in_sec):

    threshold_value = torch.quantile(cls_attn_weights, threshold, dim=-1, keepdim=True) # [n_h, T]
    boundary_idx = torch.where((cls_attn_weights >= threshold_value).float().sum(0) > 0)[0].cpu().numpy()

    word_boundaries_list = []
    word_boundary_intervals = []
    attn_boundary_intervals = []

    for k, g in groupby(enumerate(boundary_idx), lambda ix : ix[0] - ix[1]):
        seg = list(map(itemgetter(1), g))
        t_s, t_e = seg[0], min(seg[-1]+1, cls_attn_weights.shape[-1])
        if len(seg) > 1:
            attn_boundary_intervals.append([spf*t_s, spf*t_e])

    for left, right in zip(attn_boundary_intervals[:-1], attn_boundary_intervals[1:]):
        word_boundaries_list.append((left[1]+right[0])/2.)
    
    for i in range(len(word_boundaries_list)-1):
        word_boundary_intervals.append([word_boundaries_list[i], word_boundaries_list[i+1]])
    return {"attn_boundary_intervals": attn_boundary_intervals, "word_boundary_intervals": word_boundary_intervals}

# setup model
with open(os.path.join(model_path, "args.pkl"), "rb") as f:
    model_args = pickle.load(f)
model = audio_encoder.AudioEncoder(model_args)
bundle = torch.load(os.path.join(model_path, "best_bundle.pth"))
model.carefully_load_state_dict(bundle['dual_encoder'], load_all=True)
model.eval()
model = model.cuda()

# load waveform (do not layer normalize the waveform!)
audio, sr = sf.read(wav_file, dtype = 'float32')
assert sr == 16000
audio_len_in_sec = len(audio) / sr
audio = torch.from_numpy(audio).unsqueeze(0).cuda() # [T] -> [1, T]

# model forward
with torch.no_grad():
    model_out = model(audio, padding_mask=None, mask=False, need_attention_weights=True, tgt_layer=tgt_layer)
feats = model_out['features'].squeeze(0)[1:] # [1, T+1, D] -> [T, D]
spf = audio.shape[-1]/sr/feats.shape[-2]
attn_weights = model_out['attn_weights'].squeeze(0) # [1, num_heads, T+1, T+1] -> [num_heads, T+1, T+1] (for the two T+1, first is target length then the source)
cls_attn_weights = attn_weights[:, 0, 1:] # [num_heads, T+1, T+1] -> [num_heads, T]
out = cls_attn_seg(cls_attn_weights, threshold, spf, audio_len_in_sec) # out contains attn boundaries and word boundaries in intervals


Ignoring trm.cls_token due to not existing or size mismatch
Ignoring trm.pos_embed due to not existing or size mismatch
Ignoring trm.patch_embed.proj.weight due to not existing or size mismatch
Ignoring trm.patch_embed.proj.bias due to not existing or size mismatch
Ignoring trm.blocks.0.norm1.weight due to not existing or size mismatch
Ignoring trm.blocks.0.norm1.bias due to not existing or size mismatch
Ignoring trm.blocks.0.attn.qkv.weight due to not existing or size mismatch
Ignoring trm.blocks.0.attn.qkv.bias due to not existing or size mismatch
Ignoring trm.blocks.0.attn.proj.weight due to not existing or size mismatch
Ignoring trm.blocks.0.attn.proj.bias due to not existing or size mismatch
Ignoring trm.blocks.0.norm2.weight due to not existing or size mismatch
Ignoring trm.blocks.0.norm2.bias due to not existing or size mismatch
Ignoring trm.blocks.0.mlp.fc1.weight due to not existing or size mismatch
Ignoring trm.blocks.0.mlp.fc1.bias due to not existing or size mismatch
Ignori

In [6]:
print(len(out["attn_boundary_intervals"]), "   ", out["attn_boundary_intervals"])
print(len(out["word_boundary_intervals"]), "   ", out["word_boundary_intervals"])

11     [[0.2009478672985782, 0.36170616113744075], [0.38180094786729857, 0.5425592417061611], [0.6832227488151659, 0.8640758293838863], [0.9042654028436019, 1.0650236966824644], [1.1654976303317535, 1.4870142180094787], [2.4917535545023695, 2.6324170616113745], [2.692701421800948, 3.014218009478673], [3.034312796208531, 3.235260663507109], [3.2754502369668246, 3.3759241706161136], [3.3960189573459716, 3.6974407582938387], [3.757725118483412, 3.9586729857819907]]
9     [[0.3717535545023697, 0.6128909952606635], [0.6128909952606635, 0.8841706161137441], [0.8841706161137441, 1.115260663507109], [1.115260663507109, 1.989383886255924], [1.989383886255924, 2.662559241706161], [2.662559241706161, 3.024265402843602], [3.024265402843602, 3.2553554502369666], [3.2553554502369666, 3.385971563981043], [3.385971563981043, 3.7275829383886254]]
