# Generate Pose Classification Dataset

In [1]:
import os
import cv2
import torch
import glob
import pafy
from tqdm import tqdm
from os import listdir, mkdir
from os.path import isfile, join, splitext
from pathlib import Path
from mediapipe.python.solutions import drawing_utils as mp_drawing
from mediapipe.python.solutions import pose as mp_pose

In [2]:
# Model
model = torch.hub.load('ultralytics/yolov5', 'custom', path='./models/yolov5m.pt')  # or yolov5n - yolov5x6, custom

Using cache found in /Users/brian/.cache/torch/hub/ultralytics_yolov5_master
YOLOv5 🚀 2022-12-19 Python-3.8.12 torch-1.12.1 CPU

Fusing layers... 
YOLOv5m summary: 290 layers, 21172173 parameters, 0 gradients
Adding AutoShape... 


In [3]:
url = 'N-RYJobvTes'
# url = "8QOcj0m21As"
# url = 'nM55YiDntno'
# url ='MFYz7DCt9PI'
# url = '6um-5sBEoMY'
# url = 'PjIr-IbV5C4'
# url = 'yObzW5imnRA'
video = pafy.new(url)
yt_video = video.getbest(preftype="mp4")

# # Video
# video_path = './videos'

In [4]:
# Volleyball Poses Images Path
crop_img_path = "/Volumes/GoogleDrive-117429523964539289019/My Drive/vball_tracking/code/datasets/volleyball_poses"

# Sets Image Path
sets_path = f'{crop_img_path}/sets'

# Digs Image Path
digs_path = f'{crop_img_path}/digs'

# Spikes Image Path
spikes_path = f'{crop_img_path}/spikes'

# Others
others_path = f'{crop_img_path}/others'

In [5]:
# Set up file structure
path = Path(digs_path)
path.mkdir(parents=True, exist_ok=True)
path = Path(spikes_path)
path.mkdir(parents=True, exist_ok=True)
path = Path(sets_path)
path.mkdir(parents=True, exist_ok=True)
path = Path(others_path)
path.mkdir(parents=True, exist_ok=True)

# Set up img numbering system
if len(os.listdir(digs_path)) == 0:
    digs_num = 0
else:
    digs_num = int(sorted(glob.glob(f'{digs_path}/*'))[-1][-10:][:-4])
    
if len(os.listdir(sets_path)) == 0:
    sets_num = 0
else:
    sets_num = int(sorted(glob.glob(f'{sets_path}/*'))[-1][-10:][:-4])
       
if len(os.listdir(spikes_path)) == 0:
    spikes_num = 0
else:
    spikes_num = int(sorted(glob.glob(f'{spikes_path}/*'))[-1][-10:][:-4])
       
if len(os.listdir(others_path)) == 0:
    others_num = 0
else:
    others_num = int(sorted(glob.glob(f'{others_path}/*'))[-1][-10:][:-4])


In [7]:
quit = False

capture = cv2.VideoCapture(yt_video.url)

_, image = capture.read()
ims = []
num_frames = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
print(num_frames)

nth_frame = 5

for frame in range(1,num_frames//nth_frame):

    if quit:
        break
    
    for i in range(nth_frame):  
        _, image = capture.read()

    # inference
    img_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    results = model(img_rgb)
    dets = results.pandas().xyxy[0]

    if not dets.empty:
        # filter person
        dets = dets.loc[dets['name'] == 'person']

        # crop image by the bbox
        for _, det in dets.iterrows():
            crop_img = image[int(det.ymin):int(det.ymax), int(det.xmin):int(det.xmax)]
            
            if crop_img.shape[0] > 120 and crop_img.shape[1] > 60:
                
                output_frame = cv2.cvtColor(crop_img, cv2.COLOR_BGR2RGB)
                
                # Initialize fresh pose tracker and run it.
                with mp_pose.Pose(upper_body_only=False) as pose_tracker:
                    pose_result = pose_tracker.process(image=output_frame)
                    pose_landmarks = pose_result.pose_landmarks
                
                # Draw pose to image
                if pose_landmarks is not None:
                    mp_drawing.draw_landmarks(
                        image=output_frame,
                        landmark_list=pose_landmarks,
                        connections=mp_pose.POSE_CONNECTIONS)
                output_frame = cv2.cvtColor(output_frame, cv2.COLOR_RGB2BGR)
                
                cv2.imshow('cropped_images', output_frame)

                key = cv2.waitKey(0)

                if key == ord('q'):
                    quit = True
                    break
                elif key == ord('d'):
                    # Save as a 'dig'
                    crop_img_name = f'{digs_path}/player_{digs_num:07d}.jpg'
                    digs_num += 1
                    cv2.imwrite(crop_img_name, crop_img)
                elif key == ord('a'):
                    # Save as a 'set'
                    crop_img_name = f'{spikes_path}/player_{spikes_num:07d}.jpg'
                    spikes_num += 1
                    cv2.imwrite(crop_img_name, crop_img)
                elif key == ord('s'):
                    # Save as a 'set'
                    crop_img_name = f'{sets_path}/player_{sets_num:07d}.jpg'
                    sets_num += 1
                    cv2.imwrite(crop_img_name, crop_img)
                elif key == ord('o'):
                    # Save as a 'other'
                    crop_img_name = f'{others_path}/player_{others_num:07d}.jpg'
                    others_num +=1
                    cv2.imwrite(crop_img_name, crop_img)
                elif key == ord('p'):
                    # pass
                    pass

# After the loop release the cap object
capture.release()
# Destroy all the windows
cv2.destroyAllWindows()
cv2.waitKey(1)

433


-1

In [78]:
crop_img.shape[0]

231

In [7]:
def draw_bbox(image, dets):
    for _, det in dets.iterrows():
            image = cv2.rectangle(image, (int(det.xmin),int(det.ymin)), (int(det.xmax), int(det.ymax)), (0,255, 0), 2)