In [2]:
import sys
sys.path.insert(0, '..')

import torch
from torch import nn
from torch.optim import Adam
from torch.optim.lr_scheduler import StepLR
from torch.nn.functional import softplus
from torch.autograd import grad
from torch.utils.data import Dataset, DataLoader
from torchvision.utils import make_grid, save_image
from torchvision import transforms

import numpy as np
import random
from math import log2, ceil
import pandas as pd

from PIL import Image
import matplotlib.pyplot as plt

from datetime import datetime
from time import time

import os
import warnings
warnings.filterwarnings('ignore')

import torch
from torch import nn
from model import Model
import matplotlib.pyplot as plt
import numpy as np

import torch
from torch import nn
from torchvision import models
from tqdm import tqdm_notebook as tqdm

import os
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.utils import make_grid
from PIL import Image

DEVICE = 'cuda:2'

# Prepare `data`

In [2]:
from dataset import SingleFaceDataset, BatchCollate
from albumentations import HorizontalFlip, Compose, KeypointParams
from albumentations.pytorch import ToTensor

transform = Compose([HorizontalFlip(), ToTensor()], keypoint_params=KeypointParams(format='xy'))
transform_ci = Compose([ToTensor()], keypoint_params=KeypointParams(format='xy'))
dataset = SingleFaceDataset(
    root='/root/img_dataset/id00061/',
    center_identity_size=4,
    center_identity_step=32,
    transform=transform,
    transform_ci=transform_ci,
    size=3200
)

batch_collate = BatchCollate()
dataloader = torch.utils.data.DataLoader(
    dataset=dataset,
    batch_size=32,
    shuffle=True,
    collate_fn=batch_collate,
    drop_last=True)

# Load `pretrained model`

In [15]:
def load_checkpoint(model, checkpoint_filename):
    d = torch.load(checkpoint_filename, map_location=lambda storage, loc: storage)
    
    decoder = model.decoder
    encoder = model.encoder
    mapping_tl = model.mapping_tl
    mapping_fl = model.mapping_fl
    dlatent_avg = model.dlatent_avg
    
    model_dict = {
        'discriminator_s': encoder,
        'generator_s': decoder,
        'mapping_tl_s': mapping_tl,
        'mapping_fl_s': mapping_fl,
        'dlatent_avg': dlatent_avg
    }
    
    for key in model_dict.keys():
        model_dict[key].load_state_dict(d['models'][key])
        
    
