# pix2pix

Adapted from https://github.com/eriklindernoren/PyTorch-GAN/tree/master/implementations/pix2pix

In [1]:
import os
import numpy as np
import math
import itertools
import time
import datetime
import sys
import pickle
from multiprocessing import Pool

import torchvision.transforms as transforms
from torchvision.utils import save_image

from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable

from pix2pix.pix2pix_model import *
from unet.unet.unet_parts import *
from mlp import audio
from mlp import normalization
from mlp import utils as mlp
from mlp.dataset import WAVAudioDS, PolarPreprocessing

import torch.nn as nn
import torch.nn.functional as F
import torch
%load_ext autoreload

In [2]:
%matplotlib inline
%autoreload

In [3]:
epoch = 0 # epoch to start training from
n_epochs = 30 # number of epochs of training
dataset_name = 'VCTK' # name of the dataset
batch_size = 4 # size of the batches
lr = 0.0002 # adam: learning rate
b1 = 0.5 # adam: decay of first order momentum of gradient
b2 = 0.999 # adam: decay of first order momentum of gradient
decay_epoch = 100 # epoch from which to start lr decay
n_cpu = 4 # number of cpu threads to use during batch generation
img_height = 64 # size of image height
img_width = 64 # size of image width
channels = 1 # number of image channels
sample_interval = 10000 # interval between sampling of images from generators
checkpoint_interval = 1 # interval between model checkpoints

In [4]:
cuda = True if torch.cuda.is_available() else False
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Loss functions
criterion_GAN = torch.nn.MSELoss()
criterion_pixelwise = torch.nn.L1Loss()

# Loss weight of L1 pixel-wise loss between translated image and real image
lambda_pixel = 100

