In [1]:
import torch
from torchvision.utils import make_grid
from torchvision.utils import save_image
from util.image import unnormalize

In [None]:
from torchvision import transforms
from net import PConvUNet
from celeb import CelebA
import opt
from util.io import load_ckpt

class Argument():
    def __init__(self):
        self.root = './data_original'
        self.mask_root = './irregular_mask/disocclusion_img_mask/'
        self.save_dir = './snapshots/default'
        self.log_dir = './logs/default'
        self.lr = 0.0008
        self.lr_finetune = 0.0001
        self.max_iter = 9999999
        self.batch_size = 24
        self.n_threads = 4
        self.save_model_interval = 1000
        self.vis_interval = 500
        self.log_interval = 100
        self.image_size = 512
args = Argument()

device = torch.device('cuda')
size = (args.image_size, args.image_size)
img_tf = transforms.Compose(
    [transforms.Resize(size=size), transforms.ToTensor(),
     transforms.Normalize(mean=opt.MEAN, std=opt.STD)])
mask_tf = transforms.Compose(
    [
        transforms.RandomCrop(size=size),
#         transforms.Resize(size=size),
        transforms.ToTensor()
    ])

dataset_val = CelebA(args.root, args.mask_root, img_tf, mask_tf, 'test_')

lr = args.lr
model = PConvUNet()
model = PConvUNet().to(device)
optimizer = torch.optim.Adam(
    filter(lambda p: p.requires_grad, model.parameters()), lr=lr)
ckpt_name = './snapshots/default/ckpt/86000.pth'
start_iter = load_ckpt(
    ckpt_name, [('model', model)], [('optimizer', optimizer)])
for param_group in optimizer.param_groups:
    param_group['lr'] = lr
print('Starting from iter ', start_iter)

In [None]:
dataset = dataset_val
image_num=6
# image, mask, gt = zip(*[dataset[i] for i in range(image_num)])
image, mask, gt = zip(*[dataset[0] for i in range(image_num)])
image = torch.stack(image)
mask = torch.stack(mask)
gt = torch.stack(gt)
with torch.no_grad():
    output, _ = model(image.to(device), mask.to(device))
output = output.to(torch.device('cpu'))
output_comp = mask * image + (1 - mask) * output

grid = make_grid(
    torch.cat((unnormalize(image), mask, unnormalize(output),
               unnormalize(output_comp), unnormalize(gt)), dim=0),
    nrow=image_num
)

In [None]:
filename = './data_original/output/029.jpg'
save_image(grid, filename)