In [1]:
import cv2
import torch
import matplotlib.pyplot as plt
from tqdm import tqdm
from PIL import Image

from model.vgg19_encoder import vgg19_encoder
from model.vgg19_decoder import vgg19_decoder
from utils.image_loader import image_loader
from utils.load_transform import load_transform

In [2]:
device = 'cuda'

In [3]:
encoder = vgg19_encoder().to('cuda')
decoder = vgg19_decoder().to('cuda')

encoder_pt_path = './model/weights/vgg19_encoder.pt'
decoder_pt_path = './model/weights/vgg19_decoder.pt'

encoder_pt = torch.load(encoder_pt_path)
decoder_pt = torch.load(decoder_pt_path)

encoder.load_state_dict(encoder_pt, strict=False)
decoder.load_state_dict(decoder_pt)

encoder = encoder.eval()
decoder = decoder.eval()

In [4]:
transform = load_transform()

style_image_path = './data/picasso.jpg'
style_image = image_loader(style_image_path, transform=transform, device=device)
style_feature = encoder(style_image)
style_feature = style_feature.to('cuda')

In [5]:
style_feature.shape

torch.Size([1, 512, 64, 64])

In [7]:
from utils.calc_mean_std import calc_mean_std_

style_mean, style_std = calc_mean_std_(style_feature)

def adaptive_instance_normalization(content_feat=None, style_mean=None, style_std=None):
    """
    논문에서 제시한 AdaIN을 구현
    AdaIN은 content feature의 스타일을 style feature의 스타일로 변경하는 연산
    Args:
        content_feat (_type_): _description_
        style_feat (_type_): _description_

    Returns:
        _type_: _description_
    """
    size = content_feat.size()
    content_mean, content_std = calc_mean_std_(content_feat)

    # 평균(mean)과 표준편차(std)를 이용하여 정규화 수행
    normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
    # 정규화 이후에 style feature의 statistics를 가지도록 설정
    return normalized_feat * style_std.expand(size) + style_mean.expand(size)

In [8]:
def style_transfer(encoder, decoder, content_input, style_mean, style_std, alpha=1.0):
    assert (0.0 <= alpha <= 1.0)
    content_feature = encoder(content_input)
    feature = adaptive_instance_normalization(content_feature, style_mean=style_mean, style_std=style_std)
    feature = feature * alpha + content_feature * (1 - alpha) # Alpha가 1에 가까울수록 스타일이 진해짐
    return decoder(feature)

In [9]:
def frame_process(f):
    f = Image.fromarray(f)
    return f

In [24]:
out = cv2.VideoWriter(filename='test.mp4', fourcc=cv2.VideoWriter_fourcc(*'DIVX'), fps=25, frameSize=(512,512))

In [17]:
cv2.destroyAllWindows()

In [22]:
all_frames = []

cap = cv2.VideoCapture(0)
assert(cap.isOpened())
while True:
    ret, frame = cap.read()
    if ret:
        all_frames.append(frame)
        cv2.imshow('image', frame)
        if cv2.waitKey(1) & 0xFF == ord('q'):
            break
        if len(all_frames) >= 250:
            break
    else:
        break
cv2.destroyAllWindows()


In [25]:
for frame in all_frames:
    content = frame_process(frame)
    content = transform(content).unsqueeze(0).to('cuda')
    result = style_transfer(encoder=encoder, decoder=decoder, content_input=content, style_mean=style_mean, style_std=style_std, alpha=0.8)
    result = result.squeeze(0)
    result = result.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
    out.write(result)
out.release()

In [26]:
video_path = './data/sample.mp4'
cap = cv2.VideoCapture(0)


assert(cap.isOpened())
while True:
    ret, frame = cap.read()
    if ret:
        content = frame_process(frame)
        content = transform(content).unsqueeze(0).to('cuda')
        out = style_transfer(encoder=encoder, decoder=decoder, content_input=content, style_mean=style_mean, style_std=style_std, alpha=0.8)
        out = out.squeeze(0)
        out = out.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
        out = out[:,:,::-1]
        cv2.imshow('image', out)
        cv2.waitKey(10)
    else:
        break
cv2.destroyAllWindows()