In [1]:
from hw_tts.utils.audio import Audio

In [2]:
audio = Audio()

In [4]:
import torch
import torch.nn as nn 

x = torch.randn(16, 255, 4)
class ResBlock(nn.Module):
    """
    Resnet block for speaker encoder to obtain speaker embedding
    ref to
        https://github.com/fatchord/WaveRNN/blob/master/models/fatchord_version.py
        and
        https://github.com/Jungjee/RawNet/blob/master/PyTorch/model_RawNet.py
    """

    def __init__(self, in_dims, out_dims):
        super(ResBlock, self).__init__()
        self.conv1 = nn.Conv1d(in_dims, out_dims, kernel_size=1, bias=False)
        self.conv2 = nn.Conv1d(out_dims, out_dims, kernel_size=1, bias=False)
        self.batch_norm1 = nn.BatchNorm1d(out_dims)
        self.batch_norm2 = nn.BatchNorm1d(out_dims)
        self.prelu1 = nn.PReLU()
        self.prelu2 = nn.PReLU()
        self.maxpool = nn.MaxPool1d(3)
        if in_dims != out_dims:
            self.downsample = True
            self.conv_downsample = nn.Conv1d(in_dims, out_dims, kernel_size=1, bias=False)
        else:
            self.downsample = False

    def forward(self, x):
        y = self.conv1(x)
        y = self.batch_norm1(y)
        y = self.prelu1(y)
        y = self.conv2(y)
        y = self.batch_norm2(y)
        if self.downsample:
            y += self.conv_downsample(x)
        else:
            y += x
        y = self.prelu2(y)
        return self.maxpool(y)

In [5]:
ResBlock(255, 255)(x).norm()

tensor(100.9344, grad_fn=<LinalgVectorNormBackward0>)

In [6]:
import torch.nn as nn


class ResBlock(nn.Module):
    def __init__(self, input_size, out_size):
        super().__init__()
        self.first = nn.Sequential(
            nn.Conv1d(input_size, out_size, kernel_size=1, bias=False),
            nn.BatchNorm1d(out_size),
            nn.PReLU(),
            nn.Conv1d(out_size, out_size, kernel_size=1, bias=False),
            nn.BatchNorm1d(out_size)
        )

        if input_size != out_size:
            self.downsample = nn.Sequential(
                nn.Conv1d(input_size, out_size, kernel_size=1, bias=False)
            )
        else:
            self.downsample = None
        self.second = nn.Sequential(
                nn.PReLU(), 
                nn.MaxPool1d(3)
        )

    def forward(self, x):
        y = self.first(x)
        if self.downsample is not None:
            x = self.downsample(x)
        y += x
        y = self.second(y)
        return y
ResBlock(255, 255)(x).norm()

tensor(103.1035, grad_fn=<LinalgVectorNormBackward0>)

In [7]:
x = torch.randn((16, 4, 10))

In [13]:
class GlobalLayerNorm(nn.Module):
    def __init__(self, dim, eps=1e-05):
        super().__init__()
        self.eps = eps
        self.normalized_dim = dim
        self.beta = nn.Parameter(torch.zeros(dim, 1))
        self.gamma = nn.Parameter(torch.ones(dim, 1))

    def forward(self, x):
        mean = x.mean(dim=(1, 2), keepdim=True)
        var = x.var(dim=(1, 2), keepdim=True, unbiased=False)
        x = self.gamma * (x - mean) / torch.sqrt(var + self.eps) + self.beta
        return x

    
class GlobalLayerNorm2(nn.Module): 
    def __init__(self, dim): 
        super().__init__()
        self.layernorm = nn.LayerNorm(dim)
    def forward(self, x): 
        B, C, T = x.shape
        return self.layernorm(x.view(B, C * T)).view(B, C, T)

In [14]:
class GlobalLayerNorm2(nn.Module): 
    def __init__(self, dim): 
        super().__init__()
        self.layernorm = nn.LayerNorm(dim)
    def forward(self, x): 
        B, C, T = x.shape
        return self.layernorm(x.view(B, C * T)).view(B, C, T)

In [15]:
res1 = GlobalLayerNorm2(40)(x)

In [16]:
res2 = GlobalLayerNorm(4)(x)

In [17]:
(res1 - res2).norm()

tensor(1.7022e-06, grad_fn=<LinalgVectorNormBackward0>)

In [37]:
class TCNBlock(nn.Module):
    def __init__(self,
                 in_channels=256,
                 conv_channels=512,
                 kernel_size=3,
                 dilation=1):
        super().__init__()

        self.net = nn.Sequential(
            nn.Conv1d(in_channels, conv_channels, 1),
            nn.PReLU(),
            GlobalLayerNorm(conv_channels, eps=1e-05),
            nn.Conv1d(
                conv_channels,
                conv_channels,
                kernel_size,
                groups=conv_channels,
                padding=(dilation * (kernel_size - 1)) // 2,
                dilation=dilation,
                bias=True),
            nn.PReLU(),
            GlobalLayerNorm(conv_channels, eps=1e-05),
            nn.Conv1d(conv_channels, in_channels, 1, bias=True)
        )

    def forward(self, x):
        return self.net(x) + x


In [38]:
class TCNBlock_Spk(TCNBlock):
    def __init__(self, in_channels, spk_embed_dim, conv_channels, kernel_size, dilation):
        super().__init__(in_channels + spk_embed_dim, conv_channels, kernel_size, dilation)

    def forward(self, x, aux):
        aux = aux.unsqueeze(-1).repeat(1, 1, x.shape[-1])
        y = torch.cat([x, aux], 1)
        return super().forward(y)

In [None]:
class TCNBlock(nn.Module):
    def __init__(self,
                 in_channels=256,
                 conv_channels=512,
                 kernel_size=3,
                 dilation=1, causal=False):
        super().__init__()
        self.conv = nn.Conv1d(in_channels, conv_channels, 1)
        self.net = nn.Sequential(
            nn.PReLU(),
            GlobalLayerNorm(conv_channels, eps=1e-05),
            nn.Conv1d(
                conv_channels,
                conv_channels,
                kernel_size,
                groups=conv_channels,
                padding=(dilation * (kernel_size - 1)) // 2,
                dilation=dilation,
                bias=True),
            nn.PReLU(),
            GlobalLayerNorm(conv_channels, eps=1e-05),
            nn.Conv1d(conv_channels, in_channels, 1, bias=True)
        )

    def forward(self, x):
        y = self.conv(x)
        y = self.net(y)
        return y + x


class TCNBlock_Spk(TCNBlock):
    """
    Temporal convolutional network block,
        1x1Conv - PReLU - Norm - DConv - PReLU - Norm - SConv
        The first tcn block takes additional speaker embedding as inputs
    Input: 3D tensor with [N, C_in, L_in]
    Input Speaker Embedding: 2D tensor with [N, D]
    Output: 3D tensor with [N, C_out, L_out]
    """

    def __init__(self,
                 in_channels=256,
                 spk_embed_dim=100,
                 conv_channels=512,
                 kernel_size=3,
                 dilation=1,
                 causal=False):
        super().__init__(in_channels, conv_channels, kernel_size, dilation)
        self.conv = nn.Conv1d(in_channels + spk_embed_dim, conv_channels, 1)

    def forward(self, x, aux):
        aux = th.unsqueeze(aux, -1)
        aux = aux.repeat(1, 1, x.shape[-1])
        y = th.cat([x, aux], 1)
        y = self.conv(y)
        y = self.net(y.squeeze())
        return y + x
