In [1]:
import os, glob
import soundfile as sf
import numpy as np
import pandas as pd
import h5py
import tqdm
import IPython
import fairseq
import torch
from matplotlib import pyplot as plt
from sklearn.manifold import TSNE
from sklearn import svm
from sklearn import metrics
import seaborn as sns
import utils
from fairseq.dataclass.utils import convert_namespace_to_omegaconf

2021-11-24 17:42:46 | INFO | fairseq.tasks.text_to_speech | Please install tensorboardX: pip install tensorboardX


In [12]:
from fairseq.models.wav2vec.wav2vec2 import Wav2Vec2Model

In [166]:
checkpoint_base = torch.load('./pretrained_checkpoints/wav2vec_small.pt')
wav2vec2 = Wav2Vec2Model.build_model(convert_namespace_to_omegaconf(checkpoint_base['args']).model, task='audio_pretraining')

checkpoint_finetune = torch.load('./pretrained_checkpoints/wav2vec_small_960h.pt')
utils.reset_all_weights(wav2vec2)
for key in checkpoint_finetune['model']:
    if 'w2v_encoder.w2v_model.' == key[:len('w2v_encoder.w2v_model.')]:
        checkpoint_base['model'][key[len('w2v_encoder.w2v_model.'):]] = checkpoint_finetune['model'][key]

wav2vec2.load_state_dict(checkpoint_base['model'])
wav2vec2.eval()

