In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys, os
import cv2

import torch
import torchvision
from torch.utils.data import DataLoader
import numpy as np

from omegaconf import OmegaConf

from models.expgan import EXPGAN
from dataset.ffhq import FFHQDataset
from glob import glob
from models.flamedecoder import FlameDecoder

from tqdm import tqdm
import matplotlib.pyplot as plt
import json

## Model and config

In [None]:
# configs
fn_ckpt = './pretrained_model/model_checkpoint.ckpt'
fn_cfg = './experiments/ffhq/config.yaml'

device = 'cuda'

cfg = OmegaConf.load(fn_cfg)
cfg.dataset.fn_meta_flip = None

img_wh = (cfg.dataset.image_size, cfg.dataset.image_size)

In [None]:
model = EXPGAN(cfg)
model = model.to(device)
model.eval()
model.load_from_checkpoint(fn_ckpt)

model.net_G_ema.decoder.coarse_steps = cfg.model.EG3D.coarse_steps * 2
model.net_G_ema.decoder.fine_steps = cfg.model.EG3D.fine_steps * 2
model.net_G_ema.decoder.perturb = False

### Dataset, FLAME decoder, and an example DECA parameters for demos

In [None]:
# batch_size = 1
dataset = FFHQDataset(**cfg.dataset, split='train')
mesh_decoder = FlameDecoder(**cfg.model.flamedecoder, masking=False).cuda()

fn_transfer = 'data/demo/meta_smooth.json'
transfer_meta = json.load(open(fn_transfer))

### Pose interpolation

In [None]:
dataloader = DataLoader(dataset, shuffle=True, batch_size=1)

steps = 120
angles = np.linspace(0, np.pi * 4, steps)
angles = [np.sin(angle) / 6 for angle in angles]

frames = []

for bidx, batch in enumerate(dataloader):
    # seed = torch.seed()
    # print(seed)
    for angle in tqdm(angles):
        torch.manual_seed(13138407004106472821)
        batch.update({name: tensor.cuda() for name, tensor in batch.items() if type(tensor) == torch.Tensor})
        batch['codedict_real'].update({name: tensor.cuda() for name, tensor in batch['codedict_real'].items() if type(tensor) == torch.Tensor})
        pose = torch.FloatTensor([[0, angle, 0]]).cuda()

        flame = model.flamedecoder(batch['codedict_real'], batch['bbox_real'], pose=pose)
        uv = flame['uv']
        depth = flame['depth']
        c2w = flame['c2w']
        shape = batch['shape_real']
        batch['codedict_real']['shape']

        output = model.net_G_ema(shape, c2w, uv, depth, truncation=0.5, update_mean=False)        
        pred = output['pred'] * 0.5 + 0.5

        pred = pred.permute(0,2,3,1).detach().cpu()[0] #.numpy()[0]
        frames.append(pred)

        # plt.figure(figsize=(5,5))
        # plt.imshow(pred)

    #     break
    break

torchvision.io.write_video('results/demo_yaw.mp4', torch.stack(frames, dim=0) * 255., fps=30, options={'crf': '15'})

### Expression interpolation

In [None]:
# dataloader = DataLoader(dataset, shuffle=True, batch_size=1)

steps = 120

target_frames = [1, 220, 291, 1]
target_frames = [str(i) for i in target_frames]
target_exps = [torch.FloatTensor(transfer_meta['frames'][key]['exp']).unsqueeze(0) for key in target_frames]
target_jaws = [torch.FloatTensor(transfer_meta['frames'][key]['pose'][3:]).unsqueeze(0) for key in target_frames]

