In [22]:
import os, glob
import soundfile as sf
import numpy as np
import pandas as pd
import h5py
import tqdm
import IPython
import fairseq
import torch
from matplotlib import pyplot as plt
from sklearn.manifold import TSNE
from sklearn import svm
from sklearn import metrics
import seaborn as sns
import utils
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
from fairseq.modules import Fp32LayerNorm,  TransposeLast
import pytorch_lightning as pl
import torch.nn as nn

In [23]:
class Feat2Feat(pl.LightningModule):
    def __init__(self, in_channel=256, out_feature=128):
        super().__init__()
        self.linear_layers = torch.nn.ModuleList()
        hidden_channel = in_channel
        self.linear_layers.append(nn.Linear(in_features=in_channel, out_features=hidden_channel))
        for i in range(2):
            self.linear_layers.append(nn.Linear(in_features=hidden_channel, out_features=hidden_channel))
        self.final_linear = nn.Linear(in_features=hidden_channel, out_features=out_feature)
        
        self.activation_func = torch.nn.ReLU()
        
        
    def forward(self, x):
        for i, layer in enumerate(self.linear_layers):
            x = layer(x) + x
            x = self.activation_func(x)
        
        x = self.final_linear(x)
        
        return x

In [78]:
model = Feat2Feat(512, 768).cuda()

In [85]:
# checkpoint = torch.load("/mnt/scratch09/vnguyen/SpeakerRecognition/exp/tmp/checkpoints/epoch=36-step=5364.ckpt")
hf = h5py.File("/mnt/scratch09/vnguyen/SpeakerRecognition/outputs/extracted_features/wav2vec2_small/TIMIT_test.h5", 'r')
df = pd.read_csv("/mnt/scratch09/vnguyen/SpeakerRecognition/data/TIMIT_test.csv")

In [80]:
model.load_state_dict(checkpoint['state_dict'])

<All keys matched successfully>

In [42]:
!mv /mnt/scratch09/vnguyen/SpeakerRecognition/outputs/feat2feat/wav2vec2_small/cnn-encoder-error.h5 /mnt/scratch09/vnguyen/SpeakerRecognition/outputs/feat2feat/wav2vec2_small/cnn_encoder_error_train.h5

In [82]:
!rm -rf /mnt/scratch09/vnguyen/SpeakerRecognition/outputs/feat2feat/wav2vec2_small/cnn_encoder_error_test.h5

In [86]:
hout = h5py.File("/mnt/scratch09/vnguyen/SpeakerRecognition/outputs/feat2feat/wav2vec2_small/cnn_encoder_error_test.h5", 'w')

In [87]:
import torchaudio
# transformer = torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_fft=1024, win_length=720, hop_length=320)

feat_in = 'cnn_output'
feat_out = 'encoder_output'

for i, row in df.iterrows():
    key_in = row['wav_id'] + '-' + feat_in
    key_out = row['wav_id'] + '-' + feat_out
    features = hf[key_in][:]
    with torch.no_grad():
        pred = model(torch.tensor(features).cuda())
        diff = pred.cpu().numpy() - hf[key_out][:]
        hout.create_dataset(name=key_out, data=diff)
hout.close()

In [71]:
features

array([[ 0.18760265,  0.19597238,  0.4132182 , ...,  0.29623666,
         0.2521737 , -0.29260963],
       [-0.06108914,  0.02278471, -0.00902695, ...,  0.2952789 ,
         0.6248071 , -0.2813234 ],
       [-0.05911512,  0.04213655,  0.02152432, ...,  0.29912993,
         0.6175385 , -0.28398037],
       ...,
       [ 0.0036359 , -0.00586716,  0.21506296, ...,  0.19844064,
         0.60485405, -0.11605246],
       [-0.04414704,  0.00434542,  0.01885987, ...,  0.29016116,
         0.6359349 , -0.27952406],
       [ 0.32570258,  0.09383832,  0.30810487, ...,  0.35674322,
         0.2538185 , -0.14667474]], dtype=float32)

In [72]:
hf[key_out][:]

array([[ 0.00250155, -0.00966159, -0.00612373, ...,  0.00853509,
        -0.00382528, -0.02455596],
       [-0.010848  , -0.01951417,  0.00089338, ...,  0.00656117,
         0.00143483, -0.00916419],
       [-0.00299092, -0.00057096, -0.00745986, ...,  0.00927412,
        -0.01053994, -0.03214316],
       ...,
       [ 0.00423075, -0.01344961,  0.0057536 , ..., -0.00133647,
         0.00087205, -0.01811837],
       [-0.00959727, -0.0059984 ,  0.00649665, ...,  0.00246625,
        -0.00193869, -0.01098885],
       [ 0.00972552,  0.0162667 , -0.00122207, ..., -0.01982652,
        -0.0115924 , -0.01359041]], dtype=float32)

In [33]:
hf[key_out][:]

array([[ 0.2550169 ,  0.07193729,  0.43718314, ...,  0.5358205 ,
         0.28440067, -0.21044958],
       [ 0.07654906,  0.13120219,  0.44204962, ...,  0.28140178,
         0.2342612 , -0.19971497],
       [ 0.18318653,  0.19171728,  0.55717504, ..., -0.01525378,
         0.2879535 ,  0.20958944],
       ...,
       [ 0.18395099,  0.19168429,  0.5584028 , ..., -0.01493351,
         0.2884831 ,  0.20827399],
       [-0.08312055,  0.14251192,  0.15681204, ...,  0.12978798,
         0.61423016, -0.42786285],
       [ 0.1359564 ,  0.06289283,  0.2886348 , ...,  0.4319388 ,
         0.16654636, -0.08482879]], dtype=float32)