In [None]:
from datasets.speaker_audio_dataset import SpeakerAudioDataset
from model.layers.lstmp import LSTMPCell

In [34]:
sample_rate = 22050
mel_params = {
    'n_fft': int(1024 * (sample_rate / 16000)),
    'hop_length': int(256 * (sample_rate / 16000)),
    'win_length': int(1024 * (sample_rate / 16000)),
    'n_mels': 80
}

model_params = {
    'input_size': 80,
    'hidden_size': 257,
    'projection_size': 256,
    'embedding_size': 256,
    'num_layers': 3
}

In [3]:
dataset = SpeakerAudioDataset('../data/utterance_corpuses/LibriTTS/dev-clean', sample_rate, mel_params)
test_Y, test_X = dataset[0]

In [59]:
import torch
from torch import nn
from torch.nn import functional as func

class SpeakerVerificationLSTMEncoder(nn.Module):
    def __init__(self, 
                 input_size, 
                 hidden_size, 
                 projection_size, 
                 embedding_size,
                 num_layers
                ):
        super(SpeakerVerificationLSTMEncoder, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.projection_size = projection_size
        self.embedding_size = embedding_size
        self.num_layers = num_layers
        
        self.lstm = nn.LSTM(
            self.input_size, 
            self.hidden_size, 
            self.num_layers, 
            proj_size=self.projection_size,
            batch_first=True
        )
        
        self.linear = nn.Linear(
            in_features=self.projection_size, 
            out_features=self.embedding_size
        )
        
        self.relu = nn.ReLU()
        
    def forward(self, x):
        # (64, 636, 80)
        
        # lstm with projection
        x, (hx, cx) = self.lstm(x)
        
        # linear layer w/ relu
        x = self.relu(self.linear(hx[-1]))
        
        # l2 normalize
        x = func.normalize(x, p=2, dim=1)
        
        return x

In [61]:
# generate (xj, (xk1, ...xkM))
# if j = k (speakers), positive
# else negative
# generate pos/neg alternatively
# compute l2 norm response from lstm
# (ej, (ek1, ...ekM))
# compute centroid of (ek1, ...ekM), ckM

model = SpeakerVerificationLSTMEncoder(**model_params)
x = model(torch.randn(64, 636, 80))
# compute centroids of each row
x, x.shape

(tensor([[0.0318, 0.0099, 0.0000,  ..., 0.0507, 0.0000, 0.0421],
         [0.0324, 0.0092, 0.0000,  ..., 0.0505, 0.0000, 0.0423],
         [0.0307, 0.0107, 0.0000,  ..., 0.0530, 0.0000, 0.0432],
         ...,
         [0.0356, 0.0106, 0.0000,  ..., 0.0518, 0.0000, 0.0440],
         [0.0327, 0.0093, 0.0000,  ..., 0.0526, 0.0000, 0.0429],
         [0.0339, 0.0103, 0.0000,  ..., 0.0515, 0.0000, 0.0410]],
        grad_fn=<DivBackward0>),
 torch.Size([64, 256]))