In [None]:
import sys; sys.path.append('..')

import torch
from torchvision.transforms import transforms, functional as F
from PIL import Image

from models.bisenet import BiSeNet
from models.fran import FRAN

In [None]:
im = Image.open('floris2.jpg')

In [None]:
import numpy as np

bisenet = BiSeNet(n_classes=19)
bisenet.load_state_dict(torch.load('../pretrained_models/bisenet_79999_iter.pth'))
bisenet.eval().cuda();

tfm_bise = transforms.Compose([
    transforms.Resize(512),
    transforms.CenterCrop(512),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])

def get_face_mask(im):
    bise_input = tfm_bise(im)[None, ...].cuda()
    
    with torch.no_grad():
        bise_output = bisenet(bise_input)[0][0].argmax(0).cpu().numpy()
    mask = ~np.isin(bise_output, [0, 16, 17])
    mask_im = Image.fromarray((mask * 255).astype(np.uint8)).resize((1024, 1024))
    mask = np.array(mask_im) / 255

    return mask

In [None]:
fran = FRAN(padding_mode='zeros')
state_dicts = torch.load('/apollo/fdf/projects/fran/ckpts/8ij6enbo_ep11.pth')
fran.load_state_dict(state_dicts['FRAN'])
fran.eval().cuda();

In [None]:
input_size = 1024

tfm = transforms.Compose([
    transforms.Resize(input_size),
    transforms.CenterCrop(input_size),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])
inv_norm = transforms.Normalize([-1, -1, -1.], [2, 2, 2.])

In [None]:
mask = get_face_mask(im)
t = tfm(im)[None, ...].cuda()

src_age = 29
tgt_ages = [80]

reaged_ims = []

for tgt_age in tgt_ages:
    src_age_map = torch.ones((input_size, input_size)) * src_age
    tgt_age_map = torch.tensor((mask * (tgt_age - src_age)) + src_age).float()
    src_age_map = F.center_crop(src_age_map[None, None, ...], input_size).cuda()
    tgt_age_map = F.center_crop(tgt_age_map[None, None, ...], input_size).cuda()
    
    with torch.no_grad():
        out = fran(t, src_age_map, tgt_age_map)[0].cpu()
    
    im_out = F.to_pil_image(inv_norm(out).clip(min=0, max=1)).resize((512, 512))
    reaged_ims.append(im_out)

In [None]:
Image.fromarray(np.concatenate(reaged_ims, axis=1))

In [None]:
import cv2


def get_video_frames(video_path):
    frames = []

    cap = cv2.VideoCapture(video_path)
    while(cap.isOpened()):
        ret, frame = cap.read()
        if not ret:
            break
        frame = frame[..., ::-1]
        frames.append(Image.fromarray(frame))

    return frames

In [None]:
frames = get_video_frames('test.mov')

In [None]:
from tqdm import tqdm

src_age = 29
tgt_age = 50


src_age_map = torch.ones((input_size, input_size)) * src_age
tgt_age_map = torch.ones_like(src_age_map) * tgt_age
src_age_map = F.center_crop(src_age_map[None, None, ...], input_size).cuda()
tgt_age_map = F.center_crop(tgt_age_map[None, None, ...], input_size).cuda()


fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(f'output_fran_{tgt_age}.mp4', fourcc, 30.0, (input_size, input_size))

resize_crop = transforms.Compose([
    transforms.Resize(input_size),
    transforms.CenterCrop(input_size),
])

for im in tqdm(frames):
    t = tfm(im)[None, ...].cuda()

    with torch.no_grad():
        t_re_aged = fran(t, src_age_map, tgt_age_map)[0].cpu()

    f_mask = get_face_mask(im)
    im_re_aged = F.to_pil_image(inv_norm(t_re_aged).clip(min=0, max=1))

    im_re_aged = np.array(im_re_aged) * f_mask[..., None]
    im_bg = np.array(resize_crop(im)) * (1 - f_mask)[..., None]
    im_out = (im_bg + im_re_aged).astype(np.uint8)
    
    out.write(im_out[..., ::-1])

out.release()