In [None]:
from pathlib import Path
from PIL import Image
import torch
import torchvision as tv
import torchvision.transforms.functional as tvf
torch.set_grad_enabled(False)

import sys
sys.path.append('../')
from models.library import qres17m

In [None]:
device = torch.device('cuda:0')

# initialize model
lmb = 64
model = qres17m(lmb=lmb)
wpath = f'../checkpoints/qres17m/lmb{lmb}/last_ema.pt'
msd = torch.load(wpath)['model']
model.load_state_dict(msd)

model = model.to(device=device)
model.eval()
print(f'Using lmb={lmb}. Model weights={wpath}')

In [None]:
impath1 = Path('../images/celaba64-1.png')
impath2 = Path('../images/celaba64-2.png')

im1 = tvf.to_tensor(Image.open(impath1)).unsqueeze_(0).to(device=device)
im2 = tvf.to_tensor(Image.open(impath2)).unsqueeze_(0).to(device=device)

stats1 = model.forward_get_latents(im1)
stats1 = [st['z'] for st in stats1]
stats2 = model.forward_get_latents(im2)
stats2 = [st['z'] for st in stats2]
# samples = model.forward_samples_set_latents(1, latents=stats1)

steps = 8
interpolations = []
linspace = torch.linspace(0, 1, steps).tolist()
L = len(stats1)
for keep in range(1, L+1):
    interpolations.append(im1.squeeze(0))
    for lmb in linspace:
        latents = [(1-lmb)*z1 + lmb*z2 for z1, z2 in zip(stats1, stats2)]
        latents = [z if (i < keep) else None for (i,z) in enumerate(latents)]
        sample = model.cond_sample(latents, nhw_repeat=(1,1,1), temprature=0.)
        interpolations.append(sample.squeeze(0))
    interpolations.append(im2.squeeze(0))

svpath = f'runs/adv_interpolation_{impath1.stem}_{impath2.stem}.png'

In [None]:
im = tv.utils.make_grid(interpolations, nrow=steps+2)
tvf.to_pil_image(im)