In [1]:
import torch.nn as nn
import torch
import numpy as np

In [2]:
class GLU(nn.Module):
    def __init__(self):
        super(GLU, self).__init__()

    def forward(self, input):
        return input * torch.sigmoid(input)


class PixelShuffle(nn.Module):
    def __init__(self, upscale_factor):
        super(PixelShuffle, self).__init__()
        self.upscale_factor = upscale_factor

    def forward(self, input):
        n = input.shape[0]
        c_out = input.shape[1] // 2
        w_new = input.shape[2] * 2
        return input.view(n, c_out, w_new)

In [3]:
class ResidualLayer(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding):
        super(ResidualLayer, self).__init__()

        self.conv1d_layer = nn.Sequential(nn.Conv1d(in_channels=in_channels,
                                                    out_channels=out_channels,
                                                    kernel_size=kernel_size,
                                                    stride=1,
                                                    padding=padding),
                                          nn.InstanceNorm1d(num_features=out_channels,
                                                            affine=True))

        self.conv_layer_gates = nn.Sequential(nn.Conv1d(in_channels=in_channels,
                                                        out_channels=out_channels,
                                                        kernel_size=kernel_size,
                                                        stride=1,
                                                        padding=padding),
                                              nn.InstanceNorm1d(num_features=out_channels,
                                                                affine=True))

        self.conv1d_out_layer = nn.Sequential(nn.Conv1d(in_channels=out_channels,
                                                        out_channels=in_channels,
                                                        kernel_size=kernel_size,
                                                        stride=1,
                                                        padding=padding),
                                              nn.InstanceNorm1d(num_features=in_channels,
                                                                affine=True))

    def forward(self, input):
        h1_norm = self.conv1d_layer(input)
        h1_gates_norm = self.conv_layer_gates(input)

        # GLU
        h1_glu = h1_norm * torch.sigmoid(h1_gates_norm)

        h2_norm = self.conv1d_out_layer(h1_glu)
        return input + h2_norm

In [4]:
class DownsampleLayer(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding):
        super(DownsampleLayer, self).__init__()

        self.convLayer = nn.Sequential(nn.Conv2d(in_channels=in_channels,
                                                 out_channels=out_channels,
                                                 kernel_size=kernel_size,
                                                 stride=stride,
                                                 padding=padding),
                                       nn.InstanceNorm2d(num_features=out_channels,
                                                         affine=True))
        self.convLayer_gates = nn.Sequential(nn.Conv2d(in_channels=in_channels,
                                                       out_channels=out_channels,
                                                       kernel_size=kernel_size,
                                                       stride=stride,
                                                       padding=padding),
                                             nn.InstanceNorm2d(num_features=out_channels,
                                                               affine=True))

    def forward(self, input):
        # GLU
        return self.convLayer(input) * torch.sigmoid(self.convLayer_gates(input))

class UpsampleLayer(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding):
        super().__init__()
        self.module = nn.Sequential(
            nn.Conv2d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=kernel_size,
                stride=stride,
                padding=padding
            ),
            nn.PixelShuffle(
                upscale_factor=2
            ),
            nn.InstanceNorm2d(
                num_features=out_channels // 4,
                affine=True),
            GLU()
        )
    
    def forward(self, x):
        return self.module(x)

