In [None]:
from random import randint
import sys; sys.path.append('..')

import torch
from torchvision.transforms import transforms, functional as F
from PIL import Image

import numpy as np

from models.fran import FRAN
from datasets.fran_dataset import FRANDataset

In [None]:
state_dicts = torch.load('../ckpts/8ij6enbo_last.pth')

In [None]:
fran = FRAN('zeros')

fran.load_state_dict(state_dicts['FRAN'])
fran.eval().cuda();

In [None]:
ccrop_size = (512, 512)
tfm = transforms.Compose([
    transforms.ToTensor(),
    transforms.CenterCrop(ccrop_size),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])
inv_norm = transforms.Normalize([-1, -1, -1.], [2, 2, 2.])

ds_val = FRANDataset(
    data_root='../data/FRAN_dataset/',
    is_val=True,
    transform=tfm,
    num_folds=5,
    val_fold=0,
    n_subsample=40,
)

In [None]:
idx = randint(0, len(ds_val) - 1)
src_img, src_age, _, _ = ds_val[idx]
src_img = src_img.cuda()[None, ...]

tgt_ages = [20, 40, 60, 80]

reaged_ims = []

for tgt_age in tgt_ages:
    src_age_map = torch.ones_like(src_img[:, :1, ...]) * src_age
    tgt_age_map = torch.ones_like(src_age_map) * tgt_age

    with torch.no_grad():
        out = fran(src_img, src_age_map, tgt_age_map)[0].cpu()
    
    im = F.to_pil_image(inv_norm(out).clip(min=0, max=1)).resize((128, 128))
    reaged_ims.append(im)

Image.fromarray(np.concatenate(reaged_ims, axis=1))