# 0. Install and Import Dependencies

In [1]:
import mediapipe as mp
import cv2
import numpy as np
import csv
import warnings
import tensorflow as tf
import pandas as pd
import pickle
warnings.filterwarnings('ignore')

In [2]:
mp_drawing = mp.solutions.drawing_utils #Drawing helpers
mp_pose = mp.solutions.pose

# 1. Load Model

In [3]:
# path to the model
dir_model = "../Model/xyz/"

# you can try different model
# dir_model = "../Model/xyz_hand/"
# dir_model = "../Model/xy/"
# dir_model = "../Model/xy_hand/"

loaded_model = tf.saved_model.load(dir_model)
inferer = loaded_model.signatures["serving_default"]

# 2. Setup OpenCV and MediaPipe to Predict

In [7]:
dir_video_test = ('../data/all videos/wide6.mp4')
cap = cv2.VideoCapture(dir_video_test)
current_stage = ''

#initiate holistic model
with mp_pose.Pose(min_detection_confidence=0.5, min_tracking_confidence=0.5) as pose:

    # stream the video
    while cap.isOpened():
        ret, frame = cap.read()
        
        # recolor feed
        # image = cv2.flip(frame, 1)
        image = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        image.flags.writeable = False
        
        # make detection
        results = pose.process(image)
        
        # recolor image back to BGR for rendering
        image.flags.writeable = False
        image = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
        
        mp_drawing.draw_landmarks(image, results.pose_landmarks, mp_pose.POSE_CONNECTIONS,
                                  mp_drawing.DrawingSpec(color=(245,117,66), thickness=2, circle_radius=4),
                                  mp_drawing.DrawingSpec(color=(245,66,230), thickness=2, circle_radius=2))

        # try to predict
        try:
            # model xyz
            row = np.array([[res.x, res.y, res.z] for res in results.pose_landmarks.landmark]).flatten()
            X = pd.DataFrame([row])
            grip_class = np.argmax(inferer(inputs=X)['output_0'].numpy())
            grip_prob = np.max(inferer(inputs=X)['output_0'].numpy())

            # model xyz_hand
            # row = np.array([[res.x, res.y, res.z] for res in results.pose_landmarks.landmark[13:22]]).flatten()
            # X = pd.DataFrame([row])
            # grip_class = np.argmax(inferer(inputs=X)['output_0'].numpy())
            # grip_prob = np.max(inferer(inputs=X)['output_0'].numpy())

            # model xy
            # row = np.array([[res.x, res.y] for res in results.pose_landmarks.landmark[]]).flatten()
            # X = pd.DataFrame([row])
            # grip_class = np.argmax(inferer(inputs=X)['output_0'].numpy())
            # grip_prob = np.max(inferer(inputs=X)['output_0'].numpy())

            # model xy_hand
            # row = np.array([[res.x, res.y] for res in results.pose_landmarks.landmark[13:22]]).flatten()
            # X = pd.DataFrame([row])
            # grip_class = np.argmax(inferer(inputs=X)['output_0'].numpy())
            # grip_prob = np.max(inferer(inputs=X)['output_0'].numpy())

            print(grip_class, grip_prob)

            if grip_class == 0 and grip_prob >= 0.9 :
                current_stage = 'Good'
            elif grip_class == 1 and grip_prob >= 0.9 :
                current_stage = 'Narrow'
            elif grip_class == 2 and grip_prob >= 0.9 :
                current_stage = 'Wide'

            # get status box
            cv2.rectangle(image, (0,0), (250, 60), (245, 117, 16), -1)

            #Display class
            cv2.putText(image, 'Class', (95, 12), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0,0,0), 1, cv2.LINE_AA)
            cv2.putText(image, current_stage, (95, 40), cv2.FONT_HERSHEY_SIMPLEX, 1, (255,255,255), 2, cv2.LINE_AA)

            #Display Prob
            cv2.putText(image, 'PROB', (15, 12), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0,0,0), 1, cv2.LINE_AA)
            cv2.putText(image, str(round(grip_prob, 2)), (15, 40), cv2.FONT_HERSHEY_SIMPLEX, 1, (255,255,255), 2, cv2.LINE_AA)
    
            
        except Exception as e:
            print(e)
            
        # Stream vid resultq
        cv2.imshow("Raw Cam Feed", image)
        
        if cv2.waitKey(10) & 0xFF == ord('q'):
            break

cap.release()
cv2.destroyAllWindows()

2 0.9986884
2 0.98500746
2 0.99763715
2 0.9999348
2 0.99998605
2 0.9999901
2 0.99999225
2 0.9999541
2 0.9999759
2 0.9998863
2 0.9995289
2 0.9634519
2 0.9901073
2 0.99980754
2 0.99995196
2 0.9999912
2 0.9998142
2 0.9996469
2 0.99999714
2 0.99997795
2 0.99986434
2 0.99949765
2 0.99295723
2 0.99128485
2 0.9866698
0 0.64623564
0 0.90319073
0 0.9484051
0 0.99709177
0 0.9999895
0 0.9999304
0 0.9990829
2 0.6790592
0 0.54355216
2 0.70703596
0 0.9431098
2 0.915121
2 0.8194878
2 0.9682916
2 0.8883375
2 0.9871011
2 0.988892
2 0.99555594
2 0.9725876
2 0.9905253
2 0.9886284
2 0.9868153
2 0.98453206
2 0.9700057
2 0.9456595
2 0.99632096
2 0.9926973
2 0.98993343
2 0.980246
2 0.9749101
2 0.9642482
2 0.66902745
2 0.5026407
0 0.7704943
0 0.85330033
0 0.7147908
0 0.6325801
2 0.5195492
0 0.55322295
2 0.6228071
2 0.61981297
0 0.55137837
2 0.8812858
2 0.9585203
2 0.9633972
2 0.94658077
2 0.96651757
2 0.9746113
2 0.9822031
2 0.99324507
2 0.9973912
2 0.9980282
2 0.9990615
2 0.9944396
2 0.9954921
2 0.99621814
2