In [1]:
import torch
import cv2
import numpy as np
import wandb
import os
import time
import torchvision.transforms as T
import PIL
import matplotlib.pyplot as plt
from model import load_model, HED

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
def preprocess(img, device, convert = True):
#     img = cv2.imread(fp)
    if convert:
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = img.astype(np.float32)
    img -= np.array((104.00698793,116.66876762,122.67891434))
    tf = T.ToTensor()
    x = tf(img).unsqueeze(0).to(device)
    
    return x

def model_inference(model,x):
    with torch.no_grad():
        preds_tuple = model.forward(x)
        outputs = preds_tuple[-1]  # use fuse output
    edge = outputs.cpu()
    edge = edge.squeeze()
    edge = edge.numpy()
    return edge

def prepare_model(path,device):
    
    model = load_model(model_path) # Load model
    model = model.to(device) # Move to device
    model.eval() # Evaluation mode for inference
    
    return model

In [4]:
# #  Load model do inference
model_path = 'hed_checkpoint.pt'
model = prepare_model(model_path, device)

# Apply the transform
img = cv2.imread('data/bear.jpg')
x = preprocess(img, device)

edge = model_inference(model,x)

cv2.imshow('example',edge)
cv2.waitKey(0)
cv2.destroyAllWindows()

In [5]:
# Webcam example
import cv2

def show_webcam(mirror=False, concat = True):
    cam = cv2.VideoCapture(0)
    
    while True:
        ret_val, img = cam.read()
        
        # Preprocess and do inference
        x = preprocess(img,device,convert=True)
        edge_map = model_inference(model,x)
        
        # Convert to rgb
        edge_map_rgb = cv2.cvtColor(edge_map, cv2.COLOR_GRAY2RGB) # Change to rgb
        edge_map_rgb = (255*edge_map_rgb).astype(np.uint8) # Change to uint8 
        
        # Concatenate the original and the edge map
        if concat:
            display_img = cv2.hconcat([img, edge_map_rgb])
        else:
            display_img = edge_map

        if mirror: 
            display_img = cv2.flip(display_img, 1)
            
        cv2.imshow('my webcam', display_img)
        if cv2.waitKey(1) == 27: 
            break  # esc to quit
    cv2.destroyAllWindows()
    
show_webcam()