In [89]:
import mediapipe as mp
import cv2
import numpy as np
import os
import pandas as pd
import warnings
from itertools import count
import pickle
warnings.filterwarnings("ignore")

In [90]:
with open('lunges_bot3_rf.pkl', 'rb') as f:
    model_rf = pickle.load(f)

In [91]:
mp_drawing = mp.solutions.drawing_utils
mp_pose = mp.solutions.pose
landmarks = ['class']
for val in range(1, 33+1):
    landmarks += ['x{}'.format(val), 'y{}'.format(val), 'z{}'.format(val), 'v{}'.format(val)]

In [92]:
video_path = 'lunges_video.mp4'
cap = cv2.VideoCapture(video_path)
frame_width = int(cap.get(3))
frame_height = int(cap.get(4))
size = (frame_width, frame_height)
current_stage = ['']
writer = cv2.VideoWriter('result.avi', 
                            cv2.VideoWriter_fourcc(*'MJPG'), 30, size)

In [93]:
with mp_pose.Pose(min_detection_confidence=0.5, min_tracking_confidence=0.5) as pose:
    while cap.isOpened():
        ret, frame = cap.read()
        if ret:
            # Recolor image to RGB
            image = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            image.flags.writeable = False

            # Make detection
            results = pose.process(image)

            image.flags.writeable = True
            image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
            mp_drawing.draw_landmarks(image, results.pose_landmarks, mp_pose.POSE_CONNECTIONS,
                                    mp_drawing.DrawingSpec(color=(245,117,66), thickness=2, circle_radius=2), 
                                    mp_drawing.DrawingSpec(color=(245,66,230), thickness=2, circle_radius=2) 
                                    )
            
            try:
                if results.pose_landmarks is not None:
                    row = np.array([[res.x, res.y, res.z, res.visibility] for res in results.pose_landmarks.landmark],dtype=object).flatten()
                    X = pd.DataFrame([row],columns=landmarks[1:])
                    body_language_class = model_rf.predict(X)[0]
                    body_language_prob = model_rf.predict_proba(X)[0]
                    # print(body_language_class, body_language_prob)

                    if body_language_class == 'up' and body_language_prob[body_language_prob.argmax()] > 0.7:
                        if current_stage[-1]!='up':
                            current_stage.append('up')
                    elif body_language_class == 'left' and body_language_prob[body_language_prob.argmax()] > 0.7:
                        if current_stage[-1]!='left':
                            current_stage.append('left')
                    elif body_language_class == 'right' and body_language_prob[body_language_prob.argmax()] > 0.7:
                        if current_stage[-1]!='right':
                            current_stage.append('right')
                   

                    cv2.rectangle(image, (0,0), (280,180), (245,117,16), -1)

                    cv2.putText(image, f'Stage: {current_stage[-1].upper()}',
                            (10,60), cv2.FONT_HERSHEY_SIMPLEX, 1, (0,0,0), 1, cv2.LINE_AA)
                 
                    left_count = current_stage.count('left')
                    cv2.putText(image, f'Left Count: {left_count}',
                        (10,100), cv2.FONT_HERSHEY_SIMPLEX, 1, (0,0,0), 1, cv2.LINE_AA)
                    # elif current_stage[-1] == 'right':
                    right_count = current_stage.count('right')
                    cv2.putText(image, f'Right Count: {right_count}',
                        (10,140), cv2.FONT_HERSHEY_SIMPLEX, 1, (0,0,0), 1, cv2.LINE_AA)
                
                    k = cv2.waitKey(1)
                    cv2.imshow('Lunges Feed', image)
                    writer.write(image)
                    if cv2.waitKey(10) & 0xFF == ord('q'):
                        break

            except Exception as e:
                print(e)
                # print(results.pose_landmarks)
                pass

        else:
                break