In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "7"
device = "cuda"
import pickle
import torch
import numpy as np
import PIL.Image
from pathlib import Path
from IPython.display import display, clear_output
from nokogiri.working_dir import working_dir
from collections import defaultdict
from importlib import reload
import projector
import torch
import torch.nn.functional as F
from script_util import rcParams
import matplotlib.pyplot as plt
plt.rcParams.update(rcParams)

In [None]:
fm = Path("/data/natsuki/training116/00030-v4-mirror-auto4-gamma100-batch64-noaug-resumecustom/network-snapshot-011289.pkl")
assert fm.is_file()
with open(fm, 'rb') as f:
    G = pickle.load(f)['G_ema'].to(device)
def mapping(seed=1, psi=1):
    label = torch.zeros([1, G.c_dim], device=device)
    z = torch.from_numpy(np.random.RandomState(seed).randn(1, G.z_dim)).to(device)
    w = G.mapping(z, label, truncation_psi=psi)
    return w
def synthesis(w, synth=True):
    if synth:
        synth_image = G.synthesis(w, noise_mode='const')
    else:
        synth_image = w
    synth_image = (synth_image + 1) * (255/2)
    synth_image = synth_image.permute(0, 2, 3, 1).clamp(0, 255).to(torch.uint8)[0].cpu().numpy()
    img = PIL.Image.fromarray(synth_image, 'RGB')
    return img
def load_target(target_fname):
    target_pil = PIL.Image.open(target_fname).convert('RGB')
    w, h = target_pil.size
    s = min(w, h)
    target_pil_resized = target_pil.crop(((w - s) // 2, (h - s) // 2, (w + s) // 2, (h + s) // 2))
    target_pil_resized = target_pil.resize((G.img_resolution, G.img_resolution), PIL.Image.LANCZOS)
    target_uint8 = np.array(target_pil_resized, dtype=np.uint8)
    target_tensor = torch.tensor(target_uint8.transpose([2, 0, 1]), device=device)
    return target_tensor, target_pil

In [None]:
with working_dir("/home/natsuki/bizarre-pose-estimator"):
    from _train.danbooru_tagger.models.kate import Model as DanbooruTagger
    danbooru_tagger = DanbooruTagger.load_from_checkpoint(
        './_train/danbooru_tagger/runs/waning_kate_vulcan0001/checkpoints/'
        'epoch=0022-val_f2=0.4461-val_loss=0.0766.ckpt'
    )
    danbooru_tagger.eval()
    danbooru_tagger.cuda()
    for param in danbooru_tagger.parameters():
        param.requires_grad = False
def biz_feat(imgs):
    imgs = F.interpolate(imgs, size=(256, 256), mode='area')
    imgs = (imgs-imgs.min())/(imgs.max()-imgs.min())
    feat = danbooru_tagger(imgs)["raw"]
    return feat
def biz_loss(target, synth):
    dist = danbooru_tagger.loss(torch.sigmoid(target), synth)["loss"]
    return dist

In [None]:
reload(projector)
target_tensor, target_pil = load_target(list(Path("/data/natsuki/danbooru2020/v4/0000/").glob("*.png"))[0])
projected_w_steps = projector.project(
    G,
    target=target_tensor,
    num_steps=100,
    device=device,
    yield_more=True,
    additional_feat=biz_feat,
    additional_loss=biz_loss,
    additional_weight=1,
)
display(target_pil)

In [None]:
record = defaultdict(list)
for data in projected_w_steps:
    for k, v in data.items():
        record[k].append(v)
    clear_output(wait=True)
    display(data["pil"], data["dist"], data["additional_dist"], data["log"])
#    fig = plt.figure()
#    fig, axs = plt.subplots(1, 3)
#    axs[0].plot(record["loss"])
#    axs[1].imshow(target_pil)
#    axs[2].imshow(data["pil"])
#    display(plt.gcf())