In [1]:
import torch.nn as nn
import torch.nn.functional as F

class CombinedModel(nn.Module):
    def __init__(self, configs, dropout=0.1):
        super(CombinedModel, self).__init__()
        
        self.layer1 = LinearNorm(configs['EncDec_Configs']['embed_dim'],int(configs['EncDec_Configs']['embed_dim'] // 2))
        self.layer2 = LinearNorm(int(configs['EncDec_Configs']['embed_dim'] // 2),configs['Audio_Configs']['num_mels'])
        self.drop1 = nn.Dropout(dropout)
        self.drop2 = nn.Dropout(dropout)

        #postnet starts
        self.convolutions = nn.ModuleList()
        self.dropout = nn.Dropout(dropout)
        self.convolutions.append(
            nn.Sequential(
                ConvNorm(configs['Audio_Configs']['num_mels'], configs['Postnet_Configs']['postnet_embedding_dim'],
                         kernel_size=configs['Postnet_Configs']['postnet_kernel_size'], stride=1,
                         padding=int((configs['Postnet_Configs']['postnet_kernel_size'] - 1) / 2),
                         dilation=1, w_init_gain='tanh'),
                nn.BatchNorm1d(configs['Postnet_Configs']['postnet_embedding_dim']))
        )

        for i in range(1, configs['Postnet_Configs']['postnet_n_convolutions'] - 1):
            self.convolutions.append(
                nn.Sequential(
                    ConvNorm(configs['Postnet_Configs']['postnet_embedding_dim'],
                             configs['Postnet_Configs']['postnet_embedding_dim'],
                             kernel_size=configs['Postnet_Configs']['postnet_kernel_size'], stride=1,
                             padding=int((configs['Postnet_Configs']['postnet_kernel_size'] - 1) / 2),
                             dilation=1, w_init_gain='tanh'),
                    nn.BatchNorm1d(configs['Postnet_Configs']['postnet_embedding_dim']))
            )

        self.convolutions.append(
            nn.Sequential(
                ConvNorm(configs['Postnet_Configs']['postnet_embedding_dim'], 
                         configs['Audio_Configs']['num_mels'],
                         kernel_size=configs['Postnet_Configs']['postnet_kernel_size'], stride=1,
                         padding=int((configs['Postnet_Configs']['postnet_kernel_size'] - 1) / 2),
                         dilation=1, w_init_gain='linear'),
                nn.BatchNorm1d(configs['Audio_Configs']['num_mels']))
            )

        self.stop_linear = nn.Linear(configs['EncDec_Configs']['embed_dim'], 1)

    def forward(self, x):
        # all_mel_linear
        x_mel_linear = self.drop1(F.relu(self.layer1(x)))
        x_mel_linear = self.drop2(self.layer2(x_mel_linear))
        
        #stop_linear
        stoplinear_output = self.stop_linear(x)

        #all postnet
        x_postnet = x_mel_linear.transpose(1,2)
        for i in range(len(self.convolutions) - 1):
            x_postnet = self.dropout(torch.tanh(self.convolutions[i](x_postnet)))
        x_postnet = self.dropout(self.convolutions[-1](x_postnet))
        x_postnet = x_postnet.transpose(1,2)

        mel_out = x_mel_linear + x_postnet


        return x_mel_linear, stoplinear_output, mel_out