In [1]:
import os
import argparse
import numpy as np
from tqdm import tqdm
from PIL import Image

import torch
import torchvision.transforms.functional as tf

from src import model

In [2]:
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

In [3]:
harmonizer = model.Harmonizer()

if torch.cuda.is_available():
    harmonizer = harmonizer.cuda()

In [4]:
harmonizer.load_state_dict(torch.load('./pretrained/harmonizer.pth'), strict=True)
harmonizer.eval()

  harmonizer.load_state_dict(torch.load('./pretrained/harmonizer.pth'), strict=True)


Harmonizer(
  (backbone): EfficientBackbone(
    (_blocks): ModuleList(
      (0): MBConvBlock(
        (_depthwise_conv): Conv2dStaticSamePadding(
          32, 32, kernel_size=(3, 3), stride=[1, 1], groups=32, bias=False
          (static_padding): ZeroPad2d((1, 1, 1, 1))
        )
        (_bn1): BatchNorm2d(32, eps=0.001, momentum=0.010000000000000009, affine=True, track_running_stats=True)
        (_se_reduce): Conv2dStaticSamePadding(
          32, 8, kernel_size=(1, 1), stride=(1, 1)
          (static_padding): Identity()
        )
        (_se_expand): Conv2dStaticSamePadding(
          8, 32, kernel_size=(1, 1), stride=(1, 1)
          (static_padding): Identity()
        )
        (_project_conv): Conv2dStaticSamePadding(
          32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False
          (static_padding): Identity()
        )
        (_bn2): BatchNorm2d(16, eps=0.001, momentum=0.010000000000000009, affine=True, track_running_stats=True)
        (_swish): MemoryEfficien

In [5]:
comp = Image.open("./demo/image_harmonization/example/composite/f636_1.jpg").convert('RGB')
mask = Image.open("./demo/image_harmonization/example/mask/f636_1.jpg").convert('1')
if comp.size[0] != mask.size[0] or comp.size[1] != mask.size[1]:
    print('The size of the composite image and the mask are inconsistent')
    exit()

comp = tf.to_tensor(comp)[None, ...]
mask = tf.to_tensor(mask)[None, ...]

if device.type == 'cuda':
    comp = comp.cuda()
    mask = mask.cuda()

# harmonization
with torch.no_grad():
    arguments = harmonizer.predict_arguments(comp, mask)
    harmonized = harmonizer.restore_image(comp, mask, arguments)[-1]

# save the result
harmonized = np.transpose(harmonized[0].cpu().numpy(), (1, 2, 0)) * 255
harmonized = Image.fromarray(harmonized.astype(np.uint8))

harmonized.save("result.png")