In [1]:
from sklearn.pipeline import make_pipeline 
from sklearn.preprocessing import StandardScaler 

from sklearn.linear_model import LogisticRegression, RidgeClassifier
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier

In [2]:
import pandas as pd
from sklearn.model_selection import train_test_split

In [37]:
EXPORT_PATH = 'push-up-reload-2.csv'

In [38]:
df = pd.read_csv(EXPORT_PATH)
df['class'].value_counts()

arriba    3599
abajo     2294
Name: class, dtype: int64

In [39]:
X = df.drop('class', axis=1) # features
y = df['class'] # target

In [40]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=1234)

In [41]:
from sklearn.pipeline import make_pipeline 
from sklearn.preprocessing import StandardScaler 

from sklearn.linear_model import LogisticRegression, RidgeClassifier
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier

In [42]:
pipelines = {
    'lr':make_pipeline(StandardScaler(), LogisticRegression(penalty = 'l2', class_weight='balanced', max_iter=20000)),
    'rc':make_pipeline(StandardScaler(), RidgeClassifier()),
    'rf':make_pipeline(StandardScaler(), RandomForestClassifier()),
    'gb':make_pipeline(StandardScaler(), GradientBoostingClassifier()),
}

In [43]:
fit_models = {}
for algo, pipeline in pipelines.items():
    model = pipeline.fit(X_train, y_train)
    fit_models[algo] = model

In [44]:
fit_models['gb'].predict(X_test)

array(['arriba', 'abajo', 'arriba', ..., 'arriba', 'abajo', 'arriba'],
      dtype=object)

In [45]:
from sklearn.metrics import accuracy_score, precision_score, recall_score # Accuracy metrics 
import pickle 

In [46]:
for algo, model in fit_models.items():
    yhat = model.predict(X_test)
    print(algo, accuracy_score(y_test, yhat),
          precision_score(y_test.values, yhat, average="weighted", pos_label=1),
          recall_score(y_test.values, yhat, average="weighted", pos_label=1))

lr 0.9236425339366516 0.9245580661901776 0.9236425339366516
rc 0.9292986425339367 0.9291628029636481 0.9292986425339367
rf 0.9869909502262444 0.9869870934524283 0.9869909502262444
gb 0.9773755656108597 0.9773647363937634 0.9773755656108597


In [47]:
yhat = fit_models['gb'].predict(X_test)

In [48]:
yhat[:10]

array(['arriba', 'abajo', 'arriba', 'arriba', 'abajo', 'arriba', 'abajo',
       'arriba', 'arriba', 'abajo'], dtype=object)

In [66]:
with open('flexiones_cb', 'wb') as f:
    pickle.dump(fit_models['gb'], f)

In [67]:
with open('flexiones_cb', 'rb') as f:
    model = pickle.load(f)

In [68]:
landmarks = ['class']
for val in range(1,33+1):
    landmarks += ['x{}'.format(val), 'y{}'.format(val), 'z{}'.format(val), 'v{}'.format(val)]

In [69]:
import cv2
import mediapipe as mp
import numpy as np

In [70]:
mp_drawing = mp.solutions.drawing_utils
mp_pose = mp.solutions.pose

In [None]:
cap = cv2.VideoCapture("flexiones_1.mp4")

# Curl counter variables
counter = 0 
current_stage = None

## Setup mediapipe instance
with mp_pose.Pose(min_detection_confidence=0.5, min_tracking_confidence=0.5) as pose:
    while cap.isOpened():
        ret, frame = cap.read()
        
        # Recolor image to RGB
        image = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        image.flags.writeable = False
      
        # Make detection
        results = pose.process(image)
    
        # Recolor back to BGR
        image.flags.writeable = True
        image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
        
        # Detección de Pose
        mp_drawing.draw_landmarks(image, results.pose_landmarks, mp_pose.POSE_CONNECTIONS, 
                                 mp_drawing.DrawingSpec(color=(245,117,66), thickness=2, circle_radius=4),
                                 mp_drawing.DrawingSpec(color=(245,66,230), thickness=2, circle_radius=2)
                                 )
        
        # Extract landmarks
        try:
            
            row = np.array([[res.x, res.y, res.z, res.visibility] for res in results.pose_landmarks.landmark]).flatten().tolist()
            X = pd.DataFrame([row], columns=landmarks[1:])
            
            body_language_class = model.predict(X)[0]
            
            body_language_prob = model.predict_proba(X)[0]
            print(body_language_class, body_language_prob)
            
            if body_language_class or body_language_prob:
                print('predicted')
            else :
                print('prediction error')

            if body_language_class == "arriba" and body_language_prob[body_language_prob.argmax()] >= 0.7:
                current_stage = "Arriba"
                # print(current_stage)
            elif current_stage=="Arriba" and body_language_class == "abajo" and body_language_prob[body_language_prob.argmax()] >= 0.7:
                current_stage = "Abajo"
                counter +=1
            print(current_stage, counter )    
            # obtener el estado del box
            cv2.rectangle(image, (0,0), (250, 60), (245, 117, 16), -1)
            
            # Mostrar las clases
            cv2.putText(image, 'CLASS'
                        , (95,12), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1, cv2.LINE_AA)
            cv2.putText(image, str(body_language_class.split(' ')[0])
                        , (90,40), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2, cv2.LINE_AA)
            
            # Mostrar Probabilidad
            cv2.putText(image, 'PROB'
                        , (15,12), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1, cv2.LINE_AA)
            cv2.putText(image, str(round(body_language_prob[np.argmax(body_language_prob)],2))
                       , (10,40), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2, cv2.LINE_AA)
            
            # Mostrar Conteo
            cv2.putText(image, 'COUNT'
                        , (180,12), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1, cv2.LINE_AA)
            cv2.putText(image, str(counter)
                        , (195,40), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2, cv2.LINE_AA)
  
        except Exception as e:
            pass

        cv2.imshow('En vivo sin procesar', image)

        if cv2.waitKey(10) & 0xFF == ord('q'):
            break

cap.release()
cv2.destroyAllWindows()