In [11]:
import torch
import torch.nn as nn
from torch import Tensor

__all__ = ["ResBlock", "MelResNet"]


class ResBlock(nn.Module):
    r"""Add description here.
    Args:
        num_dims (int, optional): Number of compute dimensions in ResBlock. (Default: ``128``)
    """
    def __init__(self, num_dims: int = 128) -> None:
        super().__init__()

        self.resblock_model = nn.Sequential(
            nn.Conv1d(in_channels=num_dims, out_channels=num_dims, kernel_size=1, bias=False),
            nn.BatchNorm1d(num_dims),
            nn.ReLU(inplace=True),
            nn.Conv1d(in_channels=num_dims, out_channels=num_dims, kernel_size=1, bias=False),
            nn.BatchNorm1d(num_dims)
        )

    def forward(self, x: Tensor) -> Tensor:
        r"""Add description here.
        Args:
            x (Tensor): Tensor of dimension (batch_size, input_dims, input_length).

        Returns:
            Tensor: Predictor tensor of dimension (batch_size, output_dims, input_length-4).
        """
        residual = x
        y = self.resblock_model(x)
        return y + residual


class MelResNet(nn.Module):
    r"""Add description here.
    Args:
        res_blocks (int, optional): Number of ResBlocks. (Default: ``10``).
        input_dims (int, optional): Number of input dimensions (Default: ``100``).
        hidden_dims (int, optional): Number of hidden dimensions (Default: ``128``).
        output_dims (int, optional): Number of ouput dimensions (Default: ``128``).
    """
    def __init__(self, res_blocks: int = 10,
                 input_dims: int = 100,
                 hidden_dims: int = 128,
                 output_dims: int = 128) -> None:
        super().__init__()

        ResBlocks = []
        
        for i in range(res_blocks):
            ResBlocks.append(ResBlock(hidden_dims))

        self.melresnet_model = nn.Sequential(
            nn.Conv1d(in_channels=input_dims, out_channels=hidden_dims, kernel_size=5, bias=False),
            nn.BatchNorm1d(hidden_dims),
            nn.ReLU(inplace=True),
            *ResBlocks,
            nn.Conv1d(in_channels=hidden_dims, out_channels=output_dims, kernel_size=1)
        )
                       
    def forward(self, x: Tensor) -> Tensor:
        r"""Add description here.
        Args:
            x (Tensor): Tensor of dimension (batch_size, input_dims, input_length).

        Returns:
            Tensor: Predictor tensor of dimension (batch_size, output_dims, input_length-4).
        """
        y = self.melresnet_model(x)
        return y

In [13]:
x = torch.rand(32, 100, 20)
res_block = 10
in_dims = 100
compute_dims = 128
res_out_dims = 128
model = MelResNet(res_block, in_dims, compute_dims, res_out_dims)
output = model(x)
print(output.shape)

torch.Size([32, 128, 16])


In [44]:
#@pytest.mark.parametrize('batch_size', [2])
#@pytest.mark.parametrize('num_features', [200])
#@pytest.mark.parametrize('input_dims', [100])
#@pytest.mark.parametrize('output_dims', [128])

def test_waveform(batch_size: int = 2, 
                  num_features: int = 200, 
                  input_dims: int = 100, 
                  output_dims: int = 128) -> Tensor:
    
    model = MelResNet()
    x = torch.rand(batch_size, input_dims, num_features)
    out = model(x)

    assert out.size() == (batch_size, output_dims, num_features - 4)

In [45]:
test_waveform()