In [2]:
import torch
import torch.nn as nn
from dataclasses import dataclass, field
from typing import List, Tuple

from fairseq.dataclass import ChoiceEnum, FairseqDataclass
from fairseq.modules import (
    Fp32GroupNorm,
    Fp32LayerNorm,
    GradMultiply,
    GumbelVectorQuantizer,
    LayerNorm,
    MultiheadAttention,
    SamePad,
    TransposeLast,
)

In [3]:
EXTRACTOR_MODE_CHOICES = ChoiceEnum(["default", "layer_norm"])

@dataclass
class ConvFeatureExtractionModelConfig(FairseqDataclass):
    extractor_mode: EXTRACTOR_MODE_CHOICES = field(
        default="default",
        metadata={
            "help": "mode for feature extractor. default has a single group norm with d "
            "groups in the first conv block, whereas layer_norm has layer norms in "
            "every block (meant to use with normalize=True)"
        }
    )
        
    conv_feature_layers: str = field(
        default="[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] + [(512,2,2)]",
        metadata={
            "help": "string describing convolutional feature extraction layers in form of a python list that contains "
            "[(dim, kernel_size, stride), ...]"
        }
    )
    
    conv_bias: bool = field(
        default=False,
        metadata={"help": "include bias in conv encoder"}
    )
    
    conv_dropout: float = field(
        default=0.0,
        metadata={"help": "ratio of droput"}
    )

In [7]:
class ConvFeatureExtractionModel(nn.Module):
    def __init__(
        self,
        cfg: ConvFeatureExtractionModelConfig        
    ):
        super().__init__()
        
        mode = cfg.extractor_mode
        conv_layers = eval(cfg.conv_feature_layers)
        conv_bias = cfg.conv_bias
        dropout = cfg.conv_dropout
        
        assert mode in {"default", "layer_norm"}

        def block(
            n_in,
            n_out,
            k,
            stride,
            is_layer_norm=False,
            is_group_norm=False,
            conv_bias=False,
        ):
            def make_conv():
                conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias)
                nn.init.kaiming_normal_(conv.weight)
                return conv

            assert (
                is_layer_norm and is_group_norm
            ) == False, "layer norm and group norm are exclusive"

            if is_layer_norm:
                return nn.Sequential(
                    make_conv(),
                    nn.Dropout(p=dropout),
                    nn.Sequential(
                        TransposeLast(),
                        Fp32LayerNorm(dim, elementwise_affine=True),
                        TransposeLast(),
                    ),
                    nn.GELU(),
                )
            elif is_group_norm:
                return nn.Sequential(
                    make_conv(),
                    nn.Dropout(p=dropout),
                    Fp32GroupNorm(dim, dim, affine=True),
                    nn.GELU(),
                )
            else:
                return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU())

        in_d = 1
        self.conv_layers = nn.ModuleList()
        for i, cl in enumerate(conv_layers):
            assert len(cl) == 3, "invalid conv definition: " + str(cl)
            (dim, k, stride) = cl

            self.conv_layers.append(
                block(
                    in_d,
                    dim,
                    k,
                    stride,
                    is_layer_norm=mode == "layer_norm",
                    is_group_norm=mode == "default" and i == 0,
                    conv_bias=conv_bias,
                )
            )
            in_d = dim

    def forward(self, x):

        # BxT -> BxCxT
        x = x.unsqueeze(1)

        for conv in self.conv_layers:
            x = conv(x)

        return x

In [18]:
config = ConvFeatureExtractionModelConfig(conv_feature_layers = '[(512, 10, 5)] + [(512, 5, 5)] * 4')
#config = ConvFeatureExtractionModelConfig(extractor_mode = 
#                                          conv_feature_layers =
#                                          conv_bias = 
#                                          conv_dropout = 
#)

config

ConvFeatureExtractionModelConfig(_name=None, extractor_mode='default', conv_feature_layers='[(512, 10, 5)] + [(512, 5, 5)] * 4', conv_bias=False, conv_dropout=0.0)

In [22]:
eval(config.conv_feature_layers)

[(512, 10, 5), (512, 5, 5), (512, 5, 5), (512, 5, 5), (512, 5, 5)]

In [20]:
model = ConvFeatureExtractionModel(config)

In [21]:
model.conv_layers

ModuleList(
  (0): Sequential(
    (0): Conv1d(1, 512, kernel_size=(10,), stride=(5,), bias=False)
    (1): Dropout(p=0.0, inplace=False)
    (2): Fp32GroupNorm(512, 512, eps=1e-05, affine=True)
    (3): GELU()
  )
  (1): Sequential(
    (0): Conv1d(512, 512, kernel_size=(5,), stride=(5,), bias=False)
    (1): Dropout(p=0.0, inplace=False)
    (2): GELU()
  )
  (2): Sequential(
    (0): Conv1d(512, 512, kernel_size=(5,), stride=(5,), bias=False)
    (1): Dropout(p=0.0, inplace=False)
    (2): GELU()
  )
  (3): Sequential(
    (0): Conv1d(512, 512, kernel_size=(5,), stride=(5,), bias=False)
    (1): Dropout(p=0.0, inplace=False)
    (2): GELU()
  )
  (4): Sequential(
    (0): Conv1d(512, 512, kernel_size=(5,), stride=(5,), bias=False)
    (1): Dropout(p=0.0, inplace=False)
    (2): GELU()
  )
)

In [7]:
import netron

input_val = torch.randn(1, 10000)
torch.onnx.export(model, input_val, "ConvFeatExtrc.onnx")

netron.start("ConvFeatExtrc.onnx")

Serving 'ConvFeatExtrc.onnx' at http://localhost:8080


('localhost', 8080)

In [9]:
from torchinfo import summary

summary(model, (1,9000))

Layer (type:depth-idx)                   Output Shape              Param #
ConvFeatureExtractionModel               --                        --
├─ModuleList: 1-1                        --                        --
│    └─Sequential: 2-1                   [1, 512, 1799]            --
│    │    └─Conv1d: 3-1                  [1, 512, 1799]            5,632
│    │    └─Dropout: 3-2                 [1, 512, 1799]            --
│    │    └─Fp32GroupNorm: 3-3           [1, 512, 1799]            1,024
│    │    └─GELU: 3-4                    [1, 512, 1799]            --
│    └─Sequential: 2-2                   [1, 512, 899]             --
│    │    └─Conv1d: 3-5                  [1, 512, 899]             786,944
│    │    └─Dropout: 3-6                 [1, 512, 899]             --
│    │    └─GELU: 3-7                    [1, 512, 899]             --
│    └─Sequential: 2-3                   [1, 512, 449]             --
│    │    └─Conv1d: 3-8                  [1, 512, 449]             786,944