In [2]:
import torch
import os
import argparse

from torchvision.utils import save_image
from model.model import pSp, condi
from model.DNAnet import DNAnet,VAE
from utils.utils import align_face, totensor
import torch.nn as nn
import matplotlib.pyplot as plt
from tqdm import tqdm


  from .autonotebook import tqdm as notebook_tqdm


In [3]:
args = {
    "dna" : "../saved_models\\vae_4600.pth",
    "enc" : "../pretrained_model/enc_4_2.pth",
    "gan" : "../pretrained_model/stylegan2-ffhq-config-f.pt",
    "map" : '../pretrained_model/condi_4_2.pth',
    "mom" : "result\\assets\\dad.png",
    "dad" : "result\\assets\\dad.png",
    "out" : 'result/vae_child1'
}

if not os.path.exists(args["out"]):
    os.makedirs(args["out"])
    print(f"saving to: {args['out']}")

device = "cuda:0" if torch.cuda.is_available else "cpu"
print(f"device {device}")
# Load model weights
DNA = VAE(512,device=device).to(device)
DNA.load_state_dict(torch.load(args["dna"]), strict=True)
DNA.eval()

net = nn.DataParallel(pSp(3, args["enc"], args["gan"])).to(device)
net.eval()

mapper = nn.DataParallel(condi()).to(device)
mapper.load_state_dict(torch.load(args["map"]), strict=True)
mapper.eval()

device cuda:0
Loading encoder weights from ckpt!
Loading decoder weights from pretrained!


DataParallel(
  (module): condi(
    (style1): Mapper(
      (style): Sequential(
        (0): Linear(in_features=612, out_features=512, bias=True)
        (1): PReLU(num_parameters=1)
        (2): Linear(in_features=512, out_features=512, bias=True)
        (3): PReLU(num_parameters=1)
        (4): Linear(in_features=512, out_features=512, bias=True)
      )
    )
    (style2): Mapper(
      (style): Sequential(
        (0): Linear(in_features=612, out_features=512, bias=True)
        (1): PReLU(num_parameters=1)
        (2): Linear(in_features=512, out_features=512, bias=True)
        (3): PReLU(num_parameters=1)
        (4): Linear(in_features=512, out_features=512, bias=True)
      )
    )
    (style3): Mapper(
      (style): Sequential(
        (0): Linear(in_features=612, out_features=512, bias=True)
        (1): PReLU(num_parameters=1)
        (2): Linear(in_features=512, out_features=512, bias=True)
        (3): PReLU(num_parameters=1)
        (4): Linear(in_features=512, out_f

In [4]:
def infer(args, net, DNA, mapper):
    # Inference
    mImg = align_face(args["mom"]).convert('RGB')
    dImg = align_face(args["dad"]).convert('RGB')

    testAge =  torch.ones((1,1)) * (args["age"]) /100 
    if args["gender"] == 'male':
        testGen = torch.ones(1, 1).to(device)
    else:
        testGen = torch.zeros(1, 1).to(device)

    with torch.no_grad():
        mImg = totensor(mImg).unsqueeze(0).to(device)
        dImg = totensor(dImg).unsqueeze(0).to(device)

        pW = torch.cat([net.module.encoder(mImg), net.module.encoder(dImg)],1)
        sW_hat = DNA(pW)['rec']
        sW_hat_expand = sW_hat.repeat(18, 1, 1).permute(1, 0, 2)
        sW_hat_delta = mapper(sW_hat_expand, testAge,  testGen)
        sImg_hat = net(sW_hat_expand + sW_hat_delta)
    out_file = f'{args["out"]}/result_{args["gender"]}_{args["age"]}.png'
    save_image((sImg_hat+1)/2, out_file, nrow = 1)
    return out_file, sImg_hat 

In [9]:
results = []
images = []
for gender in ["male", "female"]:
    print("Gender:",gender)
    args["gender"] = gender
    # for age in tqdm(range(5,105,10), desc = "Age", position=0):
    #     args["age"] = age
    for i in tqdm(range(10)):
        args["age"] = 25
        out_file, image = infer(args, net, DNA, mapper)
        results.append(out_file)
        images.append(image)

Gender: male


100%|██████████| 10/10 [00:26<00:00,  2.62s/it]


Gender: female


100%|██████████| 10/10 [00:24<00:00,  2.49s/it]


In [10]:
batch_images = torch.stack(images,dim=0).squeeze(1)
print(batch_images.shape)
out_path = f'{args["out"]}/results_combined.png'
save_image((batch_images+1)/2, out_path, nrow = 10)

torch.Size([20, 3, 256, 256])
