In [1]:
import os

os.environ["CUDA_VISIBLE_DEVICES"]="5"

import torch
import torchvision
from torch import nn
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torch.utils.data import Dataset

from torchvision import transforms
from torchvision.utils import save_image
import torch.nn.functional as F
from random import randint
import numpy as np
import glob
from PIL import Image
from tqdm import tqdm_notebook
import torch.backends.cudnn as cudnn
cudnn.benchmark = True

from torch.nn.utils import spectral_norm
from torch.nn.init import xavier_uniform_
from torch.autograd import grad as torch_grad

In [2]:
def init_weights(model, init_type='xavier', gain=0.02):
   '''
   initialize network's weights
   init_type: normal | xavier | kaiming | orthogonal
   '''

   def init_func(m):
       classname = m.__class__.__name__
       if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
           if init_type == 'normal':
               nn.init.normal_(m.weight.data, 0.0, gain)
           elif init_type == 'xavier':
#                nn.init.xavier_normal_(m.weight.data, gain=gain)
                nn.init.xavier_uniform_(m.weight.data)
           elif init_type == 'kaiming':
               nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
           elif init_type == 'orthogonal':
               nn.init.orthogonal_(m.weight.data, gain=gain)

           if hasattr(m, 'bias') and m.bias is not None:
               nn.init.constant_(m.bias.data, 0.0)

       elif classname.find('BatchNorm2d') != -1:
           nn.init.normal_(m.weight.data, 1.0, gain)
           nn.init.constant_(m.bias.data, 0.0)

   model.apply(init_func)

class Flatten(nn.Module):
    def forward(self, input):
        return input.view(input.size(0), -1)

class PixelNorm(nn.Module):
    def __init__(self, epsilon=1e-8):
        """
            @notice: avoid in-place ops.
            https://discuss.pytorch.org/t/encounter-the-runtimeerror-one-of-the-variables-needed-for-gradient-computation-has-been-modified-by-an-inplace-operation/836/3
        """
        super(PixelNorm, self).__init__()
        self.epsilon = epsilon

    def forward(self, x):
        tmp  = torch.mul(x, x) # or x ** 2
        tmp1 = torch.rsqrt(torch.mean(tmp, dim=1, keepdim=True) + self.epsilon)

        return x * tmp1
    
class encoder(nn.Module):
    def __init__(self):
        super(encoder, self).__init__()
        self.conv_dim = 64
        
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, 3, 1, 1),
            nn.ReLU(True),
            nn.Conv2d(64, 128, 3, 2, 1),
            nn.ReLU(True),
            nn.Conv2d(128, 128, 3, 1, 1),
            nn.ReLU(True),
            nn.Conv2d(128, 256, 3, 2, 1),
            nn.ReLU(True),
            nn.Conv2d(256, 256, 3, 1, 1),
            nn.ReLU(True),
            nn.Conv2d(256, 256, 3, 1, 1),
            nn.ReLU(True),
            nn.Conv2d(256, 16, 3, 1, 1),
            nn.Tanh()
        )
        self.decoder = nn.Sequential(
            nn.Conv2d(16, 256, 3, 1, 1),
            nn.ReLU(True),
            nn.Conv2d(256, 256, 3, 1, 1),
            nn.ReLU(True),
            nn.Conv2d(256, 256, 3, 1, 1),
            nn.ReLU(True),
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.Conv2d(256, 128, 3, 1, 1),
            nn.ReLU(True),
            nn.Conv2d(128, 128, 3, 1, 1),
            nn.ReLU(True),
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.Conv2d(128, 64, 3, 1, 1),
            nn.ReLU(True),
            nn.Conv2d(64, 3, 3, 1, 1),
            nn.Tanh()
        )

        self.apply(init_weights)

    def forward(self, x):
        feat = self.encoder(x)
        out = self.decoder(feat) 
        return feat, out
    
    def decode(self, x):
        return self.decoder(x)
    
class upblock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(upblock, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.Conv2d(in_channels, out_channels, 3, 1, 1),
            PixelNorm(),
            nn.ReLU(True),
            nn.Conv2d(out_channels, out_channels, 3, 1, 1),
            PixelNorm(),
            nn.ReLU(True),
        )
        
    def forward(self, x):
        return self.conv1(x)

