In [6]:
import os

os.environ["CUDA_VISIBLE_DEVICES"]="4"

import torch
import torchvision
from torch import nn
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torch.utils.data import Dataset

from torchvision import transforms
from torchvision.utils import save_image
import torch.nn.functional as F
from random import randint
import numpy as np
import glob
from PIL import Image
from tqdm import tqdm_notebook
import torch.backends.cudnn as cudnn
cudnn.benchmark = True

from torch.nn.utils import spectral_norm
from torch.nn.init import xavier_uniform_
from torch.autograd import grad as torch_grad
import math

In [7]:
def init_weights(model, init_type='xavier', gain=0.02):
   '''
   initialize network's weights
   init_type: normal | xavier | kaiming | orthogonal
   '''

   def init_func(m):
       classname = m.__class__.__name__
       if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
           if init_type == 'normal':
               nn.init.normal_(m.weight.data, 0.0, gain)
           elif init_type == 'xavier':
#                nn.init.xavier_normal_(m.weight.data, gain=gain)
                nn.init.xavier_uniform_(m.weight.data)
           elif init_type == 'kaiming':
               nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
           elif init_type == 'orthogonal':
               nn.init.orthogonal_(m.weight.data, gain=gain)

           if hasattr(m, 'bias') and m.bias is not None:
               nn.init.constant_(m.bias.data, 0.0)

       elif classname.find('BatchNorm2d') != -1:
           nn.init.normal_(m.weight.data, 1.0, gain)
           nn.init.constant_(m.bias.data, 0.0)

   model.apply(init_func)

class Flatten(nn.Module):
    def forward(self, input):
        return input.view(input.size(0), -1)

class PixelNorm(nn.Module):
    def __init__(self, epsilon=1e-8):
        """
            @notice: avoid in-place ops.
            https://discuss.pytorch.org/t/encounter-the-runtimeerror-one-of-the-variables-needed-for-gradient-computation-has-been-modified-by-an-inplace-operation/836/3
        """
        super(PixelNorm, self).__init__()
        self.epsilon = epsilon

    def forward(self, x):
        tmp  = torch.mul(x, x) # or x ** 2
        tmp1 = torch.rsqrt(torch.mean(tmp, dim=1, keepdim=True) + self.epsilon)

        return x * tmp1
    
class encoder(nn.Module):
    def __init__(self):
        super(encoder, self).__init__()
        self.conv_dim = 64
        
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, 3, 1, 1),
            nn.ReLU(True),
            nn.Conv2d(64, 128, 3, 2, 1),
            nn.ReLU(True),
            nn.Conv2d(128, 128, 3, 1, 1),
            nn.ReLU(True),
            nn.Conv2d(128, 256, 3, 2, 1),
            nn.ReLU(True),
            nn.Conv2d(256, 256, 3, 1, 1),
            nn.ReLU(True),
            nn.Conv2d(256, 256, 3, 1, 1),
            nn.ReLU(True),
            nn.Conv2d(256, 16, 3, 1, 1),
            nn.Tanh()
        )
        self.decoder = nn.Sequential(
            nn.Conv2d(16, 256, 3, 1, 1),
            nn.ReLU(True),
            nn.Conv2d(256, 256, 3, 1, 1),
            nn.ReLU(True),
            nn.Conv2d(256, 256, 3, 1, 1),
            nn.ReLU(True),
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.Conv2d(256, 128, 3, 1, 1),
            nn.ReLU(True),
            nn.Conv2d(128, 128, 3, 1, 1),
            nn.ReLU(True),
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.Conv2d(128, 64, 3, 1, 1),
            nn.ReLU(True),
            nn.Conv2d(64, 3, 3, 1, 1),
            nn.Tanh()
        )

        self.apply(init_weights)

    def forward(self, x):
        feat = self.encoder(x)
        out = self.decoder(feat) 
        return feat, out
    
    def decode(self, x):
        return self.decoder(x)
    
class upblock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(upblock, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear'),
            spectral_norm(nn.Conv2d(in_channels, out_channels, 3, 1, 1)),
#             PixelNorm(),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(True),
            spectral_norm(nn.Conv2d(out_channels, out_channels, 3, 1, 1)),
#             PixelNorm(),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(True),
        )
        
    def forward(self, x):
        return self.conv1(x)

