In [1]:
import tensorflow as tf
from model_trainer import *
import matplotlib.pyplot as plt
from tensorflow.keras.mixed_precision import experimental as mixed_precision
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
    except RuntimeError as e:
        print(e)
policy = mixed_precision.Policy('mixed_float16')
mixed_precision.set_policy(policy)

INFO:tensorflow:Mixed precision compatibility check (mixed_float16): OK
Your GPU will likely run quickly with dtype policy mixed_float16 as it has compute capability of at least 7.0. Your GPU: GeForce RTX 2070 SUPER, compute capability 7.5


In [2]:
import cv2
import numpy as np
from pathlib import Path
from tqdm.notebook import trange

In [3]:
import backbone_models
import specific_models

In [4]:
import skimage

In [5]:
vid_dir = Path('data/test')
vid_path = vid_dir.iterdir().__next__()

model_path = 'savedmodels/hr538_316_aug_500/50'
bb_model = backbone_models.hr_5_3_8
sp_model = {
    'nose' : specific_models.conv3_16,
    'tail' : specific_models.conv3_16,
}

In [6]:
testmodel = ChaserModel(tf.keras.Input((240,320,3)),bb_model,sp_model)

Model: "functional_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 240, 320, 3) 0                                            
__________________________________________________________________________________________________
HR_0 (HighResolutionModule)     [(None, 240, 320, 8) 3760        input_1[0][0]                    
__________________________________________________________________________________________________
HR_1 (HighResolutionModule)     [(None, 240, 320, 8) 19336       HR_0[0][0]                       
__________________________________________________________________________________________________
HR_2 (HighResolutionModule)     [(None, 240, 320, 8) 92704       HR_1[0][0]                       
                                                                 HR_1[0][1]            

In [7]:
testmodel.load_weights(model_path)

<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x20625771b50>

In [8]:
original_size = (640,480)
original_hw = (original_size[1],original_size[0])
model_size = (320,240)
model_hw = (model_size[1],model_size[0])
batch_size = 64

In [9]:
cap = cv2.VideoCapture(str(vid_path))
fourcc = cv2.VideoWriter_fourcc(*'MP4V')
writer = cv2.VideoWriter('results/result.mp4',fourcc,30,original_size)
frames = []
original_frames=[]
while(cap.isOpened()):
    ret, frame = cap.read()
    if ret:
        small_frame = cv2.resize(frame, dsize=model_size)[...,2::-1]
        frames.append(small_frame)
        original_frames.append(frame)
    else:
        break
print('loaded')
batch_num = len(frames) // batch_size
count = 0
for i in trange(batch_num):
    output = testmodel(np.array(frames[i*batch_size:(i+1)*batch_size]))
    nose_hms, tail_hms = output['nose'].numpy(), output['tail'].numpy()
    nose_poses = np.unravel_index(nose_hms.reshape((nose_hms.shape[0],-1)).argmax(axis=1),nose_hms.shape[1:])
    nose_poses = np.swapaxes(nose_poses,0,1)
    tail_poses = np.unravel_index(tail_hms.reshape((tail_hms.shape[0],-1)).argmax(axis=1),tail_hms.shape[1:])
    tail_poses = np.swapaxes(tail_poses,0,1)
    for nose_pos, tail_pos in zip(nose_poses, tail_poses):
        new_nose = np.multiply(nose_pos,np.divide(original_hw,model_hw)).astype(np.int)
        new_tail = np.multiply(tail_pos,np.divide(original_hw,model_hw)).astype(np.int)
        nose_rr, nose_cc = skimage.draw.disk(new_nose,10, shape=original_hw)
        tail_rr, tail_cc = skimage.draw.disk(new_tail,10, shape=original_hw)
        frame = original_frames[count]
        frame[nose_rr,nose_cc] = [0,255,0]
        frame[tail_rr,tail_cc] = [255,0,0]
        writer.write(frame)
        count += 1

cap.release()
writer.release()
print('done')

loaded


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=89.0), HTML(value='')))


done