class generator(nn.Module):
    def __init__(self, num_z, imsize=64):
        super(generator, self).__init__()
        self.conv_dim = 64
        
        if imsize == 64:
            self.in_dim = self.conv_dim*4
            self.proj = nn.Sequential(
                nn.Linear(num_z, self.conv_dim*4 * 4 * 4),
            )
            self.decoder = nn.Sequential(
                PixelNorm(),
                nn.ReLU(True),
                upblock(self.conv_dim*4, self.conv_dim*2), # [8, 8]
                upblock(self.conv_dim*2, self.conv_dim), # [16, 16]
                nn.Conv2d(self.conv_dim, 16, 3, 1, 1),
                nn.Tanh()
            )
        else:
            self.in_dim = self.conv_dim*8
            self.proj = nn.Sequential(
                nn.Linear(num_z, self.conv_dim*8 * 4 * 4),
            )
            self.decoder = nn.Sequential(
                PixelNorm(),
                nn.ReLU(True),
                upblock(self.conv_dim*8, self.conv_dim*4), # [8, 8]
                upblock(self.conv_dim*4, self.conv_dim*2), # [16, 16]
                upblock(self.conv_dim*2, self.conv_dim), # [16, 16]
                nn.Conv2d(self.conv_dim, 16, 3, 1, 1),
                nn.Tanh()
            )

        self.apply(init_weights)

    def forward(self, x):
        x = self.proj(x).view(x.size(0),self.in_dim,4,4)
        x = self.decoder(x) 
        return x

class code_discriminator(nn.Module):
    # initializers
    def __init__(self, imsize=64):
        super(code_discriminator, self).__init__()
        conv_dim = 64
        if imsize == 64:
            self.encoder = nn.Sequential(
                nn.Conv2d(16, conv_dim, 3, 1, 1), #[16,16]
                nn.ReLU(True),
                nn.Conv2d(conv_dim, conv_dim*2, 3, 2, 1), #[8,8]
                nn.ReLU(True),
                nn.Conv2d(conv_dim * 2, conv_dim*2, 3, 1, 1),
                nn.ReLU(True),
                nn.Conv2d(conv_dim*2, conv_dim*4, 3, 2, 1), #[4,4]
                nn.ReLU(True),
                nn.Conv2d(conv_dim*4, conv_dim*4, 3, 1, 1),
                nn.ReLU(True),
                Flatten(), 
                nn.Linear(conv_dim*4*4*4, 1),
            )
        else:
            self.encoder = nn.Sequential(
                nn.Conv2d(16, conv_dim, 3, 1, 1), #[16,16]
                nn.ReLU(True),
                nn.Conv2d(conv_dim, conv_dim*2, 3, 2, 1), #[8,8]
                nn.ReLU(True),
                nn.Conv2d(conv_dim * 2, conv_dim*2, 3, 1, 1),
                nn.ReLU(True),
                nn.Conv2d(conv_dim*2, conv_dim*4, 3, 2, 1), #[4,4]
                nn.ReLU(True),
                nn.Conv2d(conv_dim*4, conv_dim*4, 3, 1, 1),
                nn.ReLU(True),
                nn.Conv2d(conv_dim*4, conv_dim*8, 3, 2, 1),
                nn.ReLU(True),
                nn.Conv2d(conv_dim*8, conv_dim*8, 3, 1, 1),
                nn.ReLU(True),
                Flatten(), 
                nn.Linear(conv_dim*8*4*4, 1),
            )

        self.apply(init_weights)
    
    # forward method
    def forward(self, x):
        out = self.encoder(x)
        return out
    
class im_discriminator(nn.Module):
    # initializers
    def __init__(self, d=64):
        super(im_discriminator, self).__init__()
        conv_dim = d        
        self.encoder = nn.Sequential(
            spectral_norm(nn.Conv2d(3, conv_dim, 3, 2, 1)),
            nn.ReLU(True),
            spectral_norm(nn.Conv2d(conv_dim, conv_dim * 2, 3, 2, 1)),
            nn.ReLU(True),
            spectral_norm(nn.Conv2d(conv_dim * 2, conv_dim* 4, 3, 2, 1)),
            nn.ReLU(True),
            nn.utils.spectral_norm(nn.Conv2d(conv_dim * 4, conv_dim * 8, 3, 2, 1)),
            nn.ReLU(True),
            nn.utils.spectral_norm(nn.Conv2d(conv_dim * 8, conv_dim * 8, 3, 1, 1)),
            nn.ReLU(True),
            Flatten(), 
            spectral_norm(nn.Linear(conv_dim*8*4*4, 128)),
            nn.ReLU(True),
            spectral_norm(nn.Linear(128, 1)),
        )

        self.apply(init_weights)
    
    # forward method
    def forward(self, x):
        out = self.encoder(x)
        return out