class generator(nn.Module):
    def __init__(self, num_z, imsize=64):
        super(generator, self).__init__()
        conv_dim = 64
        
        num_up = int(math.log(imsize, 2)) - 4 #start from 4x4
        in_dim = conv_dim*2**(num_up-1)
        out_dim = in_dim
        
        self.proj = spectral_norm(nn.Linear(num_z, in_dim * 4 * 4))
        decoder = [nn.BatchNorm2d(in_dim), nn.ReLU(True)]
        for i in range(num_up):
            decoder.append(upblock(in_dim, out_dim))
            in_dim = out_dim
            out_dim = out_dim // 2
            
        decoder += [spectral_norm(nn.Conv2d(in_dim, 16, 3, 1, 1)), nn.Tanh()]
        self.decoder = nn.Sequential(*decoder)
        self.apply(init_weights)

    def forward(self, x):
        out = self.proj(x).view(x.size(0),-1,4,4)
        out = self.decoder(out) 
        return out

class code_discriminator(nn.Module):
    # initializers
    def __init__(self, imsize=64):
        super(code_discriminator, self).__init__()
        num_down = int(math.log(imsize, 2)) - 4 #start from 4x4
        in_dim = 16
        out_dim = 64
        
        encoder = []
        
        for i in range(num_down):
            encoder += [spectral_norm(nn.Conv2d(in_dim, out_dim, 3, 1, 1)),
                            nn.BatchNorm2d(out_dim),
                            nn.LeakyReLU(0.2, True),
                            spectral_norm(nn.Conv2d(out_dim, out_dim*2, 3, 2, 1)),
                            nn.BatchNorm2d(out_dim*2),
                            nn.LeakyReLU(0.2, True),
                           ]
            in_dim = out_dim*2
            out_dim *= 2
            
        encoder += [spectral_norm(nn.Conv2d(in_dim, out_dim, 3, 1, 1)),
                        nn.BatchNorm2d(out_dim),
                        nn.LeakyReLU(0.2, True),
                        Flatten(), 
                        spectral_norm(nn.Linear(out_dim*4*4, 1)),
                       ]

        self.encoder = nn.Sequential(*encoder)
        self.apply(init_weights)
    
    def forward(self, x):
        out = self.encoder(x)
        return out
    
class im_discriminator(nn.Module):
    # initializers
    def __init__(self, d=64):
        super(im_discriminator, self).__init__()
        conv_dim = d        
        self.encoder = nn.Sequential(
            nn.Conv2d(3, conv_dim, 3, 2, 1),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(conv_dim, conv_dim * 2, 3, 2, 1),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(conv_dim * 2, conv_dim* 4, 3, 2, 1),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(conv_dim * 4, conv_dim * 8, 3, 2, 1),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(conv_dim * 8, conv_dim * 8, 3, 1, 1),
            nn.LeakyReLU(0.2, True),
            Flatten(), 
            nn.Linear(conv_dim*8*4*4, 1),
        )

        self.apply(init_weights)
    
    # forward method
    def forward(self, x):
        out = self.encoder(x)
        return out

In [8]:
class im_dataset(Dataset):
    def __init__(self, real_dir, imsize):
        self.real_dir = real_dir
        self.imgpaths = self.get_imgpaths()

        self.preprocessing = transforms.Compose([
                transforms.Resize(imsize),
                transforms.CenterCrop(imsize),
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
            ])

    def get_imgpaths(self):
        real_paths = sorted(glob.glob('%s/*.jpg'%self.real_dir, recursive=True))
        return real_paths
    

    def __getitem__(self, idx):
        truepath = self.imgpaths[idx]
        true_im = self.preprocessing(Image.open(truepath))
        if true_im.size(0) == 1:
            return self.__getitem__(np.random.randint(0, self.__len__()))
        
        return true_im

    def __len__(self):
        return len(self.imgpaths)

class code_dataset(Dataset):
    def __init__(self, real_dir):
        self.real_dir = real_dir
        self.imgpaths = self.get_imgpaths()

    def get_imgpaths(self):
        real_paths = sorted(glob.glob('%s/*.npz'%self.real_dir))
        return real_paths
    
    def __getitem__(self, idx):
        path = self.imgpaths[idx]
        code = np.load(path)
        
        return torch.FloatTensor(code)

    def __len__(self):
        return len(self.imgpaths)

def generate_recon(epoch, img):
    encode_model.eval()
    with torch.no_grad():
        feat, out = encode_model(img)
        pic = to_img(out.cpu().data)
        save_image(pic, '%s/recon_%d.png'%(out_dir, epoch))
        pic = to_img(img.cpu().data)
        save_image(pic, '%s/original_%d.png'%(out_dir, epoch))
    encode_model.train()
        
def generate_gan(epoch, img, fixed_z):
#     code_gen.eval()
    encode_model.eval()
    with torch.no_grad():
        code = code_gen(fixed_z)
        out = encode_model.decode(code)
        pic = to_img(out.cpu().data)[:30]
        save_image(pic, '%s/local_%d.png'%(out_dir, epoch))
        
        _, out = encode_model(img)
        pic = to_img(out.cpu().data)[:30]
        save_image(pic, '%s/large_recon_%d.png'%(out_dir, epoch))

    code_gen.train()

