# pix2pix

Adapted from https://github.com/eriklindernoren/PyTorch-GAN/tree/master/implementations/pix2pix

In [1]:
import os
import numpy as np
import math
import itertools
import time
import datetime
import sys
import pickle
from multiprocessing import Pool

import torchvision.transforms as transforms
from torchvision.utils import save_image

from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable

from pix2pix.pix2pix_model import *
from mlp import audio
from mlp import normalization
from mlp import utils as mlp
from mlp.dataset import WAVAudioDS, PolarPreprocessing

import torch.nn as nn
import torch.nn.functional as F
import torch
%load_ext autoreload

In [2]:
%matplotlib inline
%autoreload

In [3]:
epoch = 0 # epoch to start training from
n_epochs = 30 # number of epochs of training
dataset_name = 'VCTK' # name of the dataset
batch_size = 1 # size of the batches
lr = 0.0002 # adam: learning rate
b1 = 0.5 # adam: decay of first order momentum of gradient
b2 = 0.999 # adam: decay of first order momentum of gradient
decay_epoch = 100 # epoch from which to start lr decay
n_cpu = 8 # number of cpu threads to use during batch generation
img_height = 64 # size of image height
img_width = 64 # size of image width
channels = 2 # number of image channels
sample_interval = 500 # interval between sampling of images from generators
checkpoint_interval = -1 # interval between model checkpoints

In [4]:
cuda = True if torch.cuda.is_available() else False

# Loss functions
criterion_GAN = torch.nn.MSELoss()
criterion_pixelwise = torch.nn.L1Loss()

# Loss weight of L1 pixel-wise loss between translated image and real image
lambda_pixel = 100

# Calculate output of image discriminator (PatchGAN)
patch = (1, img_height//2**4, img_width//2**4)

In [5]:
# Initialize generator and discriminator
generator = GeneratorUNet(in_channels=channels, out_channels=channels)
discriminator = Discriminator(in_channels=channels)

if cuda:
    generator = generator.cuda()
    discriminator = discriminator.cuda()
    criterion_GAN.cuda()
    criterion_pixelwise.cuda()

# Initialize weights
generator.apply(weights_init_normal)
discriminator.apply(weights_init_normal)

# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(b1, b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(b1, b2))

In [6]:
fs = 48000
bs = batch_size
stroke_width = 32
patch_width = img_width
patch_height = img_height
nperseg = 256

train_files = pickle.load(open("train.pk", "rb"))[:4000]
val_files = pickle.load(open("valid.pk", "rb"))

stroke_mask = mlp.build_stroke_purge_mask(patch_width, patch_height, stroke_width, fs, nperseg)
purge_mask = stroke_mask.float()

preprocess = PolarPreprocessing(
    normalization.norm_mag, 
    normalization.norm_phase, 
    patch_width,
    patch_height
)

torch.multiprocessing.set_sharing_strategy('file_system')
with Pool(8) as p:
    ds_valid = WAVAudioDS(files=val_files, mk_source=lambda x: x * purge_mask, preprocess=preprocess, 
                          patch_width=patch_width, proc_pool=p, nperseg=256, random_patches=False) 
    ds_train = WAVAudioDS(files=train_files, mk_source=lambda x: x * purge_mask, preprocess=preprocess, 
                          patch_width=patch_width, proc_pool=p, nperseg=256, random_patches=False) 
    
val_dataloader = DataLoader(ds_valid, batch_size=bs, num_workers=8, shuffle=False)
dataloader = DataLoader(ds_train, batch_size=bs, num_workers=8, shuffle=False)

HBox(children=(IntProgress(value=0, max=4000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=4000), HTML(value='')))




In [7]:
# Configure dataloaders
#transforms_ = [ transforms.Resize((img_height, img_width), Image.BICUBIC),
#                transforms.ToTensor(),
#                transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)) ]

#dataloader = DataLoader(ImageDataset("../../data/%s" % dataset_name, transforms_=transforms_),
#                        batch_size=batch_size, shuffle=True, num_workers=n_cpu)

#val_dataloader = DataLoader(ImageDataset("../../data/%s" % dataset_name, transforms_=transforms_, mode='val'),
#                            batch_size=10, shuffle=True, num_workers=1)

# Tensor type
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

def sample_images(batches_done):
    """Saves a generated sample from the validation set"""
    imgs = next(iter(val_dataloader))
    real_A = Variable(imgs[0].type(Tensor))
    real_B = Variable(imgs[1].type(Tensor))
    fake_B = generator(real_A)
    img_sample = torch.cat((real_A.data, fake_B.data, real_B.data), -2)
    torch.save(img_sample, '%s.pt' % (batches_done))
    #save_image(img_sample, 'images/%s/%s.png' % (dataset_name, batches_done), nrow=5, normalize=True)

# Training

In [8]:
prev_time = time.time()
results = []
for epoch in range(epoch, n_epochs):
    for i, batch in enumerate(dataloader):

        # Model inputs
        real_A = Variable(batch[0].type(Tensor))
        real_B = Variable(batch[1].type(Tensor))

        # Adversarial ground truths
        valid = Variable(Tensor(np.ones((real_A.size(0), *patch))), requires_grad=False)
        fake = Variable(Tensor(np.zeros((real_A.size(0), *patch))), requires_grad=False)

        # ------------------
        #  Train Generators
        # ------------------

        optimizer_G.zero_grad()

        # GAN loss
        fake_B = generator(real_A)
        pred_fake = discriminator(fake_B, real_A)
        loss_GAN = criterion_GAN(pred_fake, valid)
        # Pixel-wise loss
        loss_pixel = criterion_pixelwise(fake_B, real_B)

        # Total loss
        loss_G = loss_GAN + lambda_pixel * loss_pixel

        loss_G.backward()

        optimizer_G.step()

        # ---------------------
        #  Train Discriminator
        # ---------------------

        optimizer_D.zero_grad()

        # Real loss
        pred_real = discriminator(real_B, real_A)
        loss_real = criterion_GAN(pred_real, valid)

        # Fake loss
        pred_fake = discriminator(fake_B.detach(), real_A)
        loss_fake = criterion_GAN(pred_fake, fake)

        # Total loss
        loss_D = 0.5 * (loss_real + loss_fake)

        loss_D.backward()
        optimizer_D.step()

        # --------------
        #  Log Progress
        # --------------

        # Determine approximate time left
        batches_done = epoch * len(dataloader) + i
        batches_left = n_epochs * len(dataloader) - batches_done
        time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time))
        prev_time = time.time()

        # Print log
        sys.stdout.write("\r[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f, pixel: %f, adv: %f] ETA: %s" %
                                                        (epoch, n_epochs,
                                                        i, len(dataloader),
                                                        loss_D.item(), loss_G.item(),
                                                        loss_pixel.item(), loss_GAN.item(),
                                                        time_left))
        results.append((epoch, loss_D.item(), loss_G.item(), loss_pixel.item(), loss_GAN.item()))
        with open('results.pkl', 'wb') as fp:
            pickle.dump(results, fp)

        # If at sample interval save image
        if batches_done % sample_interval == 0:
            sample_images(batches_done)


    if checkpoint_interval != -1 and epoch % checkpoint_interval == 0:
        # Save model checkpoints
        torch.save(generator.state_dict(), 'generator_%d.pth' % (epoch))
        torch.save(discriminator.state_dict(), 'discriminator_%d.pth' % (epoch))



[Epoch 0/30] [Batch 5510/36736] [D loss: 0.005751] [G loss: 35.259941, pixel: 0.343074, adv: 0.952521] ETA: 6 days, 15:36:05.4777558

KeyboardInterrupt: 