# 0. Install and Import Dependencies

In [1]:
import mediapipe as mp
import cv2
import numpy as np
import csv
import warnings
import tensorflow as tf
import pandas as pd
import pickle
warnings.filterwarnings('ignore')

In [2]:
mp_drawing = mp.solutions.drawing_utils #Drawing helpers
mp_pose = mp.solutions.pose

# 1. Load Model

In [3]:
# path to the model
dir_model = "../model/xyz/"

# you can try different model
# dir_model = "../Model/xyz_hand/"
# dir_model = "../Model/xy/"
# dir_model = "../Model/xy_hand/"

loaded_model = tf.saved_model.load(dir_model)
inferer = loaded_model.signatures["serving_default"]

# 2. Setup OpenCV and MediaPipe to Predict

In [10]:
dir_video_test = ('../../data/momentum/test.mp4')
cap = cv2.VideoCapture(dir_video_test)
current_stage = ''

#initiate holistic model
with mp_pose.Pose(min_detection_confidence=0.5, min_tracking_confidence=0.5) as pose:

    # stream the video
    while cap.isOpened():
        ret, frame = cap.read()
        
        # recolor feed
        # image = cv2.flip(frame, 1)
        image = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        image.flags.writeable = False
        
        # make detection
        results = pose.process(image)
        
        # recolor image back to BGR for rendering
        image.flags.writeable = False
        image = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
        
        mp_drawing.draw_landmarks(image, results.pose_landmarks, mp_pose.POSE_CONNECTIONS,
                                  mp_drawing.DrawingSpec(color=(245,117,66), thickness=2, circle_radius=4),
                                  mp_drawing.DrawingSpec(color=(245,66,230), thickness=2, circle_radius=2))

        # try to predict
        try:
            # model xyz
            row = np.array([[res.x, res.y, res.z] for res in results.pose_world_landmarks.landmark]).flatten()
            X = pd.DataFrame([row])
            momentum_class = np.argmax(inferer(inputs=X)['output_0'].numpy())
            momentum_prob = np.max(inferer(inputs=X)['output_0'].numpy())

            # model xyz_hand
            # row = np.array([[res.x, res.y, res.z] for res in results.pose_landmarks.landmark[13:22]]).flatten()
            # X = pd.DataFrame([row])
            # momentum_class = np.argmax(inferer(inputs=X)['output_0'].numpy())
            # momentum_prob = np.max(inferer(inputs=X)['output_0'].numpy())

            # model xy
            # row = np.array([[res.x, res.y] for res in results.pose_landmarks.landmark[]]).flatten()
            # X = pd.DataFrame([row])
            # momentum_class = np.argmax(inferer(inputs=X)['output_0'].numpy())
            # momentum_prob = np.max(inferer(inputs=X)['output_0'].numpy())

            # model xy_hand
            # row = np.array([[res.x, res.y] for res in results.pose_landmarks.landmark[13:22]]).flatten()
            # X = pd.DataFrame([row])
            # momentum_class = np.argmax(inferer(inputs=X)['output_0'].numpy())
            # momentum_prob = np.max(inferer(inputs=X)['output_0'].numpy())

            print(momentum_class, momentum_prob)

            if momentum_class == 0 and momentum_prob >= 0.9 :
                current_stage = 'Nice!'
            elif momentum_class == 1 and momentum_prob >= 0.9 :
                current_stage = 'Use momentum!'

            # get status box
            cv2.rectangle(image, (0,0), (250, 60), (245, 117, 16), -1)

            #Display class
            cv2.putText(image, 'Class', (95, 12), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0,0,0), 1, cv2.LINE_AA)
            cv2.putText(image, current_stage, (95, 40), cv2.FONT_HERSHEY_SIMPLEX, 1, (255,255,255), 2, cv2.LINE_AA)

            #Display Prob
            cv2.putText(image, 'PROB', (15, 12), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0,0,0), 1, cv2.LINE_AA)
            cv2.putText(image, str(round(momentum_prob, 2)), (15, 40), cv2.FONT_HERSHEY_SIMPLEX, 1, (255,255,255), 2, cv2.LINE_AA)
    
            
        except Exception as e:
            print(e)
            
        # Stream vid resultq
        cv2.imshow("Raw Cam Feed", image)
        
        if cv2.waitKey(10) & 0xFF == ord('q'):
            break

cap.release()
cv2.destroyAllWindows()

1 1.0
1 0.9999999
1 0.99999964
1 0.99999976
1 0.99999964
1 0.99999976
1 0.9999994
1 0.99999964
1 0.9999999
1 1.0
1 1.0
1 1.0
1 1.0
1 1.0
1 1.0
1 1.0
1 1.0
1 1.0
1 1.0
1 1.0
1 1.0
1 1.0
1 1.0
1 1.0
1 1.0
1 1.0
1 1.0
1 1.0
1 1.0
1 0.99999994
1 0.99999994
1 0.99999994
1 0.99999994
1 1.0
1 0.99999994
1 1.0
1 1.0
1 1.0
1 1.0
1 1.0
1 1.0
1 1.0
1 1.0
1 1.0
1 1.0
1 1.0
1 1.0
1 1.0
1 1.0
1 1.0
1 1.0
1 1.0
1 1.0
1 1.0
1 1.0
1 1.0
1 1.0
1 1.0
1 1.0
1 0.9999997
1 0.99999964
1 0.9999996
1 0.99999976
1 1.0
1 1.0
1 1.0
1 1.0
1 1.0
1 1.0
1 1.0
1 1.0
1 1.0
1 1.0
1 1.0
1 1.0
1 1.0
1 1.0
1 1.0
1 1.0
1 1.0
1 1.0
1 1.0
1 1.0
1 1.0
1 1.0
1 0.9999999
1 0.9999998
1 0.9999998
1 0.9999995
1 0.9999996
1 0.99999976
1 0.9999998
1 0.9999999
1 0.9999999
1 0.9999999
1 0.9999999
1 0.9999999
1 0.9999999
1 1.0
1 1.0
1 1.0
1 1.0
1 1.0
1 1.0
1 1.0
1 1.0
1 0.9999999
1 0.9999999
1 0.99999976
1 0.99999976
1 0.9999998
1 0.99999976
1 0.9999999
1 0.99999994
1 1.0
1 1.0
1 1.0
1 0.9999999
1 0.9999998
1 0.99999976
1 0.9999999
1 0.