In [None]:
# default_exp model

In [None]:
#hide
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
# export
import sys
sys.path.insert(0,'./rtvc/')
sys.path.append('./rtvc/data_objects')
sys.path.append('..')

In [None]:
#export
import torch.nn as nn
import torch.nn.functional as F
import torch
from torch import jit
from rtvc.encoder.model import SpeakerEncoder
from rtvc.encoder.visualizations import Visualizations
from torch.autograd import Function
from rtvc.encoder.data_objects import SpeakerVerificationDataLoader, SpeakerVerificationDataset
from rtvc.encoder.params_model import *
from rtvc.encoder.model import SpeakerEncoder
from rtvc.encoder.preprocess import preprocess_voxceleb1, _init_preprocess_dataset

This jit implementation of [Mish (1908.08681)](https://arxiv.org/abs/1908.08681) was taken from @rwightman's excellent [repo](https://github.com/rwightman/gen-efficientnet-pytorch/blob/master/geffnet/activations/activations_jit.py) 

In [None]:
#export
#from https://github.com/rwightman/gen-efficientnet-pytorch/blob/master/geffnet/activations/activations_jit.py

@torch.jit.script
def mish_jit_fwd(x):
    return x.mul(torch.tanh(F.softplus(x)))

In [None]:
#export
@torch.jit.script
def mish_jit_bwd(x, grad_output):
    x_sigmoid = torch.sigmoid(x)
    x_tanh_sp = F.softplus(x).tanh()
    return grad_output.mul(x_tanh_sp + x * x_sigmoid * (1 - x_tanh_sp * x_tanh_sp))

In [None]:
#export
class MishJitAutoFn(Function):
    @staticmethod
    def forward(ctx, x):
        ctx.save_for_backward(x)
        return mish_jit_fwd(x)

    @staticmethod
    def backward(ctx, grad_output):
        x = ctx.saved_tensors[0]
        return mish_jit_bwd(x, grad_output)

In [None]:
#export
def mish_jit(x, inplace=False):
    # inplace ignored
    return MishJitAutoFn.apply(x)

In [None]:
#export
class MishJit(nn.Module):
    def __init__(self, inplace: bool = False):
        super(MishJit, self).__init__()
        self.inplace = inplace

    def forward(self, x):
        return MishJitAutoFn.apply(x)

# Normalization Functions

Taken from [Nvidia's Tacotron 2 implementation]( https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/SpeechSynthesis/Tacotron2):  

In [None]:
# export
# LinearNorm,ConvNorm adapted from TacoTron 2 Implementation : https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/SpeechSynthesis/Tacotron2
class LinearNorm(nn.Module):
    def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
        super(LinearNorm, self).__init__()
        self.linear_layer = nn.Linear(in_dim, out_dim, bias=bias)

        nn.init.xavier_uniform_(
            self.linear_layer.weight,
            gain=nn.init.calculate_gain(w_init_gain))

    def forward(self, x):
        return self.linear_layer(x)


class ConvNorm(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=1, stride=1,
                 padding=None, dilation=1, bias=True, w_init_gain='linear'):
        super(ConvNorm, self).__init__()
        if padding is None:
            assert(kernel_size % 2 == 1)
            padding = int(dilation * (kernel_size - 1) / 2)

        self.conv = nn.Conv1d(in_channels, out_channels,
                                    kernel_size=kernel_size, stride=stride,
                                    padding=padding, dilation=dilation,
                                    bias=bias)

        nn.init.xavier_uniform_(
            self.conv.weight,
            gain=nn.init.calculate_gain(w_init_gain))

    def forward(self, signal):
        return self.conv(signal)
class LayerNorm(nn.Module):
    def __init__(self, features, eps=1e-6):
        super(LayerNorm, self).__init__()
        self.gamma = nn.Parameter(torch.ones(features))
        self.beta = nn.Parameter(torch.zeros(features))
        self.eps = eps

    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        std = x.std(-1, keepdim=True)
        return self.gamma * (x - mean) / (std + self.eps) + self.beta

# Li-GRU with Mish 

Li-GRUs were first introduced in [Light Gated Recurrent Units for Speaker Recognition (1803.10225)](https://arxiv.org/abs/1803.10225). The core ideas behind Li-GRUs are that removing the reset gate from GRUs would be helpful as the past state is usually always relevant in the context of speech and coupled ReLU and BatchNorm instead of *tanh*. [PyTorch-Kaldi](https://github.com/mravanelli/pytorch-kaldi) found Li-GRU to be the best performing model on TIMIT, using a bidirectional 5-layer stack of Li-GRUs with `hidden_dim` of 550 and dropout of 0.2 between layers. 

Brak extends this idea to speaker encoding, but replaces the ReLU with Mish. Brak uses a 3-layer stack of Li-GRU's with `hidden_dim` of 256 and dropout of 0.2 between layers. The initial curiosity for Mish arose from its and its predecessor's ([Swish (1710.05941)](https://arxiv.org/abs/1710.05941)) similarity to an inverted low pass filter with resonance. 



**Swish**

<img src='../brak/swish.png' width=300 height=300/>


**Low Pass Filter**

<img src='../brak/low pass filter with resonance.gif'  width=300 height=300/>

In [None]:
# export
# Adapted from pytorch-kaldi https://github.com/mravanelli/pytorch-kaldi/blob/master/neural_networks.py
class LiGru(nn.Module):
    def __init__(self,inp_dim=40,layer_dims=[550,550,550,550,550], ps=0.2, 
                 activations=[''], bn_inp=False, ln_inp=False,
                 use_ln=False, use_bn=True,bidir=True,orth_init=True,
                ):
        super(LiGru,self).__init__()

        self.input_dim = inp_dim # n_channels
        self.ps = ps # dropout
        self.bidir=bidir
        self.orthinit = orth_init
#         self.activation = nn.ReLU()
        self.activation = MishJit()
        self.to_do = 'train' # training flag
        if self.to_do == 'train':
            self.test_flag = False
        else:
            self.test_flag = True
        self.wh = nn.ModuleList([])
        self.uh = nn.ModuleList([])

        self.wz = nn.ModuleList([])  # Update Gate
        self.uz = nn.ModuleList([])  # Update Gate

        self.ln = nn.ModuleList([])  # Layer Norm
        self.bn_wh = nn.ModuleList([])  # Batch Norm
        self.bn_wz = nn.ModuleList([])  # Batch Norm
#         self.act = nn.ModuleList([])  # Activations
        self.ligru_lay = layer_dims
        self.N_ligru_lay = len(self.ligru_lay)
        self.use_bn = use_bn
        self.use_ln = use_ln
        self.bn_inp = bn_inp
        self.ln_inp = ln_inp
        self.use_cuda = torch.cuda.is_available()
        if self.bn_inp:
        # batchnorm
            self.bn0 = nn.BatchNorm1d(self.input_dim, momentum=0.05)
    
        if self.ln_inp:
        # layer norm
            self.ln0 = LayerNorm(self.input_dim)
        current_input = self.input_dim
        
        # hidden inits
        
        for i in range(self.N_ligru_lay):
            add_bias = True

            if self.use_ln or self.use_bn:
                add_bias = False

            # Feed-forward connections
            self.wh.append(nn.Linear(current_input, self.ligru_lay[i], bias=add_bias))
            self.wz.append(nn.Linear(current_input, self.ligru_lay[i], bias=add_bias))

            # Recurrent connections
            self.uh.append(nn.Linear(self.ligru_lay[i], self.ligru_lay[i], bias=False))
            self.uz.append(nn.Linear(self.ligru_lay[i], self.ligru_lay[i], bias=False))

            if self.orthinit:
                nn.init.orthogonal_(self.uh[i].weight)
                nn.init.orthogonal_(self.uz[i].weight)

            # batch norm initialization
            self.bn_wh.append(nn.BatchNorm1d(self.ligru_lay[i], momentum=0.05))
            self.bn_wz.append(nn.BatchNorm1d(self.ligru_lay[i], momentum=0.05))

            self.ln.append(LayerNorm(self.ligru_lay[i]))

            if self.bidir:
                current_input = 2 * self.ligru_lay[i]
            else:
                current_input = self.ligru_lay[i]

        self.out_dim = self.ligru_lay[i] + self.bidir * self.ligru_lay[i]
        
    def forward(self, x):

        # Applying Layer/Batch Norm
        if self.ln_inp:
            x = self.ln0((x))

        if self.bn_inp:
            x_bn = self.bn0(x.view(x.shape[0] * x.shape[1], x.shape[2]))
            x = x_bn.view(x.shape[0], x.shape[1], x.shape[2])

        for i in range(self.N_ligru_lay):

            # Initial state and concatenation
            if self.bidir:
                h_init = torch.zeros(2 * x.shape[1], self.ligru_lay[i])
                x = torch.cat([x, torch.flip(x, [0])], 1)
            else:
                h_init = torch.zeros(x.shape[1], self.ligru_lay[i])

            # Drop mask initilization (same mask for all time steps)
            if self.test_flag == False:
                drop_mask = torch.bernoulli(
                    torch.Tensor(h_init.shape[0], h_init.shape[1]).fill_(1 - self.ps)
                )
            else:
                drop_mask = torch.FloatTensor([1 - self.ps])

            if self.use_cuda:
                h_init = h_init.cuda()
                drop_mask = drop_mask.cuda()

            # Feed-forward affine transformations (all steps in parallel)
            wh_out = self.wh[i](x)
            wz_out = self.wz[i](x)

            # Apply batch norm if needed (all steos in parallel)
            if self.use_bn:

                wh_out_bn = self.bn_wh[i](wh_out.view(wh_out.shape[0] * wh_out.shape[1], wh_out.shape[2]))
                wh_out = wh_out_bn.view(wh_out.shape[0], wh_out.shape[1], wh_out.shape[2])

                wz_out_bn = self.bn_wz[i](wz_out.view(wz_out.shape[0] * wz_out.shape[1], wz_out.shape[2]))
                wz_out = wz_out_bn.view(wz_out.shape[0], wz_out.shape[1], wz_out.shape[2])

            # Processing time steps
            hiddens = []
            ht = h_init

            for k in range(x.shape[0]):

                # ligru equation
                zt = torch.sigmoid(wz_out[k] + self.uz[i](ht))
                at = wh_out[k] + self.uh[i](ht)
                hcand = self.activation(at) * drop_mask
                ht = zt * ht + (1 - zt) * hcand

                if self.use_ln:
                    ht = self.ln[i](ht)

                hiddens.append(ht)

            # Stacking hidden states
            h = torch.stack(hiddens)

            # Bidirectional concatenations
            if self.bidir:
                h_f = h[:, 0 : int(x.shape[1] / 2)]
                h_b = torch.flip(h[:, int(x.shape[1] / 2) : x.shape[1]].contiguous(), [0])
                h = torch.cat([h_f, h_b], 2)

            # Setup x for the next hidden layer
            x = h
        return x


In [None]:
# export
# subclasses resemblyzer SpeakerEncoder
class GRUVoiceEncoder(SpeakerEncoder):
    def __init__(self, *args, **kwargs):
        super(GRUVoiceEncoder,self).__init__(*args, **kwargs)
        self.lstm=None
        self.gru = LiGru(input_size,[hidden_dim]*n_layers,ps=0.2,bidir=False).to(device)
        self.relu = None
        self.mish = MishJit()
        
        self.linear = nn.Linear(in_features=hidden_dim, 
                                out_features=output_dim).to(device)
    def forward(self, utterances):
        """
        Computes the embeddings of a batch of utterance spectrograms.
        
        :param utterances: batch of mel-scale filterbanks of same duration as a tensor of shape 
        (batch_size, n_frames, n_channels) 
        :param hidden_init: initial hidden state of the LSTM as a tensor of shape (num_layers, 
        batch_size, hidden_size). Will default to a tensor of zeros if None.
        :return: the embeddings as a tensor of shape (batch_size, embedding_size)
        """
        # Pass the input through the LSTM layers and retrieve all outputs, the final hidden state
        # and the final cell state.
        hn = self.gru(utterances)
        # We take only the hidden state of the last layer
        embeds_raw = self.mish(self.linear(hn[:,-1,:]))
        # L2-normalize it
        embeds = embeds_raw / torch.norm(embeds_raw, dim=-1, keepdim=True)
        return embeds

In [None]:
from nbdev.export import notebook2script
notebook2script()

Converted 00_core.ipynb.
Converted 01_data.ipynb.
Converted 02_model.ipynb.
Converted 03_train.ipynb.
Converted index.ipynb.