In [113]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        # 2D Conv Layer 
        self.conv1 = nn.Conv2d(in_channels=1,  # TODO 1 ?
                               out_channels=128,
                               kernel_size=(5, 15),
                               stride=(1, 1),
                               padding=(2, 7))

        self.conv1_gates = nn.Conv2d(in_channels=1,  # TODO 1 ?
                                     out_channels=128,
                                     kernel_size=(5, 15),
                                     stride=1,
                                     padding=(2, 7))

        # 2D Downsample Layer
        self.downSample1 = DownsampleLayer(in_channels=128,
                                                out_channels=256,
                                                kernel_size=5,
                                                stride=2,
                                                padding=2)

        self.downSample2 = DownsampleLayer(in_channels=256,
                                                out_channels=256,
                                                kernel_size=5,
                                                stride=2,
                                                padding=2)

        # 2D -> 1D Conv
        self.conv2dto1dLayer = nn.Sequential(nn.Conv1d(in_channels=256*65,
                                                       out_channels=256,
                                                       kernel_size=1,
                                                       stride=1,
                                                       padding=0),
                                             nn.InstanceNorm1d(num_features=256,
                                                               affine=True))

        # Residual Blocks
        self.residualLayer1 = ResidualLayer(in_channels=256,
                                            out_channels=512,
                                            kernel_size=3,
                                            stride=1,
                                            padding=1)
        self.residualLayer2 = ResidualLayer(in_channels=256,
                                            out_channels=512,
                                            kernel_size=3,
                                            stride=1,
                                            padding=1)
        self.residualLayer3 = ResidualLayer(in_channels=256,
                                            out_channels=512,
                                            kernel_size=3,
                                            stride=1,
                                            padding=1)
        self.residualLayer4 = ResidualLayer(in_channels=256,
                                            out_channels=512,
                                            kernel_size=3,
                                            stride=1,
                                            padding=1)
        self.residualLayer5 = ResidualLayer(in_channels=256,
                                            out_channels=512,
                                            kernel_size=3,
                                            stride=1,
                                            padding=1)
        self.residualLayer6 = ResidualLayer(in_channels=256,
                                            out_channels=512,
                                            kernel_size=3,
                                            stride=1,
                                            padding=1)

        # 1D -> 2D Conv
        self.conv1dto2dLayer = nn.Sequential(nn.Conv1d(in_channels=256,
                                                       out_channels=256*65,
                                                       kernel_size=1,
                                                       stride=1,
                                                       padding=0),
                                             nn.InstanceNorm1d(num_features=256*65,
                                                               affine=True))

        # UpSample Layer
        self.upSample1 = UpsampleLayer(in_channels=256,
                                       out_channels=1024,
                                       kernel_size=5,
                                       stride=1,
                                       padding=2)

        self.upSample2 = UpsampleLayer(in_channels=256,
                                       out_channels=512,
                                       kernel_size=5,
                                       stride=1,
                                       padding=2)

        self.lastConvLayer = nn.Conv2d(in_channels=128,
                                       out_channels=1,
                                       kernel_size=(6, 16),
                                       stride=(1, 1),
                                       padding=(1, 6))

    def forward(self, input):
        # GLU
        input = input.unsqueeze(1)
        conv1 = self.conv1(input) * torch.sigmoid(self.conv1_gates(input))
        # DownloadSample
        downsample1 = self.downSample1(conv1)
        downsample2 = self.downSample2(downsample1)
        # 2D -> 1D
        # reshape
        reshape2dto1d = downsample2.view(downsample2.size(0), 256*65, 1, -1)
        #print(reshape2dto1d.shape)
        reshape2dto1d = reshape2dto1d.squeeze(2)
        conv2dto1d_layer = self.conv2dto1dLayer(reshape2dto1d)
        #return conv2dto1d_layer
        residual_layer_1 = self.residualLayer1(conv2dto1d_layer)
        residual_layer_2 = self.residualLayer2(residual_layer_1)
        residual_layer_3 = self.residualLayer3(residual_layer_2)
        residual_layer_4 = self.residualLayer4(residual_layer_3)
        residual_layer_5 = self.residualLayer5(residual_layer_4)
        residual_layer_6 = self.residualLayer6(residual_layer_5)
        #return residual_layer_6
        # 1D -> 2D
        conv1dto2d_layer = self.conv1dto2dLayer(residual_layer_6)
        # reshape
        reshape1dto2d = conv1dto2d_layer.unsqueeze(2)
        reshape1dto2d = reshape1dto2d.view(reshape1dto2d.size(0), 256, 65, -1)
        # UpSample
        upsample_layer_1 = self.upSample1(reshape1dto2d)
        upsample_layer_2 = self.upSample2(upsample_layer_1)
        output = self.lastConvLayer(upsample_layer_2)
        output = output.squeeze(1)
        return output

In [146]:
# dyskryminator
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.convLayer1 = nn.Sequential(nn.Conv2d(in_channels=1,
                                                  out_channels=128,
                                                  kernel_size=(3, 3),
                                                  stride=(1, 1),
                                                  padding=(1, 1)),
                                        GLU())

        # DownSample Layer
        self.downSample1 = self.downSample(in_channels=128,
                                           out_channels=256,
                                           kernel_size=(3, 3),
                                           stride=(2, 2),
                                           padding=1)

        self.downSample2 = self.downSample(in_channels=256,
                                           out_channels=512,
                                           kernel_size=(3, 3),
                                           stride=[2, 2],
                                           padding=1)

        self.downSample3 = self.downSample(in_channels=512,
                                           out_channels=1024,
                                           kernel_size=[6, 3],
                                           stride=[2, 2],
                                           padding=1)

        # Conv Layer
        self.outputConvLayer = nn.Sequential(nn.Conv2d(in_channels=1024,
                                                       out_channels=1,
                                                       kernel_size=(1, 3),
                                                       stride=[1, 1],
                                                       padding=[0, 1]))

    def downSample(self, in_channels, out_channels, kernel_size, stride, padding):
        convLayer = nn.Sequential(nn.Conv2d(in_channels=in_channels,
                                            out_channels=out_channels,
                                            kernel_size=kernel_size,
                                            stride=stride,
                                            padding=padding),
                                  nn.InstanceNorm2d(num_features=out_channels,
                                                    affine=True),
                                  GLU())
        return convLayer

    def forward(self, input):
        # input has shape [batch_size, num_features, time]
        # discriminator requires shape [batchSize, 1, num_features, time]
        input = input.unsqueeze(1)
        conv_layer_1 = self.convLayer1(input)

        downsample1 = self.downSample1(conv_layer_1)
        downsample2 = self.downSample2(downsample1)
        downsample3 = self.downSample3(downsample2)

        # downsample3 = downsample3.contiguous().permute(0, 2, 3, 1).contiguous()
        # print("Discriminator forward downsample3: ", downsample3.shape)

        output = torch.sigmoid(self.outputConvLayer(downsample3))
        # print("Discriminator forward output: ", output.shape)
        return output