Wav2Vec2Model(
  (feature_extractor): ConvFeatureExtractionModel(
    (conv_layers): ModuleList(
      (0): Sequential(
        (0): Conv1d(1, 512, kernel_size=(10,), stride=(5,), bias=False)
        (1): Dropout(p=0.0, inplace=False)
        (2): Fp32GroupNorm(512, 512, eps=1e-05, affine=True)
        (3): GELU()
      )
      (1): Sequential(
        (0): Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=False)
        (1): Dropout(p=0.0, inplace=False)
        (2): GELU()
      )
      (2): Sequential(
        (0): Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=False)
        (1): Dropout(p=0.0, inplace=False)
        (2): GELU()
      )
      (3): Sequential(
        (0): Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=False)
        (1): Dropout(p=0.0, inplace=False)
        (2): GELU()
      )
      (4): Sequential(
        (0): Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=False)
        (1): Dropout(p=0.0, inplace=False)
        (2): GELU()
      )
      (5

In [155]:
def wav2vec2forward(model, source, aggregation=True, hidden_layer=None):
    """
    Inference function of pretrained wav2vec2 to extract intermediate representations
    Ref: https://github.com/pytorch/fairseq/blob/89ec6e7efff867d258947acafc57189b257212d0/fairseq/models/wav2vec/wav2vec2.py
    """
    with torch.no_grad():
        cnn_features = model.feature_extractor(source)
        
        cnn_features = cnn_features.transpose(1, 2)
        features = model.layer_norm(cnn_features)

        if model.quantizer: # this is not None in pretrained w2v
            q = model.quantizer(features, produce_targets=False)
            quantized_features = q["x"]
            projected_quantized_features = model.project_q(quantized_features)

        if model.post_extract_proj is not None: # this is not None in pretrained w2v
            features = model.post_extract_proj(features)

        if model.input_quantizer is not None: # this is None in pretrained w2v
            q = model.input_quantizer(features, produce_targets=False)
            features = q['x']
            features = model.project_inp(features)
            
        encoder_outputs, encoder_layers_features = model.encoder(features, padding_mask=None, layer=hidden_layer)
            
        context_vectors = model.final_proj(encoder_outputs)
        
        ret = dict()
        ret['cnn_output'] = cnn_features.squeeze(0)
        ret['vq'] = quantized_features.squeeze(0)
        ret['projected_vq'] = projected_quantized_features.squeeze(0)
        ret['encoder_output'] = encoder_outputs.squeeze(0)
        ret['context_vector'] = context_vectors.squeeze(0)
        if len(encoder_layers_features) > 0:
            ret['encoder_hiddens'] = [h[0][0] for h in encoder_layers_features]
        
        if aggregation:
            ret['cnn_output'] = torch.mean(ret['cnn_output'], dim=0)
            ret['vq'] = torch.mean(ret['vq'], dim=0)
            ret['projected_vq'] = torch.mean(ret['projected_vq'], dim=0)
            ret['encoder_output'] = torch.mean(ret['encoder_output'], dim=0)
            ret['context_vector'] = torch.mean(ret['context_vector'], dim=0)
            if len(encoder_layers_features) > 0:
                ret['encoder_hiddens'] = [torch.mean(h, dim=0) for h in ret['encoder_hiddens']]
        
        return ret

In [161]:
df = pd.read_csv('./data/TIMIT_train.csv')

In [134]:
!rm -rf ./outputs/extracted_features/wav2vec2_small/TIMIT_train_word.h5

In [167]:
hf = h5py.File("./outputs/extracted_features/wav2vec2_small-random_init/TIMIT_train_word.h5", 'w')

In [None]:
hf.close()

In [None]:
for i, row in tqdm.tqdm(df.iterrows()):
    wav_id = row['wav_id']
    word_path = row['wav_path'][:-7] + 'WRD'
    with open(word_path) as f:
        words = f.read().strip('\n').split('\n')
    wav, sr = sf.read(row['wav_path'], dtype='float32')
    for j, word in enumerate(words):
        word = word.split(' ')
#         if word[2] not in valid_words:
#             continue
        s = wav[int(word[0]):int(word[1])+1]
        if len(s) < 400:
            s = np.concatenate((s, np.zeros(400-len(s), dtype=np.float32)))
        output = wav2vec2forward(wav2vec2, torch.tensor(s).unsqueeze(0), aggregation=True)

        hf.create_dataset(f"{wav_id}-{word[2]}_{j}-cnn_output", data=output['cnn_output'].cpu())
        hf.create_dataset(f"{wav_id}-{word[2]}_{j}-vq", data=output['vq'].cpu())
        hf.create_dataset(f"{wav_id}-{word[2]}_{j}-projected_vq", data=output['projected_vq'].cpu())
        hf.create_dataset(f"{wav_id}-{word[2]}_{j}-encoder_output", data=output['encoder_output'].cpu())
        hf.create_dataset(f"{wav_id}-{word[2]}_{j}-context_vector", data=output['context_vector'].cpu())

3989it [53:31,  3.48s/it]

In [132]:
for i, row in df.iterrows():
    wav_id = row['wav_id']
    phoneme_path = row['wav_path'][:-7] + 'PHN'
    word_path = row['wav_path'][:-7] + 'WRD'
    with open(phoneme_path) as f:
        phonemes = f.read().strip('\n').split('\n')
    with open(word_path) as f:
        words = f.read().strip('\n').split('\n')
    wav, sr = sf.read(row['wav_path'], dtype='float32')
    word_idx = -1
    word_end_pos = -1
    if phonemes[0].split(' ')[2] != 'h#':
        print(phonemes[0].split(' ')[2])
    if phonemes[-1].split(' ')[2] != 'h#':
        print(phonemes[-1].split(' ')[2])
    for j, p in enumerate(phonemes):
        p = p.split(' ')
        if p[2] == 'h#':
            continue
        try:
            if int(p[0]) >= word_end_pos:
                word_idx += 1
                word_end_pos = int(words[word_idx].split(' ')[1])
                current_word = words[word_idx].split(' ')[2]
        except:
            continue
        s = wav[int(p[0]):int(p[1])+1]
        if len(s) < 400:
            s = np.concatenate((s, np.zeros(400-len(s), dtype=np.float32)))
        output = wav2vec2forward(wav2vec2, torch.tensor(s).unsqueeze(0), aggregation=True)

        hf.create_dataset(f"{wav_id}-{current_word}_{word_idx}-{p[2]}_{j}-cnn_output", data=output['cnn_output'].cpu())
        hf.create_dataset(f"{wav_id}-{current_word}_{word_idx}-{p[2]}_{j}-vq", data=output['vq'].cpu())
        hf.create_dataset(f"{wav_id}-{current_word}_{word_idx}-{p[2]}_{j}-projected_vq", data=output['projected_vq'].cpu())
        hf.create_dataset(f"{wav_id}-{current_word}_{word_idx}-{p[2]}_{j}-encoder_output", data=output['encoder_output'].cpu())
        hf.create_dataset(f"{wav_id}-{current_word}_{word_idx}-{p[2]}_{j}-context_vector", data=output['context_vector'].cpu())

KeyboardInterrupt: 

In [114]:
from collections import defaultdict
words_list = defaultdict(int)
for i, row in df.iterrows():
    wav_id = row['wav_id']
    word_path = row['wav_path'][:-7] + 'WRD'
    with open(word_path) as f:
        words = f.read().strip('\n').split('\n')
    for w in words:
        words_list[w.split(' ')[2]] += 1

In [111]:
valid_word_list = set()
for w in words_list:
    if words_list[w] >= 7:
        valid_word_list.add(w)

In [120]:
valid_word_list_test = set()
for w in words_list:
    if words_list[w] >= 7:
        valid_word_list_test.add(w)

In [139]:
valid_words = valid_word_list_test.intersection(valid_word_list)

In [152]:
import pickle
with open("data/TIMIT/valid_words.pkl", 'wb') as f:
    pickle.dump(valid_words, f)

In [105]:
{k: v for k, v in sorted(words_list.items(), key=lambda item: item[1], reverse=True)}

{'the': 1603,
 'to': 1018,
 'in': 947,
 'a': 867,
 'that': 612,
 'she': 572,
 'an': 571,
 'your': 565,
 'all': 545,
 'had': 526,
 'like': 518,
 'me': 517,
 'and': 492,
 "don't": 488,
 'water': 479,
 'dark': 473,
 'year': 473,
 'oily': 470,
 'rag': 470,
 'wash': 469,
 'ask': 464,
 'carry': 463,
 'suit': 462,
 'greasy': 462,
 'of': 455,
 'is': 401,
 'you': 274,
 'are': 238,
 'was': 236,
 'he': 233,
 'for': 216,
 'his': 190,
 'with': 188,
 'be': 176,
 'it': 171,
 'on': 167,
 'we': 154,
 'this': 152,
 'they': 141,
 'by': 130,
 'her': 127,
 'from': 125,
 'as': 125,
 'have': 124,
 'not': 119,
 'but': 103,
 'will': 100,
 'i': 99,
 'do': 93,
 'him': 84,
 'my': 83,
 'or': 78,
 'no': 76,
 'were': 75,
 'at': 74,
 'can': 73,
 'new': 68,
 'up': 68,
 'would': 67,
 'every': 65,
 'now': 60,
 'our': 60,
 'how': 59,
 'each': 59,
 'their': 58,
 'big': 57,
 'so': 57,
 'often': 56,
 'may': 55,
 'has': 55,
 'out': 55,
 'never': 54,
 'many': 51,
 'into': 51,
 'if': 50,
 'saw': 50,
 'get': 50,
 'one': 49,
 'm

In [89]:
!cat ./data/data/TRAIN/DR1/MRAI0/SX72.WRD

2210 9400 spring
9400 16887 street
16887 20386 is
20386 24680 straight
24680 31040 ahead


In [93]:
s, sr = sf.read("./data/data/TRAIN/DR1/MRAI0/SX72.WAV.wav", dtype='float32')
IPython.display.Audio(data=s[31040 :31645    ], rate=sr)

In [43]:
def wav2vec2featurize(model, source, feature_name):
    """
    Inference function of pretrained wav2vec2 to extract intermediate representations
    Ref: https://github.com/pytorch/fairseq/blob/89ec6e7efff867d258947acafc57189b257212d0/fairseq/models/wav2vec/wav2vec2.py
    """
    assert feature_name in ['cnn_output', 'vq', 'projected_vq', 'encoder_output', 'context_vector']
    cnn_features = model.feature_extractor(source)

    cnn_features = cnn_features.transpose(1, 2)
    
    if feature_name == 'cnn_output':
        return cnn_features.squeeze(0)
    
    features = model.layer_norm(cnn_features)

    if model.quantizer: # this is not None in pretrained w2v
        q = model.quantizer(features, produce_targets=False)
        quantized_features = q["x"]
        if feature_name == 'vq':
            return quantized_features.squeeze(0)
        projected_quantized_features = model.project_q(quantized_features)
        if feature_name == 'projected_vq':
            return projected_quantized_features.squeeze(0)

    if model.post_extract_proj is not None: # this is not None in pretrained w2v
        features = model.post_extract_proj(features)

    if model.input_quantizer is not None: # this is None in pretrained w2v
        q = model.input_quantizer(features, produce_targets=False)
        features = q['x']
        features = model.project_inp(features)

    encoder_outputs, encoder_layers_features = model.encoder(features, padding_mask=None, layer=None)
    
    if feature_name == 'encoder_output':
        return encoder_outputs.squeeze(0)

    context_vectors = model.final_proj(encoder_outputs)

    return context_vectors.squeeze(0)

In [46]:
wav2vec2featurize(wav2vec2, torch.tensor(s[39561:40313]).unsqueeze(0), 'encoder_output')

tensor([[ 0.0458,  1.0363, -0.3802,  ..., -1.2763,  0.5118, -0.3924],
        [-0.7767, -0.6042, -0.1070,  ..., -0.0194, -0.7107, -1.8827]],
       grad_fn=<SqueezeBackward1>)

In [45]:
s, sr = sf.read("./data/data/TRAIN/DR1/FCJF0/SA1.WAV.wav", dtype='float32')
IPython.display.Audio(data=s[4559:5723], rate=sr)

In [72]:
wav2vec2

Wav2Vec2Model(
  (feature_extractor): ConvFeatureExtractionModel(
    (conv_layers): ModuleList(
      (0): Sequential(
        (0): Conv1d(1, 512, kernel_size=(10,), stride=(5,), bias=False)
        (1): Dropout(p=0.0, inplace=False)
        (2): Fp32GroupNorm(512, 512, eps=1e-05, affine=True)
        (3): GELU()
      )
      (1): Sequential(
        (0): Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=False)
        (1): Dropout(p=0.0, inplace=False)
        (2): GELU()
      )
      (2): Sequential(
        (0): Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=False)
        (1): Dropout(p=0.0, inplace=False)
        (2): GELU()
      )
      (3): Sequential(
        (0): Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=False)
        (1): Dropout(p=0.0, inplace=False)
        (2): GELU()
      )
      (4): Sequential(
        (0): Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=False)
        (1): Dropout(p=0.0, inplace=False)
        (2): GELU()
      )
      (5