In [1]:
import tensorflow as tf
import cv2
from tensorflow.keras.utils import CustomObjectScope
import numpy as np
import os

In [2]:
smooth = 1e-15
def dice_coef(y_true, y_pred):
    y_true = tf.keras.layers.Flatten()(y_true)
    y_pred = tf.keras.layers.Flatten()(y_pred)
    intersection = tf.reduce_sum(y_true * y_pred)
    return (2. * intersection + smooth) / (tf.reduce_sum(y_true) + tf.reduce_sum(y_pred) + smooth)

def dice_loss(y_true, y_pred):
    return 1.0 - dice_coef(y_true, y_pred)

In [3]:
#with CustomObjectScope({'dice_loss': dice_loss, 'dice_coef': dice_coef}):
#    model = tf.keras.models.load_model("model.h5")
    #model.summary()

In [4]:
model = tf.keras.models.load_model("mobilev3_unet.h5")

In [5]:
H = 560
W = 560

In [6]:
imgBg = cv2.imread("bg/1.jpg")
listImg = os.listdir("bg")
indexImg = 0

imgList = []
for imgPath in listImg:
    img = cv2.imread(f'bg/{imgPath}')
    imgList.append(img)


cap = cv2.VideoCapture(0)

# loop through frame
while cap.isOpened():
    ret, frame = cap.read()

    h, w, _ = frame.shape
    ori_frame = frame

    frame = cv2.resize(frame, (W, H))
    frame = np.expand_dims(frame, axis=0)
    frame = frame / 255.0

    mask = model.predict(frame)[0]
    mask = cv2.resize(mask, (w, h))
    mask = mask > 0.5
    mask = mask.astype(np.float32)
    mask = np.expand_dims(mask, axis=-1)
    
    photo_mask = mask
    background_mask = np.abs(1 - mask)

    masked_frame = ori_frame * photo_mask

    background_mask = np.concatenate([background_mask, background_mask, background_mask], axis=-1)
    background_mask = background_mask * imgList[indexImg]
    final_frame = masked_frame + background_mask
    final_frame = final_frame.astype(np.uint8)
    
    # show result to user on desktop
    cv2.imshow('Bg replacement', final_frame)

    # Break loop outcome
    key = cv2.waitKey(1)
    if key == ord('a'):
        if indexImg > 0:
            indexImg -= 1
    elif key == ord('d'):
        if indexImg < len(imgList) - 1:
            indexImg += 1
    elif key == ord('q'):
        break

# Releases webcam or caputer device
cap.release()
# Closes imshow frame
cv2.destroyAllWindows()