In [3]:
class im_dataset(Dataset):
    def __init__(self, real_dir, imsize):
        self.real_dir = real_dir
        self.imgpaths = self.get_imgpaths()

        self.preprocessing = transforms.Compose([
                transforms.Resize(imsize),
                transforms.CenterCrop(imsize),
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
            ])

    def get_imgpaths(self):
        real_paths = sorted(glob.glob('%s/*.jpg'%self.real_dir, recursive=True))
        return real_paths
    

    def __getitem__(self, idx):
        truepath = self.imgpaths[idx]
        true_im = self.preprocessing(Image.open(truepath))
        if true_im.size(0) == 1:
            return self.__getitem__(np.random.randint(0, self.__len__()))
        
        return true_im

    def __len__(self):
        return len(self.imgpaths)

def generate_recon(epoch, img):
    encode_model.eval()
    with torch.no_grad():
        feat, out = encode_model(img)
        pic = to_img(out.cpu().data)
        save_image(pic, '%s/recon_%d.png'%(out_dir, epoch))
        pic = to_img(img.cpu().data)
        save_image(pic, '%s/original_%d.png'%(out_dir, epoch))
    encode_model.train()
        
def generate_gan(epoch, fixed_z):
    code_gen.eval()
    encode_model.eval()
    with torch.no_grad():
        code = code_gen(fixed_z)
        out = encode_model.decode(code)
        pic = to_img(out.cpu().data)
        save_image(pic, '%s/local_%d.png'%(out_dir, epoch))
    code_gen.train()

def adjust_learning_rate(optimizers, epoch, num_epochs):        
    for optimizer in optimizers:
        for param_group in optimizer.param_groups:
            current_lr = param_group['lr']
            if epoch == 15:
                new_lr = current_lr / 2
            elif epoch == 30:
                new_lr = current_lr / 5
            else:
                new_lr = current_lr
            param_group['lr'] = new_lr
            
def to_img(x):
    x = 0.5 * (x + 1)
    x = x.clamp(0, 1)
    return x


In [4]:
class AdversarialLoss(nn.Module):
    r"""
    Adversarial loss
    https://arxiv.org/abs/1711.10337
    """

    def __init__(self, type='nsgan', target_real_label=1.0, target_fake_label=0.0):
        r"""
        type = nsgan | lsgan | hinge
        """
        super(AdversarialLoss, self).__init__()

        self.type = type
        self.register_buffer('real_label', torch.tensor(target_real_label))
        self.register_buffer('fake_label', torch.tensor(target_fake_label))

        if type == 'nsgan':
            self.criterion = nn.BCELoss()

        elif type == 'lsgan':
            self.criterion = nn.MSELoss()

        elif type == 'hinge':
            self.criterion = nn.ReLU()

    def __call__(self, outputs, is_real, is_disc=None):
        if self.type == 'hinge':
            if is_disc:
                if is_real:
                    outputs = -outputs
                return self.criterion(1 + outputs).mean()
            else:
                return (-outputs).mean()
        elif self.type == 'wgan-gp':
            if is_real:
                outputs = -outputs
            return outputs.mean()
        else:
            labels = (self.real_label if is_real else self.fake_label).expand_as(outputs)
            if self.type == 'nsgan':
                outputs = torch.sigmoid(outputs)
            loss = self.criterion(outputs, labels)
            return loss

        
def gradient_penalty(real_data, generated_data, disc):
    batch_size = real_data.size(0)

    # Calculate interpolation
    alpha = torch.rand(batch_size, 1, 1, 1).cuda()
    alpha = alpha.expand_as(real_data)
    interpolated = (alpha * real_data + (1 - alpha) * generated_data).detach().requires_grad_()

    # Calculate probability of interpolated examples
    prob_interpolated = disc(interpolated)

    # Calculate gradients of probabilities with respect to examples
    gradients = torch_grad(outputs=prob_interpolated, inputs=interpolated,
                           grad_outputs=torch.ones(prob_interpolated.size()).cuda(),
                           create_graph=True, retain_graph=True)[0]

    # Gradients have shape (batch_size, num_channels, img_width, img_height),
    # so flatten to easily take norm per example in batch
    gradients = gradients.view(batch_size, -1)
    gradients_norm = torch.sqrt(torch.sum(gradients ** 2, dim=1) + 1e-12)

    # Return gradient penalty
    return 10 * ((gradients_norm - 1) ** 2).mean()



