In [1]:
import torch
from torch.nn import DataParallel
from pathlib import Path
from PIL import Image
import torchvision
import argparse
import numpy as np
import time
import torch
from functools import partial
import tqdm
import matplotlib.pyplot as plt

from src.model import EmbeddingModel
from src.dataset import get_dataloader
from src.loss import LossBuilder
from src.stylegan import G_synthesis, G_mapping
from src.spherical_optimizer import SphericalOptimizer
from src.utils import open_url

In [2]:
input_dir = 'data/input_aligned'  
output_dir = 'data/output'       
cache_dir = 'cache'       
batch_size = 1       
seed = None       
loss_str = '100*L2'       
noise_type = 'trainable'       
num_trainable_noise_layers = 18       
tile_latent = False       
bad_noise_layers = '17'         
opt_name = 'adam'       
learning_rate = 0.4       
steps = 500       
lr_schedule = 'linear1cycledrop'       
save_intermediate = False 

In [3]:
# Setup output paths
out_path = Path(output_dir)
out_path.mkdir(parents=True, exist_ok=True)

# Load data
dataloader = get_dataloader(input_dir, batch_size)

In [4]:
cache_dir = Path(cache_dir)
cache_dir.mkdir(parents=True, exist_ok = True)
        
# Load StyleGAN
synthesis = G_synthesis().cuda()

print("Loading Synthesis Network")
with open_url("https://drive.google.com/uc?id=1TCViX1YpQyRsklTVYEJwdbmK91vklCo8", cache_dir=cache_dir) as f:
            synthesis.load_state_dict(torch.load(f))

# Turn off network gradient updates
for param in synthesis.parameters():
            param.requires_grad = False

lrelu = torch.nn.LeakyReLU(negative_slope=0.2)

# Load mean + std of mapping network
gaussian_fit = torch.load("cache/gaussian_fit.pt")


Loading Synthesis Network


In [5]:
for ref_im, ref_im_name in dataloader:
    ref_im = ref_im.cuda()
    # Set seed
    if seed:
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        torch.backends.cudnn.deterministic = True

    batch_size = ref_im.shape[0]

    # Generate latent tensor
    if tile_latent:
        latent = torch.randn((batch_size, 1, 512), dtype=torch.float, 
                                 requires_grad=True, device='cuda')
    else:
        latent = torch.randn((batch_size, 18, 512), dtype=torch.float, 
                                 requires_grad=True, device='cuda')

    # Generate list of noise tensors
    noise = [] # stores all of the noise tensors
    noise_vars = []  # stores the noise tensors that we want to optimize on

    for i in range(18):
            # dimension of the ith noise tensor
        res = (batch_size, 1, 2**(i//2+2), 2**(i//2+2))

        if(noise_type == 'zero' or i in [int(layer) for layer in bad_noise_layers.split('.')]):
            new_noise = torch.zeros(res, dtype=torch.float, device='cuda')
            new_noise.requires_grad = False
        elif(noise_type == 'fixed'):
            new_noise = torch.randn(res, dtype=torch.float, device='cuda')
            new_noise.requires_grad = False
        elif (noise_type == 'trainable'):
            new_noise = torch.randn(res, dtype=torch.float, device='cuda')
            if (i < num_trainable_noise_layers):
                new_noise.requires_grad = True
                noise_vars.append(new_noise)
            else:
                new_noise.requires_grad = False
        else:
            raise Exception("unknown noise type")

        noise.append(new_noise)

    var_list = [latent]+noise_vars

    opt_dict = {
            'sgd': torch.optim.SGD,
            'adam': torch.optim.Adam,
            'sgdm': partial(torch.optim.SGD, momentum=0.9),
            'adamax': torch.optim.Adamax
    }
    opt = SphericalOptimizer(opt_dict[opt_name], var_list, lr=learning_rate)

    schedule_dict = {
            'fixed': lambda x: 1,
            'linear1cycle': lambda x: (9*(1-np.abs(x/steps-1/2)*2)+1)/10,
            'linear1cycledrop': lambda x: (9*(1-np.abs(x/(0.9*steps)-1/2)*2)+1)/10 if x < 0.9*steps else 1/10 + (x-0.9*steps)/(0.1*steps)*(1/1000-1/10),
    }
    scheduler = torch.optim.lr_scheduler.LambdaLR(opt.opt, schedule_dict[lr_schedule])
        
    loss_builder = LossBuilder(ref_im, loss_str).cuda()

    min_loss = np.inf
    best_summary = ""
    start_t = time.time()
    gen_im = None

    print("Optimizing")
    t = tqdm.trange(steps)
    for j in t:
        opt.opt.zero_grad()

        # Duplicate latent in case tile_latent = True
        if (tile_latent):
            latent_in = latent.expand(-1, 18, -1)
        else:
            latent_in = latent

        # Apply learned linear mapping to match latent distribution to that of the mapping network
        latent_in = lrelu(latent_in*gaussian_fit["std"] + gaussian_fit["mean"])

        # Normalize image to [0,1] instead of [-1,1]
        gen_im = (synthesis(latent_in, noise)+1)/2

        # Calculate Losses
        loss, loss_dict = loss_builder(latent_in, gen_im)
        loss_dict['TOTAL'] = loss

        # Save best summary for log
        if(loss < min_loss):
            min_loss = loss
            best_summary = f'BEST ({j+1}) | '+' | '.join(
            [f'{x}: {y:.4f}' for x, y in loss_dict.items()])

        loss.backward()
        opt.step()
        scheduler.step()

        t.set_description('L2: {:.10f}'.format(loss_dict['L2']))

    total_t = time.time()-start_t
    current_info = f' | time: {total_t:.1f} | it/s: {(j+1)/total_t:.2f} | batchsize: {batch_size}'
    print(best_summary+current_info)
        
    #ield gen_im.clone().cpu().detach().clamp(0, 1)

  0%|          | 0/500 [00:00<?, ?it/s]

Optimizing


L2: 0.0000345961: 100%|██████████| 500/500 [01:59<00:00,  4.19it/s]

BEST (500) | L2: 0.0000 | TOTAL: 0.0035 | time: 119.3 | it/s: 4.19 | batchsize: 1





In [6]:
toPIL = torchvision.transforms.ToPILImage()
img = toPIL(gen_im.clone().squeeze(0).cpu().detach().clamp(0, 1))
img.show()

In [7]:
# Generate new noise
re_noise = []
for i in range(18):
    # dimension of the ith noise tensor
    if i < 30:
        res = (batch_size, 1, 2**(i//2+2), 2**(i//2+2))

        new_noise = torch.randn(res, dtype=torch.float, device='cuda')
        new_noise.requires_grad = False
        re_noise.append(new_noise)
    else:
        re_noise.append(noise[i])
    


In [8]:
new_im = (synthesis(latent_in, re_noise)+1)/2
img = toPIL(new_im.clone().squeeze(0).cpu().detach().clamp(0, 1))
img.show()