# Calculate output of image discriminator (PatchGAN)
patch = (1, img_height//2**4, img_width//2**4)

In [5]:
class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()
        self.inc = inconv(1, 64)
        self.down1 = down(64, 128)
        self.down2 = down(128, 256)
        self.down3 = down(256, 512, dropout=0.5)
        self.down4 = down(512, 512, dropout=0.5)
        self.up1 = up(1024, 256, bilinear=True)
        self.up2 = up(512, 128, bilinear=True)
        self.up3 = up(256, 64, bilinear=True)
        self.up4 = up(128, 64, bilinear=True)
        self.outc = outconv(64, 1)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        x = self.outc(x)
        return x
    
generator = UNet().to(device)

In [6]:
discriminator = Discriminator(in_channels=channels)

if cuda:
    generator = generator.cuda()
    discriminator = discriminator.cuda()
    criterion_GAN.cuda()
    criterion_pixelwise.cuda()

if epoch != 0:
    # Load pretrained models
    generator.load_state_dict(torch.load('generator_%d.pth' % (epoch)))
    discriminator.load_state_dict(torch.load('discriminator_%d.pth' % (epoch)))
else:
    # Initialize weights
    #generator.apply(weights_init_normal)
    generator.load_state_dict(torch.load('model_weights.pt', map_location='cpu'))
    discriminator.apply(weights_init_normal)

# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(b1, b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(b1, b2))

In [7]:
fs = 48000
bs = batch_size
stroke_width = 32
patch_width = img_width
patch_height = img_height
nperseg = 256

train_files = pickle.load(open("train.pk", "rb"))
val_files = pickle.load(open("valid.pk", "rb"))

stroke_mask = mlp.build_stroke_purge_mask(patch_width, patch_height, stroke_width, fs, nperseg)
purge_mask = stroke_mask.float()

preprocess = PolarPreprocessing(
    normalization.norm_mag, 
    normalization.norm_phase, 
    patch_width,
    patch_height
)

torch.multiprocessing.set_sharing_strategy('file_system')
with Pool(4) as p:
    ds_valid = WAVAudioDS(files=val_files, mk_source=lambda x: x * purge_mask, preprocess=preprocess, 
                          patch_width=patch_width, proc_pool=p, nperseg=256, random_patches=True) 
    ds_train = WAVAudioDS(files=train_files, mk_source=lambda x: x * purge_mask, preprocess=preprocess, 
                          patch_width=patch_width, proc_pool=p, nperseg=256, random_patches=True) 
    
val_dataloader = DataLoader(ds_valid, batch_size=bs, num_workers=4, shuffle=False)
dataloader = DataLoader(ds_train, batch_size=bs, num_workers=4, shuffle=False)

HBox(children=(IntProgress(value=0), HTML(value='')))




HBox(children=(IntProgress(value=0, max=200), HTML(value='')))




In [8]:
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

def sample_images(batches_done):
    """Saves a generated sample from the validation set"""
    imgs = next(iter(val_dataloader))
    real_A = Variable(torch.unsqueeze(imgs[0][:,0,:,:], 1).type(Tensor))
    real_B = Variable(torch.unsqueeze(imgs[1][:,0,:,:], 1).type(Tensor))
    fake_B = generator(real_A)
    img_sample = torch.cat((real_A.data, fake_B.data, real_B.data), -2)
    torch.save(img_sample, '%s.pt' % (batches_done))
    #save_image(img_sample, 'images/%s/%s.png' % (dataset_name, batches_done), nrow=5, normalize=True)

# Training

In [9]:
prev_time = time.time()
results = []
for epoch in range(epoch, n_epochs):
    for i, batch in enumerate(dataloader):
        # Model inputs
        real_A = Variable(torch.unsqueeze(batch[0][:,0,:,:], 1).type(Tensor)) # Gap
        real_B = Variable(torch.unsqueeze(batch[1][:,0,:,:], 1).type(Tensor)) # Original (No gap)
        # Adversarial ground truths
        valid = Variable(Tensor(np.ones((real_A.size(0), *patch))), requires_grad=False)
        fake = Variable(Tensor(np.zeros((real_A.size(0), *patch))), requires_grad=False)

        # ------------------
        #  Train Generators
        # ------------------

        optimizer_G.zero_grad()

        # GAN loss
        fake_B = generator(real_A)
        pred_fake = discriminator(fake_B, real_A)
        loss_GAN = criterion_GAN(pred_fake, valid)
        
        # Pixel-wise loss
        loss_pixel = criterion_pixelwise(fake_B, real_B)
        
        # Total loss
        loss_G = loss_GAN + lambda_pixel * loss_pixel

        loss_G.backward()

        optimizer_G.step()

        # ---------------------
        #  Train Discriminator
        # ---------------------

        optimizer_D.zero_grad()

        # Real loss
        pred_real = discriminator(real_B, real_A)
        loss_real = criterion_GAN(pred_real, valid)

        # Fake loss
        pred_fake = discriminator(fake_B.detach(), real_A)
        loss_fake = criterion_GAN(pred_fake, fake)

        # Total loss
        loss_D = 0.5 * (loss_real + loss_fake)

        loss_D.backward()
        optimizer_D.step()

        # --------------
        #  Log Progress
        # --------------

        # Determine approximate time left
        batches_done = epoch * len(dataloader) + i
        batches_left = n_epochs * len(dataloader) - batches_done
        time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time))
        prev_time = time.time()

        # Print log
        sys.stdout.write("\r[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f, pixel: %f, adv: %f] ETA: %s" %
                                                        (epoch, n_epochs,
                                                        i, len(dataloader),
                                                        loss_D.item(), loss_G.item(),
                                                        loss_pixel.item(), loss_GAN.item(),
                                                        time_left))
        results.append((epoch, loss_D.item(), loss_G.item(), loss_pixel.item(), loss_GAN.item()))
        with open('results.pkl', 'wb') as fp:
            pickle.dump(results, fp)

        # If at sample interval save image
        if batches_done % sample_interval == 0:
            sample_images(batches_done)

    if checkpoint_interval != -1 and epoch % checkpoint_interval == 0:
        # Save model checkpoints
        torch.save(generator.state_dict(), 'generator_%d.pth' % (epoch))
        torch.save(discriminator.state_dict(), 'discriminator_%d.pth' % (epoch))



[Epoch 0/30] [Batch 156/443] [D loss: 0.320208] [G loss: 23.890461, pixel: 0.232825, adv: 0.608012] ETA: 4:22:57.932459

KeyboardInterrupt: 