In [1]:
import torch
import librosa

# frontend

In [22]:
from model.frontend.default import DefaultFrontend

frontend = DefaultFrontend()

# Spec Augmentation

In [23]:
from model.specaug.specaug import SpecAug

specaug = SpecAug(freq_mask_width_range=[0,30],time_mask_width_range=[0,40])

# normalize

In [4]:
from model.layers.global_mvn import GlobalMVN

normalizer = GlobalMVN(stats_file="test_utils/feats_stats.npz")

# Encoder

In [30]:
from model.encoder.transformer_encoder import TransformerEncoder

encoder = TransformerEncoder(
    input_size = 80, 
    input_layer="conv2d6", 
    num_blocks= 18, 
    output_size=512,
    dropout_rate = 0.1,
    positional_dropout_rate = 0.1,
    attention_dropout_rate= 0.1
)

# Decoder

In [6]:
from model.decoder.transformer_decoder import  TransformerDecoder

decoder = TransformerDecoder(vocab_size=3262, encoder_output_size=512,self_attention_dropout_rate=0.1,src_attention_dropout_rate=0.1)

In [26]:
from model.ctc import CTC

ctc = CTC(odim=3262,encoder_output_size = 512)

# label smoothing

In [31]:
from model.asr_model import ASRModel

model = ASRModel(
    vocab_size = 3262,
    token_list = ["aaa","bbb","<blank>"],
    frontend = frontend,
    specaug = specaug,
    normalize = normalizer,
    preencoder= None,
    encoder = encoder,
    postencoder = None,
    decoder = decoder,
    ctc = ctc,
    joint_network = None,
)

In [32]:
model

ASRModel(
  (frontend): DefaultFrontend(
    (stft): Stft(n_fft=512, win_length=512, hop_length=128, center=True, normalized=False, onesided=True)
    (logmel): LogMel(sr=16000, n_fft=512, n_mels=80, fmin=0, fmax=8000.0, htk=False)
  )
  (specaug): SpecAug(
    (time_warp): TimeWarp(window=5, mode=bicubic)
    (freq_mask): MaskAlongAxis(mask_width_range=[0, 30], num_mask=2, axis=freq)
    (time_mask): MaskAlongAxis(mask_width_range=[0, 40], num_mask=2, axis=time)
  )
  (normalize): GlobalMVN(stats_file=test_utils/feats_stats.npz, norm_means=True, norm_vars=True)
  (encoder): TransformerEncoder(
    (embed): Conv2dSubsampling6(
      (conv): Sequential(
        (0): Conv2d(1, 512, kernel_size=(3, 3), stride=(2, 2))
        (1): ReLU()
        (2): Conv2d(512, 512, kernel_size=(5, 5), stride=(3, 3))
        (3): ReLU()
      )
      (out): Sequential(
        (0): Linear(in_features=6144, out_features=512, bias=True)
        (1): PositionalEncoding(
          (dropout): Dropout(p=0.1, in