In [9]:
import os, glob
import soundfile as sf
import numpy as np
import pandas as pd
import h5py
import tqdm
import IPython
import fairseq
import torch
from matplotlib import pyplot as plt
from sklearn.manifold import TSNE
from sklearn import svm
from sklearn import metrics
import seaborn as sns
import utils
from fairseq.dataclass.utils import convert_namespace_to_omegaconf

In [12]:
from fairseq.models.wav2vec.wav2vec2 import Wav2Vec2Model

In [19]:
def wav2vec2forward(model, source, aggregation=True, hidden_layer=None):
    """
    Inference function of pretrained wav2vec2 to extract intermediate representations
    Ref: https://github.com/pytorch/fairseq/blob/89ec6e7efff867d258947acafc57189b257212d0/fairseq/models/wav2vec/wav2vec2.py
    """
    with torch.no_grad():
        cnn_features = model.feature_extractor(source)
        
        cnn_features = cnn_features.transpose(1, 2)
        features = model.layer_norm(cnn_features)

        if model.quantizer: # this is not None in pretrained w2v
            q = model.quantizer(features, produce_targets=False)
            quantized_features = q["x"]
            projected_quantized_features = model.project_q(quantized_features)

        if model.post_extract_proj is not None: # this is not None in pretrained w2v
            features = model.post_extract_proj(features)

        if model.input_quantizer is not None: # this is None in pretrained w2v
            q = model.input_quantizer(features, produce_targets=False)
            features = q['x']
            features = model.project_inp(features)
            
        encoder_outputs, encoder_layers_features = model.encoder(features, padding_mask=None, layer=hidden_layer)
            
        context_vectors = model.final_proj(encoder_outputs)
        
        ret = dict()
        ret['cnn_output'] = cnn_features.squeeze(0)
        ret['vq'] = quantized_features.squeeze(0)
        ret['projected_vq'] = projected_quantized_features.squeeze(0)
        ret['encoder_output'] = encoder_outputs.squeeze(0)
        ret['context_vector'] = context_vectors.squeeze(0)
        if len(encoder_layers_features) > 0:
            ret['encoder_hiddens'] = [h[0][0] for h in encoder_layers_features]
        
        if aggregation:
            ret['cnn_output'] = torch.mean(ret['cnn_output'], dim=0)
            ret['vq'] = torch.mean(ret['vq'], dim=0)
            ret['projected_vq'] = torch.mean(ret['projected_vq'], dim=0)
            ret['encoder_output'] = torch.mean(ret['encoder_output'], dim=0)
            ret['context_vector'] = torch.mean(ret['context_vector'], dim=0)
            if len(encoder_layers_features) > 0:
                ret['encoder_hiddens'] = [torch.mean(h, dim=0) for h in ret['encoder_hiddens']]
        
        return ret

In [20]:
def load_wav2vec2(mode='base', evaluate=True):
    checkpoint_base = torch.load('./pretrained_checkpoints/wav2vec_small.pt')
    wav2vec2 = Wav2Vec2Model.build_model(convert_namespace_to_omegaconf(checkpoint_base['args']).model, task='audio_pretraining')

    if mode == 'random':
        utils.reset_all_weights(wav2vec2)
    else:
        if mode == 'finetune':
            checkpoint_finetune = torch.load('./pretrained_checkpoints/wav2vec_small_960h.pt')
            for key in checkpoint_finetune['model']:
                if 'w2v_encoder.w2v_model.' == key[:len('w2v_encoder.w2v_model.')]:
                    checkpoint_base['model'][key[len('w2v_encoder.w2v_model.'):]] = checkpoint_finetune['model'][key]
        wav2vec2.load_state_dict(checkpoint_base['model'])
        
    wav2vec2.cuda()
    if evaluate:
        wav2vec2.eval()
    return wav2vec2

In [21]:
wav2vec2 = load_wav2vec2('base')

In [55]:
wav2vec2.quantizer.get_codebook_indices()

tensor([  0, 320,   0,  ..., 638, 319, 639], device='cuda:0')

In [30]:
df = pd.read_csv('./data/LibriSpeech_dev_clean.csv')

In [10]:
!mkdir ./outputs/extracted_features/wav2vec2_small-finetune_speaker_identification

In [31]:
hout = h5py.File("./outputs/extracted_features/wav2vec2_small/LibriSpeech_dev_3s.h5", 'w')

In [32]:
for i, row in tqdm.tqdm(df.iterrows()):
    wav_id = row['wav_id']
    s, sr = sf.read(row['wav_path'], dtype='float32', stop=48000)
    
    if len(s) < 48000:
        continue
    
    output = wav2vec2forward(wav2vec2, torch.tensor(s).unsqueeze(0).cuda(), aggregation=False)
    
    for feature_name in ['cnn_output', 'vq', 'projected_vq', 'encoder_output', 'context_vector']:
        hout.create_dataset(f"{wav_id}-{feature_name}", data=output[feature_name].cpu())
#         hout_avg.create_dataset(f"{wav_id}-{feature_name}", data=torch.mean(output[feature_name], dim=0).cpu())

2703it [01:13, 36.94it/s]


In [33]:
hout.close()

In [74]:
!mv ./outputs/extracted_features/wav2vec2_small/TIMIT_test_average.h5 ./outputs/extracted_features/wav2vec2_small/TIMIT_test_averaged.h5

In [73]:
!ls ./outputs/extracted_features/wav2vec2_small

LibriSpeech_devclean_averaged.h5      TIMIT_test_word.h5
TIMIT_test.h5			      TIMIT_train_phoneme-cnn_output.h5
TIMIT_test_average.h5		      TIMIT_train_phoneme-context_vector.h5
TIMIT_test_phoneme-cnn_output.h5      TIMIT_train_phoneme-encoder_output.h5
TIMIT_test_phoneme-context_vector.h5  TIMIT_train_phoneme-projected_vq.h5
TIMIT_test_phoneme-encoder_output.h5  TIMIT_train_phoneme-vq.h5
TIMIT_test_phoneme-projected_vq.h5    TIMIT_train_word.h5
TIMIT_test_phoneme-vq.h5	      VCTK_averaged.h5
TIMIT_test_phoneme.h5


In [46]:
hout = h5py.File("./outputs/extracted_features/wav2vec2_small/TIMIT_train_average.h5", 'w')

In [47]:
for key in hin.keys():
    hout.create_dataset(key, data=np.mean(hin[key][:], axis=0))

In [48]:
hout.close()
hin.close()

In [68]:
!rm -rf ./outputs/extracted_features/wav2vec2_small/IMIT_train.h5 ./outputs/extracted_features/wav2vec2_small/TIMIT_train_average.h5

In [86]:
row = df.iloc[0]
wav_id = row['wav_id']
s, sr = sf.read(row['wav_path'], dtype='float32')
output = wav2vec2forward(wav2vec2, torch.tensor(s).unsqueeze(0).cuda(), aggregation=False)

In [60]:
h5 = h5py.File("./outputs/extracted_features/wav2vec2_small/TIMIT_test.h5", 'r')

In [None]:
h5.keys()

In [44]:
df = pd.read_csv("./data/TIMIT_test.csv")

In [45]:
lengths = []
for i, row in df.iterrows():
    s, sr = sf.read(row['wav_path'], dtype='float32')
    l = len(s) / sr
    lengths.append(l)

In [46]:
df['length'] = lengths

In [47]:
df.to_csv("./data/TIMIT_test.csv", index=False)

In [38]:
import IPython
IPython.display.display(IPython.display.Audio(data=s, rate=sr))