In [51]:
import torchaudio
import torch

In [52]:
# load file at data/vcc2016_training/SF1/100001.wav
waveform, sample_rate = torchaudio.load('../data/vcc2016_training/SF1/100001.wav')
waveform.size()

torch.Size([1, 56314])

# TODO:
- Just try to implement the architecture of CycleGAN
  - <<Here>>
- Then run a single training loop
- Once that works, train the network with SGD
  - This might require actually implementing the dataset

# Next:
- Debug shape mismatches with sample input

In [69]:
class Downsample(torch.nn.Module):
   
    def __init__(self, in_channels, out_channels, kernel_size, stride):
      super(Downsample, self).__init__()
      self.in_channels = in_channels
      self.out_channels = out_channels
      self.conv = torch.nn.Conv2d(in_channels=in_channels,
                                  out_channels=out_channels,
                                  kernel_size=kernel_size,
                                  stride=stride)
      self.norm = torch.nn.InstanceNorm2d(num_features=out_channels, affine=True)
      self.glu = torch.nn.GLU(dim=1)
    
    def forward(self, x):
      # from IPython.core.debugger import set_trace;set_trace()
      x = self.conv(x)
      x = self.norm(x)
      x = self.glu(x)
      return x
  
class ResidualBlock(torch.nn.Module):
  def __init__(self, in_channels1, out_channels1, in_channels2, out_channels2, kernel_size, stride):
    super(ResidualBlock, self).__init__()
    self.conv1 = torch.nn.Conv2d(in_channels=in_channels1,
                                out_channels=out_channels1,
                                kernel_size=kernel_size,
                                stride=stride,
                                padding='same')
    self.norm1 = torch.nn.InstanceNorm2d(num_features=out_channels1, affine=True)
    self.glu = torch.nn.GLU(dim=1)
    self.conv2 = torch.nn.Conv2d(in_channels=in_channels2,
                                out_channels=out_channels2,
                                kernel_size=kernel_size,
                                stride=stride,
                                padding='same')
    self.norm2 = torch.nn.InstanceNorm2d(num_features=out_channels2, affine=True)
  
  def forward(self, x):
    residual = x.clone()
    x = self.conv1(x)
    x = self.norm1(x)
    x = self.glu(x)
    x = self.conv2(x)
    x = self.norm2(x)
    return x + residual

class Upsample(torch.nn.Module):
    
    def __init__(self, in_channels, out_channels, kernel_size, stride):
      super(Upsample, self).__init__()
      self.conv = torch.nn.Conv2d(in_channels=in_channels,
                                          out_channels=out_channels,
                                          kernel_size=kernel_size,
                                          stride=stride)
      self.pixel_shuffle = torch.nn.PixelShuffle(upscale_factor=2)
      self.norm = torch.nn.InstanceNorm2d(num_features=out_channels/2, affine=True)
      self.glu = torch.nn.GLU(dim=1)
    
    def forward(self, x):
      x = self.conv(x)
      x = self.pixel_shuffle(x)
      x = self.norm(x)
      x = self.glu(x)
      return x

class Generator(torch.nn.Module):
    
    def __init__(self):
      super(Generator, self).__init__()
      self.conv1 = torch.nn.Conv2d(in_channels=24, out_channels=128, kernel_size=(1, 5), stride=(1, 2))
      self.glu = torch.nn.GLU(dim=1)
      self.downsample_twice = torch.nn.Sequential(
        Downsample(in_channels=64, out_channels=256, kernel_size=(1, 5), stride=(1, 2)),
        Downsample(in_channels=128, out_channels=512*2, kernel_size=(1, 5), stride=(1, 2))
      )
      self.residual_blocks = torch.nn.Sequential(
        *[ResidualBlock(in_channels1=512, out_channels1=1024,
                        in_channels2=512, out_channels2=512,
                        kernel_size=(1, 3), stride=(1, 1)) for _ in range(6)]
      )
      self.upsample_twice = torch.nn.Sequential(
         Upsample(in_channels=512, out_channels=1024, kernel_size=(1, 5), stride=(1, 1)),
         Upsample(in_channels=512, out_channels=512, kernel_size=(1, 5), stride=(1, 1)),
      )
      self.conv2 = torch.nn.Conv2d(in_channels=256, out_channels=24, kernel_size=(1, 15), stride=(1, 1))

    def forward(self, x):
      x = self.conv1(x)
      x = self.glu(x)
      x = self.downsample_twice(x)
      # from IPython.core.debugger import set_trace;set_trace()
      x = self.residual_blocks(x)
      x = self.upsample_twice(x)
      x = self.conv2(x)
      return x

In [70]:
def test_residual_block():
  residual = ResidualBlock(in_channels1=1024, out_channels1=1024, 
              in_channels2=512, out_channels2=1024,
              kernel_size=(1, 3), stride=(1, 1))

  residual.forward(torch.randn(1, 1024, 1, 1024)).size()
# test_residual_block()

def test_downsample_block():
  downsample = Downsample(in_channels=24, out_channels=256, kernel_size=(1, 5), stride=(1, 2))
  print(downsample.forward(torch.randn(1, 24, 1, 1024)).size())
# test_downsample_block()

def test_generator():
  generator = Generator()
  generator.forward(torch.randn(1, 24, 1, 1024)).size()
test_generator()

RuntimeError: weight should contain 256 elements not 2048