In [144]:
# prev # inp = torch.tensor(np.random.randn(1, 36, 128), dtype=torch.float32)
# torch.Size([1, 36, 128])
# torch.Size([1, 1, 5, 16])

In [145]:
generator = Generator()
discriminator = Discriminator()
input_tensor = torch.tensor(np.random.randn(1, 257, 501), dtype=torch.float32)
gen_result = generator(input_tensor)
disc_out = discriminator(gen_result)
print(gen_result.shape)
print(disc_out.shape)

torch.Size([1, 257, 501])
torch.Size([1, 1, 31, 63])


In [147]:
# training
def train_step(
        generator_src_trg,
        generator_trg_src,
        disc_src,
        disc_trg,
        train_loader,
        gen_optimizer,
        disc_optimizer,
        config
):
    # trainloader need to be tqdm style
    for idx, (real_A, real_B) in enumerate(train_loader):
        #
        # add to config
        cycle_loss_lambda = 10
        identity_loss_lambda = 5
        # add to config
        #
        #num_iterations =
        real_A = real_A.to(config.device)
        real_B = real_B.to(config.device)
        
        fake_B = generator_src_trg(real_A)
        cycle_A = generator_trg_src(fake_B)
        
        fake_A = generator_trg_src(real_B)
        cycle_B = generator_src_trg(fake_A)
        
        identity_A = generator_trg_src(real_A)
        identity_B = generator_src_trg(real_B)
        
        d_fake_A = disc_src(fake_A)
        d_fake_B = disc_trg(fake_B)
        
        d_fake_cycle_A = disc_src(cycle_A)
        d_fake_cycle_B = disc_trg(cycle_B)
        
        cycleLoss = torch.mean(
            torch.abs(real_A - cycle_A) + torch.mean(torch.abs(real_B - cycle_B))
        )
        
        identity_loss = torch.mean(
            torch.abs(real_A - identity_A) + torch.mean(torch.abs(real_B - identity_B))
        )
        
        generator_loss_A2B = torch.mean((1 - d_fake_B) ** 2)
        generator_loss_B2A = torch.mean((1 - d_fake_A) ** 2)
        
        generator_loss = generator_loss_A2B + generator_loss_B2A + cycle_loss_lambda * cycleLoss + \
                            identity_loss_lambda * identity_loss
        # add generator loss
        gen_optimizer.zero_grad()
        disc_optimizer.zero_grad()
        
        generator_loss.backward()
        gen_optimizer.step()
        
        # discriminator train
        d_real_A = disc_src(real_A)
        d_real_B = disc_trg(real_B)
        
        generated_A = generator_trg_src(real_B)
        d_fake_A = disc_src(generated_A)
        
        cycled_B = generator_src_trg(generated_A)
        d_cycled_B = disc_trg(cycled_B)
        
        generated_B = generator_src_trg(real_A)
        d_fake_B = disc_trg(generated_B)
        
        cycled_A = generator_trg_src(generated_B)
        d_cycled_A = disc_src(cycled_A)
        
        d_loss_A_real = torch.mean((1 - d_real_A) ** 2)
        d_loss_A_fake = torch.mean((0 - d_fake_A) ** 2)
        d_loss_A = (d_loss_A_real + d_loss_A_fake) / 2.0
        
        d_loss_B_real = torch.mean((1 - d_real_B) ** 2)
        d_loss_B_fake = torch.mean((0 - d_fake_B) ** 2)
        d_loss_B = (d_loss_B_real + d_loss_B_fake) / 2.0
        
        d_loss_A_cycled = torch.mean((0 - d_cycled_A) ** 2)
        d_loss_B_cycled = torch.mean((0 - d_cycled_B) ** 2)
        
        d_loss_A_2nd = (d_loss_A_real + d_loss_A_cycled) / 2.0
        d_loss_B_2nd = (d_loss_B_real + d_loss_B_cycled) / 2.0
        
        d_loss = (d_loss_A + d_loss_B) / 2.0 + (d_loss_A_2nd + d_loss_B_2nd) / 2.0
        # add to store d_loss
        disc_optimizer.zero_grad()
        gen_optimizer.zero_grad()
        
        d_loss.backward()
        disc_optimizer.step()
        pbar.set_description(
                        "Iter:{} Generator Loss:{:.4f} Discrimator Loss:{:.4f} GA2B:{:.4f} GB2A:{:.4f} G_id:{:.4f} G_cyc:{:.4f} D_A:{:.4f} D_B:{:.4f}".format(
                        num_iterations,
                        generator_loss.item(),
                        d_loss.item(), generator_loss_A2B,
                        generator_loss_B2A, identiyLoss,
                        cycleLoss, d_loss_A, _loss_B)
                        )

In [None]:
def main(config):
    raise Exception("tO be implemented!")