In [1]:
import numpy as np
import tensorflow as tf
from tqdm.notebook import tqdm

In [2]:
interpreter = tf.lite.Interpreter(model_path='tflite_models/mobv3_small_07_head.tflite')
interpreter.allocate_tensors()

In [3]:
import cv2
import skimage.draw
from pathlib import Path

In [4]:
import time

In [5]:
vid_dir = Path('data/test')
vid_path = vid_dir.iterdir().__next__()

In [6]:
original_size = (640,480)
original_hw = (original_size[1],original_size[0])

In [7]:
def draw_head(interpreter, original_frame):
    total_start = time.time()
    original_hw = original_frame.shape[:2]

    input_details = interpreter.get_input_details()[0]
    target_size = tuple(input_details['shape'][2:0:-1])
    resized_frame = cv2.resize(original_frame, dsize=target_size)[...,2::-1]
    resized_frame = resized_frame[np.newaxis,...]

    invoke_start = time.time()
    input_idx = input_details['index']
    interpreter.set_tensor(input_idx, resized_frame)

    interpreter.invoke()

    output_idx = interpreter.get_output_details()[0]['index']
    invoke_end = time.time()
    hm = np.squeeze(interpreter.get_tensor(output_idx))
    
    pos = np.unravel_index(hm.flatten().argmax(),hm.shape)
    pos = np.multiply(pos, np.divide(original_hw,hm.shape)).astype(np.int)
    rr,cc = skimage.draw.disk(pos,10, shape=original_hw)
    original_frame[rr, cc] = [0,255,0]
    return original_frame, invoke_end-invoke_start, time.time()-total_start

In [8]:
cap = cv2.VideoCapture(str(vid_path))
fourcc = cv2.VideoWriter_fourcc(*'mp4v')

frames = []
drawn_frames =[]
while cap.isOpened():
    ret, frame = cap.read()
    if ret:
        frames.append(frame)
    else:
        break
cap.release()
total_time_acc = 0
invoke_time_acc = 0
print('invoking')
for frame in tqdm(frames):
    drawn_frame, invoke_time, total_time = draw_head(interpreter, frame)
    total_time_acc += total_time
    invoke_time_acc += invoke_time
    drawn_frames.append(drawn_frame)
print(f'total time: {total_time_acc}')
print(f'invoke time: {invoke_time_acc}')



invoking


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=5699.0), HTML(value='')))


total time: 192.03338503837585
invoke time: 188.15914154052734


In [9]:
print('writing')
writer = cv2.VideoWriter('results/tflite_result.mp4',fourcc,30, original_size)
for df in tqdm(drawn_frames):
    writer.write(df)
writer.release()


writing


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=5699.0), HTML(value='')))


