In [1]:
import torch
import torch.nn as nn

In [2]:
def conv_block(in_channels, out_channels, kernel_size, stride, padding):
    return nn.Sequential(
        nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding),
        nn.BatchNorm1d(out_channels),
        nn.ReLU()
    )

In [3]:
class WavGAN(nn.Module):
    def __init__(self):
        super(WavGAN, self).__init__()
        self.conv1 = conv_block(1, 64, 3, 1, 1)
        self.conv2 = conv_block(64, 128, 3, 1, 1)
        self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
        self.conv3 = conv_block(128, 64, 3, 1, 1)
        self.conv4 = nn.Conv1d(64, 1, 3, 1, 1)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.upsample(x)
        x = self.conv3(x)
        x = self.conv4(x)
        return x

In [4]:
class MelGAN(nn.Module):
    def __init__(self):
        super(MelGAN, self).__init__()
        self.conv1 = conv_block(1, 64, 3, 1, 1)
        self.conv2 = conv_block(64, 128, 3, 1, 1)
        self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
        self.conv3 = conv_block(128, 64, 3, 1, 1)
        self.conv4 = nn.Conv1d(64, 1, 3, 1, 1)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.upsample(x)
        x = self.conv3(x)
        x = self.conv4(x)
        return x


In [5]:
class HiFiGAN(nn.Module):
    def __init__(self):
        super(HiFiGAN, self).__init__()
        self.conv1 = conv_block(1, 64, 3, 1, 1)
        self.conv2 = conv_block(64, 128, 3, 1, 1)
        self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
        self.conv3 = conv_block(128, 64, 3, 1, 1)
        self.conv4 = nn.Conv1d(64, 1, 3, 1, 1)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.upsample(x)
        x = self.conv3(x)
        x = self.conv4(x)
        return x

In [6]:
class AudioSuperResolutionModel(nn.Module):
    def __init__(self):
        super(AudioSuperResolutionModel, self).__init__()
        self.wavgan = WavGAN()
        self.melgan = MelGAN()
        self.hifigan = HiFiGAN()
    def forward(self, input_audio):
        representation = input_audio
        wavgan_output = self.wavgan(representation)
        melgan_output = self.melgan(representation)
        hifigan_output = self.hifigan(representation)
        ensemble_output = (wavgan_output + melgan_output + hifigan_output) / 3
        
        return ensemble_output

In [7]:
wavgan_model = WavGAN()
melgan_model = MelGAN()
hifigan_model = HiFiGAN()
ensemble_model = AudioSuperResolutionModel()

In [9]:
def summary(model, input_size):
    print("Model Summary:")
    print("=" * 50)
    print(model)
    print("=" * 50)
    total_params = 0
    for param in model.parameters():
        total_params += param.numel()
    print(f"Total parameters: {total_params}")
    print("=" * 50)

In [10]:
input_size = (1, 1, 100)  # Example input size
summary(ensemble_model, input_size)

Model Summary:
AudioSuperResolutionModel(
  (wavgan): WavGAN(
    (conv1): Sequential(
      (0): Conv1d(1, 64, kernel_size=(3,), stride=(1,), padding=(1,))
      (1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (conv2): Sequential(
      (0): Conv1d(64, 128, kernel_size=(3,), stride=(1,), padding=(1,))
      (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (upsample): Upsample(scale_factor=2.0, mode='nearest')
    (conv3): Sequential(
      (0): Conv1d(128, 64, kernel_size=(3,), stride=(1,), padding=(1,))
      (1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (conv4): Conv1d(64, 1, kernel_size=(3,), stride=(1,), padding=(1,))
  )
  (melgan): MelGAN(
    (conv1): Sequential(
      (0): Conv1d(1, 64, kernel_size=(3,), stride=(1,), padding=(1,))
      (1): BatchNorm1d(64, eps=1e-05, momentum=0.1, a