In [1]:
int((2048 // 2) + 1)

1025

In [2]:
int(2048/2.0)

1024

In [1]:
int(2048/0.8)

2560

In [1]:
symbols = [
    'EOS', ' ', '!', ',', '-', '.', \
    ';', '?', 'a', 'b', 'c', 'd', 'e', 'f', \
    'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', \
    'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', 'à', \
    'â', 'è', 'é', 'ê', 'ü', '’', '“', '”' \
  ]
2*len(symbols) 

86

In [None]:
  def forward(self, x):
        """
        Forward pass of the PostNet.

        Args:
            x (Tensor): Input mel spectrogram

        Returns:
            Tensor: Refined mel spectrogram
        """
        x = x.transpose(2, 1) # (N, FREQ, TIME)
        x = self.conv_1(x)
        x = self.bn_1(x)
        x = torch.tanh(x)
        x = self.dropout_1(x) # (N, POSNET_DIM, TIME)
        x = self.conv_2(x)
        x = self.bn_2(x)
        x = torch.tanh(x)
        x = self.dropout_2(x) # (N, POSNET_DIM, TIME)
        x = self.conv_3(x)
        x = self.bn_3(x)
        x = torch.tanh(x)
        x = self.dropout_3(x) # (N, POSNET_DIM, TIME)    
        x = self.conv_4(x)
        x = self.bn_4(x)
        x = torch.tanh(x)
        x = self.dropout_4(x) # (N, POSNET_DIM, TIME)    
        x = self.conv_5(x)
        x = self.bn_5(x)
        x = torch.tanh(x)
        x = self.dropout_5(x) # (N, POSNET_DIM, TIME)
        x = self.conv_6(x)
        x = self.bn_6(x)
        x = self.dropout_6(x) # (N, FREQ, TIME)
        x = x.transpose(1, 2)
        return x
    
class DecoderPreNet(nn.Module):
    """
    Decoder pre-network that processes mel spectrograms before the main decoder.

    This network applies linear transformations with dropout.
    """

    def __init__(self):
        """
        Initialize the DecoderPreNet with its layers.
        """
        super(DecoderPreNet, self).__init__()
        self.linear_1 = nn.Linear(hp.mel_freq, hp.embedding_size)
        self.linear_2 = nn.Linear(hp.embedding_size, hp.embedding_size)

    def forward(self, x):
        """
        Forward pass of the DecoderPreNet.

        Args:
            x (Tensor): Input mel spectrogram

        Returns:
            Tensor: Processed mel spectrogram
        """
        x = self.linear_1(x)
        x = F.relu(x)
        x = F.dropout(x, p=0.5, training=True)
        x = self.linear_2(x)
        x = F.relu(x)    
        x = F.dropout(x, p=0.5, training=True)
        return x    

class TransformerTTS(nn.Module):
    """
    Main Transformer-based Text-to-Speech model.

    This model combines encoder, decoder, and various auxiliary networks to generate mel spectrograms from text.
    """

    def __init__(self, device="cuda"):
        """
        Initialize the TransformerTTS model with its components.

        Args:
            device (str): Device to run the model on
        """
        super(TransformerTTS, self).__init__()
        self.encoder_prenet = EncoderPreNet()
        self.decoder_prenet = DecoderPreNet()
        self.postnet = PostNet()
        self.pos_encoding = nn.Embedding(num_embeddings=hp.max_mel_time, embedding_dim=hp.embedding_size)
        self.encoder_block_1 = EncoderBlock()
        self.encoder_block_2 = EncoderBlock()
        self.encoder_block_3 = EncoderBlock()
        self.decoder_block_1 = DecoderBlock()
        self.decoder_block_2 = DecoderBlock()
        self.decoder_block_3 = DecoderBlock()
        self.linear_1 = nn.Linear(hp.embedding_size, hp.mel_freq) 
        self.linear_2 = nn.Linear(hp.embedding_size, 1)
        self.norm_memory = nn.LayerNorm(normalized_shape=hp.embedding_size)

    def forward(self, text, text_len, mel, mel_len):
        """
        Forward pass of the TransformerTTS model.

        Args:
            text (Tensor): Input text tensor
            text_len (Tensor): Lengths of input texts
            mel (Tensor): Target mel spectrogram
            mel_len (Tensor): Lengths of target mel spectrograms

        Returns:
            Tuple[Tensor, Tensor, Tensor]: Predicted mel spectrogram (post-net), 
                                           predicted mel spectrogram (pre-net),
                                           stop token predictions
        """
        N = text.shape[0]
        S = text.shape[1]
        TIME = mel.shape[1]

        # Create masks
        self.src_key_padding_mask = torch.zeros((N, S), device=text.device).masked_fill(
            ~mask_from_seq_lengths(text_len, max_length=S), float("-inf")
        )
        self.src_mask = torch.zeros((S, S), device=text.device).masked_fill(
            torch.triu(torch.full((S, S), True, dtype=torch.bool), diagonal=1).to(text.device),       
            float("-inf")
        )
        self.tgt_key_padding_mask = torch.zeros((N, TIME), device=mel.device).masked_fill(
            ~mask_from_seq_lengths(mel_len, max_length=TIME), float("-inf")
        )
        self.tgt_mask = torch.zeros((TIME, TIME), device=mel.device).masked_fill(
            torch.triu(torch.full((TIME, TIME), True, device=mel.device, dtype=torch.bool), diagonal=1),       
            float("-inf")
        )
        self.memory_mask = torch.zeros((TIME, S), device=mel.device).masked_fill(
            torch.triu(torch.full((TIME, S), True, device=mel.device, dtype=torch.bool), diagonal=1),       
            float("-inf")
        )    

        # Encoder
        text_x = self.encoder_prenet(text)
        pos_codes = self.pos_encoding(torch.arange(hp.max_mel_time).to(mel.device))
        S = text_x.shape[1]
        text_x = text_x + pos_codes[:S]
        text_x = self.encoder_block_1(text_x, attn_mask=self.src_mask, key_padding_mask=self.src_key_padding_mask)
        text_x = self.encoder_block_2(text_x, attn_mask=self.src_mask, key_padding_mask=self.src_key_padding_mask)
        text_x = self.encoder_block_3(text_x, attn_mask=self.src_mask, key_padding_mask=self.src_key_padding_mask)
        text_x = self.norm_memory(text_x)
        
        # Decoder
        mel_x = self.decoder_prenet(mel)
        mel_x = mel_x + pos_codes[:TIME]
        mel_x = self.decoder_block_1(x=mel_x, memory=text_x, x_attn_mask=self.tgt_mask, 
                                     x_key_padding_mask=self.tgt_key_padding_mask,
                                     memory_attn_mask=self.memory_mask,
                                     memory_key_padding_mask=self.src_key_padding_mask)
        mel_x = self.decoder_block_2(x=mel_x, memory=text_x, x_attn_mask=self.tgt_mask, 
                                     x_key_padding_mask=self.tgt_key_padding_mask,
                                     memory_attn_mask=self.memory_mask,
                                     memory_key_padding_mask=self.src_key_padding_mask)
        mel_x = self.decoder_block_3(x=mel_x, memory=text_x, x_attn_mask=self.tgt_mask, 
                                     x_key_padding_mask=self.tgt_key_padding_mask,
                                     memory_attn_mask=self.memory_mask,
                                     memory_key_padding_mask=self.src_key_padding_mask)

        # Output processing
        mel_linear = self.linear_1(mel_x)
        mel_postnet = self.postnet(mel_linear)
        mel_postnet = mel_linear + mel_postnet
        stop_token = self.linear_2(mel_x)

        # Masking
        bool_mel_mask = self.tgt_key_padding_mask.ne(0).unsqueeze(-1).repeat(1, 1, hp.mel_freq)
        mel_linear = mel_linear.masked_fill(bool_mel_mask, 0)
        mel_postnet = mel_postnet.masked_fill(bool_mel_mask, 0)
        stop_token = stop_token.masked_fill(bool_mel_mask[:, :, 0].unsqueeze(-1), 1e3).squeeze(2)
        
        return mel_postnet, mel_linear, stop_token 