weights = torch.linspace(0, 1, steps // (len(target_exps)-1)).unsqueeze(1)

exps = []
for i in range(len(target_exps)-1): 
    lerp_exp = torch.lerp(target_exps[i], target_exps[i+1], weights)
    exps.append(lerp_exp)
exps = torch.cat(exps, dim=0).cuda()

jaws = []
for i in range(len(target_jaws)-1): 
    lerp_jaw = torch.lerp(target_jaws[i], target_jaws[i+1], weights)
    jaws.append(lerp_jaw)
jaws = torch.cat(jaws, dim=0).cuda()

frames = []

for bidx, batch in enumerate(dataloader):
    # seed = torch.seed()
    # print(seed)
    for exp, jaw in tqdm(zip(exps, jaws)):
        torch.manual_seed(0)
        batch.update({name: tensor.cuda() for name, tensor in batch.items() if type(tensor) == torch.Tensor})
        batch['codedict_real'].update({name: tensor.cuda() for name, tensor in batch['codedict_real'].items() if type(tensor) == torch.Tensor})
        batch['codedict_real']['exp'] = exp.unsqueeze(0).expand(1, -1)
        batch['codedict_real']['pose'][:,3:] = jaw

        flame = model.flamedecoder(batch['codedict_real'], batch['bbox_real'], pose=pose)
        uv = flame['uv']
        depth = flame['depth']
        c2w = flame['c2w']
        shape = batch['shape_real']
        batch['codedict_real']['shape']

        output = model.net_G_ema(shape, c2w, uv, depth, truncation=0.5, update_mean=False)        
        pred = output['pred'] * 0.5 + 0.5

        pred = pred.permute(0,2,3,1).detach().cpu()[0]
        frames.append(pred)
    break

torchvision.io.write_video('results/demo_expr.mp4', torch.stack(frames, 0) * 255., fps=30, options={'crf': '15'})

### Pose and expression interpolation

In [None]:
dataloader = DataLoader(dataset, shuffle=False, batch_size=4)

steps = 120
angles = np.linspace(0, np.pi * 4, steps)

x_angles = [np.sin(angle) / 6 for angle in angles]
y_angles = [np.cos(angle) / 6 for angle in angles]

target_frames = [1, 220, 150, 291, 1]
target_frames = [str(i) for i in target_frames]
target_exps = [torch.FloatTensor(transfer_meta['frames'][key]['exp']).unsqueeze(0) for key in target_frames]
target_jaws = [torch.FloatTensor(transfer_meta['frames'][key]['pose'][3:]).unsqueeze(0) for key in target_frames]

weights = torch.linspace(0, 1, steps // (len(target_exps)-1)).unsqueeze(1)

exps = []
for i in range(len(target_exps)-1): 
    lerp_exp = torch.lerp(target_exps[i], target_exps[i+1], weights)
    exps.append(lerp_exp)
exps = torch.cat(exps, dim=0).cuda()

jaws = []
for i in range(len(target_jaws)-1): 
    lerp_jaw = torch.lerp(target_jaws[i], target_jaws[i+1], weights)
    jaws.append(lerp_jaw)
jaws = torch.cat(jaws, dim=0).cuda()


frames = []

for bidx, batch in enumerate(dataloader) :
    seed = torch.seed()
    for x_angle, y_angle, exp, jaw in tqdm(zip(x_angles, y_angles, exps, jaws)) :
        # seed = bidx 
        torch.manual_seed(seed)

        batch.update({name: tensor.cuda() for name, tensor in batch.items() if type(tensor) == torch.Tensor})
        batch['codedict_real'].update({name: tensor.cuda() for name, tensor in batch['codedict_real'].items() if type(tensor) == torch.Tensor})
        pose = torch.FloatTensor([[x_angle, y_angle, 0]]).cuda().expand(4, -1)
        batch['codedict_real']['exp'] = exp.unsqueeze(0).expand(4, -1)
        batch['codedict_real']['pose'][:,3:] = jaw
        
        flame = model.flamedecoder(batch['codedict_real'], batch['bbox_real'], pose=pose)
        uv = flame['uv']
        depth = flame['depth']
        c2w = flame['c2w']
        shape = batch['shape_real']
        batch['codedict_real']['shape']
        
        flame_fullhead = mesh_decoder(batch['codedict_real'], batch['bbox_real'], pose=pose, mesh_render=True)
        shape_image = torch.clip(flame_fullhead['shape_image'][0].permute(1,2,0), 0, 1).cpu() #.numpy()

        output = model.net_G_ema(shape, c2w, uv, depth, truncation=0.5, update_mean=False)        
        pred = output['pred'] * 0.5 + 0.5

        pred = pred.permute(0,2,3,1).detach().cpu() #.numpy()
        top = torch.cat([shape_image, pred[0], pred[1]], dim=1)
        bot = torch.cat([torch.ones_like(shape_image), pred[2], pred[3]], dim=1)
        pred = torch.cat([top, bot], dim=0)

        frames.append(pred)
    break

torchvision.io.write_video('results/demo_pose_expr.mp4', torch.stack(frames, 0) * 255., fps=30, options={'crf': '15'})

### Low- and high-resolution results before and after StyleGAN upsampling

In [None]:
dataloader = DataLoader(dataset, shuffle=False, batch_size=1)

steps = 120
angles = np.linspace(0, np.pi * 4, steps)

x_angles = [np.sin(angle) / 6 for angle in angles]
y_angles = [np.cos(angle) / 6 for angle in angles]

target_frames = [1, 220, 150, 291, 1]
target_frames = [str(i) for i in target_frames]
target_exps = [torch.FloatTensor(transfer_meta['frames'][key]['exp']).unsqueeze(0) for key in target_frames]
target_jaws = [torch.FloatTensor(transfer_meta['frames'][key]['pose'][3:]).unsqueeze(0) for key in target_frames]

weights = torch.linspace(0, 1, steps // (len(target_exps)-1)).unsqueeze(1)

exps = []
for i in range(len(target_exps)-1): 
    lerp_exp = torch.lerp(target_exps[i], target_exps[i+1], weights)
    exps.append(lerp_exp)
exps = torch.cat(exps, dim=0).cuda()

jaws = []
for i in range(len(target_jaws)-1): 
    lerp_jaw = torch.lerp(target_jaws[i], target_jaws[i+1], weights)
    jaws.append(lerp_jaw)
jaws = torch.cat(jaws, dim=0).cuda()

frames = []

for bidx, batch in enumerate(dataloader) :
    seed = torch.seed()
    for x_angle, y_angle, exp, jaw in tqdm(zip(x_angles, y_angles, exps, jaws)) :
        # seed = bidx 
        torch.manual_seed(seed)

        batch.update({name: tensor.cuda() for name, tensor in batch.items() if type(tensor) == torch.Tensor})
        batch['codedict_real'].update({name: tensor.cuda() for name, tensor in batch['codedict_real'].items() if type(tensor) == torch.Tensor})
        pose = torch.FloatTensor([[x_angle, y_angle, 0]]).cuda().expand(1, -1)
        batch['codedict_real']['exp'] = exp.unsqueeze(0).expand(1, -1)
        batch['codedict_real']['pose'][:,3:] = jaw
        
        flame = model.flamedecoder(batch['codedict_real'], batch['bbox_real'], pose=pose)
        uv = flame['uv']
        depth = flame['depth']
        c2w = flame['c2w']
        shape = batch['shape_real']
        batch['codedict_real']['shape']
        
        flame_fullhead = mesh_decoder(batch['codedict_real'], batch['bbox_real'], pose=pose, mesh_render=True)
        shape_image = torch.clip(flame_fullhead['shape_image'][0].permute(1,2,0), 0, 1).cpu()

        output = model.net_G_ema(shape, c2w, uv, depth, truncation=0.5, update_mean=False)   
        pred = output['pred'] * 0.5 + 0.5
        pred_low = output['aux_out'] * 0.5 + 0.5

        pred = pred.permute(0,2,3,1).detach().cpu()
        pred_low = torch.nn.functional.interpolate(pred_low, size=(256,256), mode='bilinear')
        pred_low = pred_low.permute(0,2,3,1).detach().cpu()
        # pred_low = np.pad(pred_low[0], ((96,96), (96,96), (0,0)), 'constant', constant_values=0)
        vis = torch.cat([shape_image, pred_low[0], pred[0]], dim=1)

        frames.append(vis)
    break

torchvision.io.write_video('results/demo_low_high_res.mp4', torch.stack(frames, 0) * 255., fps=30, options={'crf': '15'})

## Scripts for figures

### Pose interpolation

In [None]:
y_angles = np.linspace(-0.9, 0.9, 7)
dataloader = DataLoader(dataset, shuffle=True, batch_size=1)


for bidx, batch in enumerate(dataloader):
    seed = torch.seed()
    frames = []
    for angle in y_angles:
        torch.manual_seed(100)
        batch.update({name: tensor.cuda() for name, tensor in batch.items() if type(tensor) == torch.Tensor})
        batch['codedict_real'].update({name: tensor.cuda() for name, tensor in batch['codedict_real'].items() if type(tensor) == torch.Tensor})

        pose = torch.FloatTensor([[0, angle, 0]]).cuda()
        # pose = None

        flame = model.flamedecoder(batch['codedict_real'], batch['bbox_real'], pose=pose, mesh_render=True)
        uv = flame['uv']
        depth = flame['depth']
        c2w = flame['c2w']
        shape = batch['shape_real']
        shape_image = torch.clip(flame['shape_image'][0].permute(1,2,0), 0, 1).cpu().numpy()

        output = model.net_G_ema(shape, c2w, uv, depth, truncation=0.5, update_mean=False)        
        pred = output['pred'] * 0.5 + 0.5
        pred = pred[0].permute(1,2,0).detach().cpu().numpy()
        frames.append(pred)

    pred = np.concatenate(frames, axis=1)
    cv2.imwrite(f'results/pose_interp.png', (pred * 255).round().astype(np.uint8)[:,:,::-1])
    plt.figure(figsize=(15,15))
    plt.imshow(pred)
    break


### Expression interpolation

In [None]:
steps = 7
weights = torch.linspace(0, 1, steps).unsqueeze(1)

rand_1 = 170 # np.random.randint(len(dataset))
rand_2 = 41788 # np.random.randint(len(dataset))
exp1 = dataset[rand_1]['codedict_real']['exp'][None,...]
exp2 = dataset[rand_2]['codedict_real']['exp'][None,...]
jaw1 = dataset[rand_1]['codedict_real']['pose'][3:][None,...]
jaw2 = dataset[rand_2]['codedict_real']['pose'][3:][None,...]

print(rand_1, rand_2)

exps = torch.lerp(exp1, exp2, weights).cuda()
jaws = torch.lerp(jaw1, jaw2, weights).cuda()

for bidx, batch in enumerate(dataloader):
    seed = torch.seed()
    frames = []
    for exp, jaw in zip(exps, jaws):
        torch.manual_seed(seed)
        batch.update({name: tensor.cuda() for name, tensor in batch.items() if type(tensor) == torch.Tensor})
        batch['codedict_real'].update({name: tensor.cuda() for name, tensor in batch['codedict_real'].items() if type(tensor) == torch.Tensor})

        batch['codedict_real']['exp'] = exp.unsqueeze(0)
        batch['codedict_real']['pose'][:,3:] = jaw
        batch['shape_real'][:,100:] = jaw
        pose = torch.FloatTensor([[0,0,0]]).cuda()

        flame = model.flamedecoder(batch['codedict_real'], batch['bbox_real'], pose=pose, mesh_render=True)
        uv = flame['uv']
        depth = flame['depth']
        c2w = flame['c2w']
        shape = batch['shape_real']
        shape_image = torch.clip(flame['shape_image'][0].permute(1,2,0), 0, 1).cpu().numpy()

        output = model.net_G_ema(shape, c2w, uv, depth, truncation=0.5, update_mean=False)        
        pred = output['pred'] * 0.5 + 0.5
        pred = pred[0].permute(1,2,0).detach().cpu().numpy()
        frames.append(pred)

    pred = np.concatenate(frames, axis=1)
    cv2.imwrite(f'results/expression_interp.png', (pred * 255).round().astype(np.uint8)[:,:,::-1])
    plt.figure(figsize=(15,15))
    plt.imshow(pred)
    break


### Shape interpolation

In [None]:
steps = 7
weights = torch.linspace(0, 1, steps).unsqueeze(1)

rand_1 = 2132 # np.random.randint(len(dataset))
rand_2 = 15 # np.random.randint(len(dataset))
shape1 = dataset[rand_1]['codedict_real']['shape'][None,...]
shape2 = dataset[rand_2]['codedict_real']['shape'][None,...]

print(rand_1, rand_2)

shapes = torch.lerp(shape1, shape2, weights).cuda()

for bidx, batch in enumerate(dataloader):
    seed = torch.seed()
    frames = []
    for shape in shapes:
        torch.manual_seed(seed)
        batch.update({name: tensor.cuda() for name, tensor in batch.items() if type(tensor) == torch.Tensor})
        batch['codedict_real'].update({name: tensor.cuda() for name, tensor in batch['codedict_real'].items() if type(tensor) == torch.Tensor})

        batch['codedict_real']['shape'] = shape.unsqueeze(0)
        batch['shape_real'][:,:100] = shape
        pose = torch.FloatTensor([[0,0,0]]).cuda()

        flame = model.flamedecoder(batch['codedict_real'], batch['bbox_real'], pose=pose, mesh_render=True)
        uv = flame['uv']
        depth = flame['depth']
        c2w = flame['c2w']
        shape = batch['shape_real']
        shape_image = torch.clip(flame['shape_image'][0].permute(1,2,0), 0, 1).cpu().numpy()

        output = model.net_G_ema(shape, c2w, uv, depth, truncation=0.5, update_mean=False)        
        pred = output['pred'] * 0.5 + 0.5
        pred = pred[0].permute(1,2,0).detach().cpu().numpy()
        frames.append(pred)
        # frames.append(shape_image)

        # plt.imshow(pred)
        # break

    pred = np.concatenate(frames, axis=1)
    cv2.imwrite(f'results/shape_interp.png', (pred * 255).round().astype(np.uint8)[:,:,::-1])
    plt.figure(figsize=(15,15))
    plt.imshow(pred)
    break


### Latent vector (w space) interpolation

In [None]:
z_dim = model.net_G_ema.z_dim

steps = 7
weights = torch.linspace(0, 1, steps).unsqueeze(1).cuda()

rand_1 = np.random.randint(len(dataset))
rand_2 = np.random.randint(len(dataset))

torch.manual_seed(rand_1)
latent1 = torch.randn(1, z_dim)
torch.manual_seed(rand_2)
latent2 = torch.randn(1, z_dim)

for bidx, batch in enumerate(dataloader):
    shape = batch['shape_real']

    rand_1 = np.random.randint(len(dataset))
    rand_2 = np.random.randint(len(dataset))
    print(rand_1, rand_2)

    torch.manual_seed(rand_1)
    latent1 = torch.randn(1, 512)
    torch.manual_seed(rand_2)
    latent2 = torch.randn(1, 512)

    latent1 = torch.cat([latent1, shape], dim=-1).cuda()
    latent2 = torch.cat([latent2, shape], dim=-1).cuda()

    latent1_w = model.net_G_ema.mapping_network(latent1)
    latent2_w = model.net_G_ema.mapping_network(latent2)

    latents_w = torch.lerp(latent1_w[:,0], latent2_w[:,0], weights)
    latents_w = latents_w.unsqueeze(1).expand(-1, 29, -1)


    seed = torch.seed()
    frames = []
    for latent in latents_w:
        torch.manual_seed(seed)
        batch.update({name: tensor.cuda() for name, tensor in batch.items() if type(tensor) == torch.Tensor})
        batch['codedict_real'].update({name: tensor.cuda() for name, tensor in batch['codedict_real'].items() if type(tensor) == torch.Tensor})

        pose = torch.FloatTensor([[0,0,0]]).cuda()

        flame = model.flamedecoder(batch['codedict_real'], batch['bbox_real'], pose=pose, mesh_render=True)
        uv = flame['uv']
        depth = flame['depth']
        c2w = flame['c2w']
        shape = batch['shape_real']
        shape_image = torch.clip(flame['shape_image'][0].permute(1,2,0), 0, 1).cpu().numpy()

        output = model.net_G_ema(shape, c2w, uv, depth, truncation=0.5, update_mean=False, w=latent.unsqueeze(0))        
        pred = output['pred'] * 0.5 + 0.5
        pred = pred[0].permute(1,2,0).detach().cpu().numpy()
        frames.append(pred)

    pred = np.concatenate(frames, axis=1)
    cv2.imwrite(f'results/identity_interp.png', (pred * 255).round().astype(np.uint8)[:,:,::-1])
    plt.figure(figsize=(15,15))
    plt.imshow(pred)
    break
