In [1]:
import cv2
import torch
import re
from neural_style.transformer_net import TransformerNet
from neural_style import utils
import time

In [2]:
STYLE_TRANSFORM_PATH = ["./saved_models/rain_princess.pth", "./saved_models/candy.pth"]
PRESERVE_COLOR = False
WIDTH = 1280//2
HEIGHT = 720//2

# main

In [3]:
def webcam(style_transform_path, width=1280, height=720):
    """
    Captures and saves an image, perform style transfer, and again saves the styled image.
    Reads the styled image and show in window. 
    """
    # Device
    device = ("cuda" if torch.cuda.is_available() else "cpu")

    # Load Transformer Network
    print("Loading Transformer Network")
    state_dicts = []
    for mdls in style_transform_path:
        state_dicts.append(torch.load(mdls))
    
    # remove saved deprecated running_* keys in InstanceNorm from the checkpoint
    for sdict in state_dicts:
        for k in list(sdict.keys()):
            if re.search(r'in\d+\.running_(mean|var)$', k):
                del sdict[k]
                
    idx = 0
    net = TransformerNet()
    net.load_state_dict(state_dicts[idx])
    net = net.to(device)
    print("Done Loading Transformer Network")

    # Set webcam settings
    cam = cv2.VideoCapture(0)
    cam.set(3, width)
    cam.set(4, height)

    # Main loop
    with torch.no_grad():
        st = time.time()
        while True:
            # Get webcam input
            ret_val, img = cam.read()

            # Mirror 
            img = cv2.flip(img, 1)

            # Free-up unneeded cuda memory
            torch.cuda.empty_cache()
            
            # Generate image
            content_tensor = utils.itot(img).to(device)
            generated_tensor = net(content_tensor)
            generated_image = utils.ttoi(generated_tensor.detach())
            if (PRESERVE_COLOR):
                generated_image = utils.transfer_color(img, generated_image)

            generated_image = generated_image / 255

            # Show webcam
            cv2.imshow('Demo webcam', generated_image)
            if cv2.waitKey(1) == 27: 
                break  # esc to quit
        
            # cycle models
            if time.time()-st > 5:
                if idx == 0:
                    idx = 1
                    net.load_state_dict(state_dicts[idx])
                else:
                    idx = 0
                    net.load_state_dict(state_dicts[idx])
                    
                st = time.time()
                    
            
    # Free-up memories
    cam.release()
    cv2.destroyAllWindows()

In [4]:
webcam(STYLE_TRANSFORM_PATH, WIDTH, HEIGHT)

Loading Transformer Network
Done Loading Transformer Network


# debug

In [None]:
print("Loading Transformer Network")
state_dicts = []
for mdls in STYLE_TRANSFORM_PATH:
    state_dicts.append(torch.load(mdls))

In [None]:
# remove saved deprecated running_* keys in InstanceNorm from the checkpoint
for sdict in state_dicts:
    for k in list(sdict.keys()):
        if re.search(r'in\d+\.running_(mean|var)$', k):
            del sdict[k]

In [None]:
net = TransformerNet()

In [None]:
net = net.to("cuda")

In [None]:
net.load_state_dict(state_dicts[0])

In [None]:
next(net.parameters()).device

In [None]:
next(net.parameters()).device

In [None]:
net.load_state_dict(state_dicts[1])

In [None]:
next(net.parameters()).device

In [None]:
state_dicts[0].to('cuda')

In [None]:
st = time.time()

In [None]:
time.time()-st