def save_reconstructions(name, original, reconstruction, nrows=6):
    """
    original, reconstruction - type: list, e.g. original = [x, x_hat], reconstruction = [G(E(x)), G(E(x_hat))]
    
    [[orig_x, rec_x], [orig_x, rec_x], [orig_x, rec_x]]
    [[orig_x_hat, rec_x_hat], [orig_x_hat, rec_x_hat], [orig_x_hat, rec_x_hat]]
    
    """
    tensor = []
    for orig, rec in zip(original, reconstruction):        
        tensor.append(torch.cat([torch.cat([orig.split(1)[i], rec.split(1)[i]], dim=0) for i in range(nrows//2)], dim=0))
    
    save_image(torch.cat(tensor, dim=0), name, nrow=nrows, padding=1, normalize=True, range=(-1, 1))
    
alae = Model(latent_size=256, layer_count=6, maxf=256, startf=64, mapping_layers=8, dlatent_avg_beta=0.995, style_mixing_prob=0,
             generator = 'GeneratorDefault',
             encoder = 'EncoderDefault')
        
load_checkpoint(alae, 'training_artifacts/celeba/model_final.pth')
alae = alae.to(DEVICE);

# VGG `loss`

In [16]:
def tanh2sigmoid(batch):
    return batch.div(2).add(0.5)


class VGGExtractor(nn.Module):
    def __init__(self, vgg):
        super(VGGExtractor, self).__init__()
        
        mean = torch.FloatTensor([0.485, 0.456, 0.406])[None, :, None, None]
        self.register_buffer('mean', mean)

        std = torch.FloatTensor([0.229, 0.224, 0.225])[None, :, None, None]
        self.register_buffer('std', std)
        
        self.relu0_1 = vgg[0:2]
        self.relu1_1 = vgg[2:7]
        self.relu2_1 = vgg[7:12]
        self.relu3_1 = vgg[12:21]
        self.relu4_1 = vgg[21:30]
        
    def forward(self, x, level=1):
        x = (x - self.mean)/self.std
        
        extracted_featrues = []
        for block in [self.relu0_1, self.relu1_1, self.relu2_1, self.relu3_1, self.relu4_1][:level]:
            x = block(x)
            extracted_featrues.append(x)
            
        return extracted_featrues
    
class VGGLoss(nn.Module):
    def __init__(self):
        super(VGGLoss, self).__init__()
        self.feature_exctracor = VGGExtractor(models.vgg19(pretrained=True).features).to(DEVICE)
        self.feature_exctracor.eval();

    def forward(self, inp1, inp2, level):
        coefs = [1, 1/4, 1/8]
        features1, features2 = self.feature_exctracor(inp1, level), self.feature_exctracor(inp2, level)
        
        loss = 0
        for i in range(level):
            loss += coefs[i] * (features1[i] - features2[i]).pow(2).mean()
        return loss
    
vgg_loss = VGGLoss()

# Training loop

In [18]:
# os.mkdir('Finetune_id00061_correct')

In [17]:
def autoencode(model, batch):
    codes = model.encode(batch.to(DEVICE), 5, 1)[0].repeat(1, 12, 1)
    return model.decoder(codes, lod=5, blend=1, noise='batch_constant')

In [None]:
torch.backends.cudnn.enabled = False

alae.train()

experiment_name = 'Finetune_id00061_correct'

try:
    os.mkdir(experiment_name)
except:
    pass

SAVE_IMAGES_EACH = 98
EPOCHS = 1000
reconstruction_loss = []

BS = 32
scale = 128

optimizer_G = Adam(alae.decoder.parameters(), betas=(0, 0.99), lr=1e-4)
best_loss = 100

for epoch in range(EPOCHS):    
    for batch_idx, batch in tqdm(enumerate(dataloader), total=len(dataloader), desc=f'Epoch: [{epoch+1}/{EPOCHS}]'):
                                 
        real_samples = batch['image'].to(DEVICE).add(-0.5).mul(2)         
        x_reconstruction = autoencode(alae, real_samples).clamp(-1.2, 1.2)
        
        loss = vgg_loss(tanh2sigmoid(real_samples), tanh2sigmoid(x_reconstruction), 3)
        reconstruction_loss.append(loss.item())
                                 
        optimizer_G.zero_grad()
        loss.backward()
        optimizer_G.step()

        if ((batch_idx+1) % SAVE_IMAGES_EACH) == 0:                                

            name = f'Epoch{epoch}_Img{batch_idx}.png'
            save_reconstructions(os.path.join(experiment_name, name),
                                 [real_samples.cpu().detach(), real_samples.cpu().detach()[-16:]],
                                 [x_reconstruction.cpu().detach(), x_reconstruction.cpu().detach()[-16:]], nrows=6)
                                 
            Image.open(os.path.join(experiment_name, name)).resize((1536, 512)).save(os.path.join(experiment_name, name))

            # Save losses
            pd.DataFrame(reconstruction_loss).to_csv(os.path.join(experiment_name, 'stats.csv'), index=False)
        
            if min(reconstruction_loss[-100:]) < best_loss:
                torch.save({'AE': alae.state_dict(),
                            'optAE': optimizer_G.state_dict()},
                            os.path.join(experiment_name, f'BestModel.pt'))
        
    if epoch % 20 == 0:
        torch.save({'AE': alae.state_dict(),
                    'optAE': optimizer_G.state_dict()},
                    os.path.join(experiment_name, f'Finetuned_Epoch{epoch}.pt'))