In [1]:
import sys, os, cv2, torch
sys.path.append(f"{os.getcwd()}")
from PIL import Image
from torchvision.transforms import transforms
from yacs.config import CfgNode as CN
from style_transfer.networks import AdaIN, AdaINConfig
from torchvision.utils import save_image

In [7]:
def transform(contentPath, stylePath):
    config = AdaINConfig.create("style_transfer/config/adain.yaml")
    config.defrost()
    config.preserve_color = False
    config.alpha = 1.0
    config.save_dir = "../.output/results"
    config.freeze()
    network = AdaIN(config)
    network.loadModel({
        "vgg": "../../../Models/AdaIN/vgg_normalised.pth",
        "decoder": "../../../Models/AdaIN/decoder.pth"
    })
    with open(contentPath, 'rb') as file:
        content = Image.open(file)
        content.convert("RGB")
        
    with open(stylePath, 'rb') as file:
        style = Image.open(file)
        style.convert("RGB")
    
    smallest_size = sys.maxsize
    for val in content.size:
        if val < smallest_size:
            smallest_size = val
    
    composite = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize(smallest_size),
    ])
    
    content = composite(content)
    style = composite(style)
    
    styledImage = network.transformTo(content, style)

    styledImage = styledImage[0].detach().cpu().numpy()
    styledImage = styledImage.transpose(1,2,0)
    styledImage = cv2.cvtColor(styledImage, cv2.COLOR_RGB2BGR)
    cv2.imwrite(os.path.join(config.save_dir, "test_adain.png"), styledImage * 255)

In [8]:
transform("../../../Datasets/custom/Photo.png", "../../../Datasets/custom/Style.png")