def adjust_learning_rate(optimizers, epoch):        
    for optimizer in optimizers:
        for param_group in optimizer.param_groups:
            current_lr = param_group['lr']
            if epoch == 10:
                new_lr = current_lr / 10
            else:
                new_lr = current_lr
            param_group['lr'] = new_lr
            
def to_img(x):
    x = 0.5 * (x + 1)
    x = x.clamp(0, 1)
    return x


In [9]:
class AdversarialLoss(nn.Module):
    r"""
    Adversarial loss
    https://arxiv.org/abs/1711.10337
    """

    def __init__(self, type='nsgan', target_real_label=1.0, target_fake_label=0.0):
        r"""
        type = nsgan | lsgan | hinge
        """
        super(AdversarialLoss, self).__init__()

        self.type = type
        self.register_buffer('real_label', torch.tensor(target_real_label))
        self.register_buffer('fake_label', torch.tensor(target_fake_label))

        if type == 'nsgan':
            self.criterion = nn.BCELoss()

        elif type == 'lsgan':
            self.criterion = nn.MSELoss()

        elif type == 'hinge':
            self.criterion = nn.ReLU()

    def __call__(self, outputs, is_real, is_disc=None):
        if self.type == 'hinge':
            if is_disc:
                if is_real:
                    outputs = -outputs
                return self.criterion(1 + outputs).mean()
            else:
                return (-outputs).mean()
        elif self.type == 'wgan-gp':
            if is_real:
                outputs = -outputs
            return outputs.mean()
        else:
            labels = (self.real_label if is_real else self.fake_label).expand_as(outputs)
            if self.type == 'nsgan':
                outputs = torch.sigmoid(outputs)
            loss = self.criterion(outputs, labels)
            return loss

        
def gradient_penalty(real_data, generated_data, disc):
    batch_size = real_data.size(0)

    # Calculate interpolation
    alpha = torch.rand(batch_size, 1, 1, 1).cuda()
    alpha = alpha.expand_as(real_data)
    interpolated = (alpha * real_data + (1 - alpha) * generated_data).detach().requires_grad_()

    # Calculate probability of interpolated examples
    prob_interpolated = disc(interpolated)

    # Calculate gradients of probabilities with respect to examples
    gradients = torch_grad(outputs=prob_interpolated, inputs=interpolated,
                           grad_outputs=torch.ones(prob_interpolated.size()).cuda(),
                           create_graph=True, retain_graph=True)[0]

    # Gradients have shape (batch_size, num_channels, img_width, img_height),
    # so flatten to easily take norm per example in batch
    gradients = gradients.view(batch_size, -1)
    gradients_norm = torch.sqrt(torch.sum(gradients ** 2, dim=1) + 1e-12)

    # Return gradient penalty
    return 10 * ((gradients_norm - 1) ** 2).mean()


In [10]:
out_dir = './adv_training2'
code_dir = os.path.join(out_dir, 'codes')
os.makedirs(out_dir, exist_ok=True)    
os.makedirs(code_dir, exist_ok=True)    


batch_size = 128
dataset = im_dataset('./data/img_align_celeba', 64)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

num_epochs = 100
k = 256
output_size = 256

encode_model = encoder().cuda().train()
code_gen = generator(k, output_size).cuda().train()

im_disc = im_discriminator().cuda().train()
code_disc = code_discriminator(output_size).cuda().train()

encode_model.load_state_dict(torch.load('./adv_training2/encoder.pth'))
# code_gen.load_state_dict(torch.load('./adv_training2/code_gen.pth'))
# code_disc.load_state_dict(torch.load('./adv_training2/code_disc.pth'))

loss = 'hinge'
adv_criterion = AdversarialLoss(type=loss).cuda()
l1_criterion = nn.L1Loss()

encoder_optimizer = torch.optim.Adam(encode_model.parameters(), lr=1e-4, betas=(0.5, 0.999))
im_disc_optimizer = torch.optim.Adam(im_disc.parameters(), lr=4e-4, betas=(0.5, 0.999))

gen_optimizer = torch.optim.Adam(code_gen.parameters(), lr=1e-4, betas=(.5, 0.999))
code_disc_optimizer = torch.optim.Adam(code_disc.parameters(), lr=4e-4, betas=(.5, 0.999))

### Train encoder - decoder

In [11]:
# for epoch in range(15):
#     for batch_idx, img in enumerate(dataloader):
#         adjust_learning_rate([encoder_optimizer], epoch)      
#         img = img.cuda()
#         feat, out = encode_model(img)        
#         real_out = im_disc(img)
#         fake_out = im_disc(out.detach())

#         err_real = adv_criterion(real_out, True, True)
#         err_fake = adv_criterion(fake_out, False, True)
#         gan_loss = err_real + err_fake