In [5]:
out_dir = './adv_training2'
if not os.path.exists(out_dir):
    os.mkdir(out_dir)
    


batch_size = 256

dataset = im_dataset('./data/img_align_celeba', 64)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

num_epochs = 50
k = 128
encode_model = encoder().cuda().train()
code_gen = generator(k, 128).cuda().train()

im_disc = im_discriminator().cuda().train()
code_disc = code_discriminator(128).cuda().train()

# encode_model.load_state_dict(torch.load('./adv_training2/encoder.pth'))

adv_criterion = AdversarialLoss(type='wgan-gp').cuda()
l1_criterion = nn.L1Loss()

encoder_optimizer = torch.optim.Adam(encode_model.parameters(), lr=1e-4)
gen_optimizer = torch.optim.Adam(code_gen.parameters(), lr=1e-4, betas=(0, 0.9))

im_disc_optimizer = torch.optim.Adam(im_disc.parameters(), lr=4e-4, betas=(0, 0.9))
code_disc_optimizer = torch.optim.Adam(code_disc.parameters(), lr=4e-4, betas=(0, 0.9))

### Train encoder - decoder

In [6]:
for epoch in range(20):
    for batch_idx, img in enumerate(dataloader):
        img = img.cuda()
        feat, out = encode_model(img)        
        real_out = im_disc(img)
        fake_out = im_disc(out.detach())

        err_real = adv_criterion(real_out, True, True)
        err_fake = adv_criterion(fake_out, False, True)
        err_wgan = gradient_penalty(img, out.detach(), im_disc)
        
        gan_loss = err_real + err_fake + err_wgan
                
        im_disc_optimizer.zero_grad()
        gan_loss.backward()
        im_disc_optimizer.step()
        
        # generator
        l1 = l1_criterion(out, img)
        fake_out = im_disc(out)
        adv_loss = adv_criterion(fake_out, True, False)
        g_loss = l1 + 0.1*adv_loss
        
        encoder_optimizer.zero_grad()
        g_loss.backward()
        encoder_optimizer.step()

        if batch_idx % 100 == 0:
            print(epoch, gan_loss.item(), l1.item(), adv_loss.item())
            generate_recon(epoch, img)
            
    if epoch % 1 == 0:
        torch.save(encode_model.state_dict(), '%s/encode_model.pth'%out_dir)
        torch.save(im_disc.state_dict(), '%s/im_disc.pth'%out_dir)

### Train 128 generator

In [None]:
dataset = im_dataset('./data/img_align_celeba', 128)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

fixed_z = torch.randn([batch_size, k]).cuda()

for epoch in range(50):
    for batch_idx, img in enumerate(dataloader):
        img = img.cuda()
        z = torch.randn([img.size(0), k]).cuda()
        with torch.no_grad():
            real_feat, _ = encode_model(img)  
            
        fake_feat = code_gen(z)
        
        real_out = code_disc(real_feat)
        fake_out = code_disc(fake_feat.detach())

        err_real = adv_criterion(real_out, True, True)
        err_fake = adv_criterion(fake_out, False, True)
        err_wgan = gradient_penalty(real_feat, fake_feat.detach(), code_disc)

        gan_loss = err_real + err_fake + err_wgan
                
        code_disc_optimizer.zero_grad()
        gan_loss.backward()
        code_disc_optimizer.step()
        
        # generator
        fake_out = code_disc(fake_feat)
        adv_loss = adv_criterion(fake_out, True, False)
        
        gen_optimizer.zero_grad()
        adv_loss.backward()
        gen_optimizer.step()

        if batch_idx % 100 == 0:
            print(epoch, gan_loss.item(), adv_loss.item())
            generate_gan(epoch, fixed_z)
    if epoch % 1 == 0:
        torch.save(code_gen.state_dict(), '%s/code_gen.pth'%out_dir)
        torch.save(code_disc.state_dict(), '%s/code_disc.pth'%out_dir)