#         if loss == 'wgan-gp':
#             err_wgan = gradient_penalty(img, out.detach(), im_disc)
#             gan_loss += err_wgan
        
                
#         im_disc_optimizer.zero_grad()
#         gan_loss.backward()
#         im_disc_optimizer.step()
        
#         # generator
#         l1 = l1_criterion(out, img)
#         fake_out = im_disc(out)
#         adv_loss = adv_criterion(fake_out, True, False)
#         g_loss = l1 + 0.01*adv_loss
        
#         encoder_optimizer.zero_grad()
#         g_loss.backward()
#         encoder_optimizer.step()

#         if batch_idx % 100 == 0:
#             print(epoch, gan_loss.item(), l1.item(), adv_loss.item())
#             generate_recon(epoch, img)
            
#     if epoch % 1 == 0:
#         torch.save(encode_model.state_dict(), '%s/encode_model.pth'%out_dir)
#         torch.save(im_disc.state_dict(), '%s/im_disc.pth'%out_dir)

Convert high resolution images to code and save it into a folder for quicker training. How to compress np arrays efficiently? Takes too much disk space

In [12]:
# dataset = im_dataset('./data/img_align_celeba', output_size)
# dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# count = 0
# for batch_idx, img in enumerate(dataloader):
#     img = img.cuda()
#     with torch.no_grad():
#         real_feat, _ = encode_model(img) 
#         real_feat = real_feat.cpu().numpy()
#         for i in range(real_feat.shape[0]):
#             np.savez_compressed(f'{code_dir}/{count}.npz', real_feat[i])
#             count += 1

### Train 128 generator

In [None]:
# dataset = code_dataset(code_dir)
# dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

dataset = im_dataset('./data/img_align_celeba', output_size)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

encode_model.eval()
fixed_z = torch.randn([batch_size, k]).cuda()

for epoch in range(num_epochs):
    for batch_idx, img in enumerate(dataloader):
        img = img.cuda()
        z = torch.randn([img.size(0), k]).cuda()
        with torch.no_grad():
            real_feat, _ = encode_model(img)  
            
        fake_feat = code_gen(z)
        
        real_out = code_disc(real_feat)
        fake_out = code_disc(fake_feat.detach())

        err_real = adv_criterion(real_out, True, True)
        err_fake = adv_criterion(fake_out, False, True)
        gan_loss = err_real + err_fake

        if loss == 'wgan-gp':
            err_wgan = gradient_penalty(real_feat, fake_feat.detach(), code_disc)
            gan_loss += err_wgan

                
        code_disc_optimizer.zero_grad()
        gan_loss.backward()
        code_disc_optimizer.step()
        
        # generator
        fake_out = code_disc(fake_feat)
        adv_loss = adv_criterion(fake_out, True, False)
        
        gen_optimizer.zero_grad()
        adv_loss.backward()
        gen_optimizer.step()

        if batch_idx % 50 == 0:
            print(epoch, err_real.item(), err_fake.item(), adv_loss.item())
            generate_gan(epoch, img, fixed_z)
            
    if epoch % 1 == 0:
        torch.save(code_gen.state_dict(), '%s/code_gen.pth'%out_dir)
        torch.save(code_disc.state_dict(), '%s/code_disc.pth'%out_dir)

  "See the documentation of nn.Upsample for details.".format(mode))


0 0.5964428186416626 1.4191110134124756 16.387798309326172
0 0.009365899488329887 0.0 14.7371244430542
0 0.051138438284397125 0.0 4.9860334396362305
0 0.014964363537728786 0.0 3.664454221725464
0 0.0 0.1002039760351181 5.653564929962158
0 0.0 0.0007618269883096218 2.6636598110198975
0 0.0 0.5470526814460754 6.94040584564209
0 0.0 0.2772418260574341 6.751776695251465
0 0.005020281299948692 0.0 3.330946207046509
0 0.0 0.0 3.717128276824951
0 0.0 0.0 3.2507944107055664
0 0.0 0.0 3.474308490753174
0 0.0 0.28474700450897217 5.033273696899414
0 0.0 0.0 3.643026113510132
0 0.0 0.0 3.5737709999084473
0 0.0 0.0 1.8947405815124512
0 0.0 0.0 2.9227938652038574
0 0.008621095679700375 0.18504652380943298 3.679661273956299
0 0.0 0.0 2.882267951965332
0 0.0 0.0478169322013855 4.573443412780762
0 0.0 0.0 3.267589569091797
0 0.0 0.0 4.360723495483398
0 1.1747798919677734 0.0 6.892269134521484
0 0.035564251244068146 5.245824813842773 4.853349685668945
0 0.02089126594364643 0.0 3.160186767578125
0 0.0 0.

In [None]:
fake_feat.shape, real_feat.shape