# Generate Pose Classification Dataset

In [1]:
import cv2
import torch
from tqdm import tqdm
from os import listdir, mkdir
from os.path import isfile, join, splitext
from pathlib import Path

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Model
model = torch.hub.load('ultralytics/yolov5', 'yolov5m')  # or yolov5n - yolov5x6, custom

# Video
video_path = './videos'

# Volleyball Poses Images Path
crop_img_path = './datasets/volleyball_poses'

# Sets Image Path
sets = '/sets'

# Digs Image Path
digs = '/digs'

# Others
others = '/others'

Using cache found in /Users/hao/.cache/torch/hub/ultralytics_yolov5_master
YOLOv5 🚀 2022-12-19 Python-3.10.9 torch-1.13.1 CPU

Fusing layers... 
YOLOv5m summary: 290 layers, 21172173 parameters, 0 gradients
Adding AutoShape... 


In [3]:
video_files = [f for f in listdir(video_path) if isfile(join(video_path, f))]
print(video_files)

['20221209_154007.mp4', '20221209_153751.mp4', '20221209_153931.mp4', '20221209_153611.mp4']


In [4]:
quit = False

for video_file in video_files:

    if quit:
        break

    capture = cv2.VideoCapture(f'{video_path}/{video_file}')

    _, image = capture.read()
    ims = []
    num_frames = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
    print(num_frames)

    # Set up img numbering system
    sets_num = 0
    digs_num = 0
    others_num = 0

    # Set up file structure
    path = Path(f'{crop_img_path}/{splitext(video_file)[0]}{digs}')
    path.mkdir(parents=True, exist_ok=True)
    path = Path(f'{crop_img_path}/{splitext(video_file)[0]}{sets}')
    path.mkdir(parents=True, exist_ok=True)
    path = Path(f'{crop_img_path}/{splitext(video_file)[0]}{others}')
    path.mkdir(parents=True, exist_ok=True)

    for frame in range(1,num_frames):

        if quit:
            break

        _, image = capture.read()

        # inference
        img_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        results = model(img_rgb)
        dets = results.pandas().xyxy[0]

        if not dets.empty:

            # filter person
            dets = dets.loc[dets['name'] == 'person']

            # crop image by the bbox
            for _, det in dets.iterrows():
                crop_img = image[int(det.ymin):int(det.ymax), int(det.xmin):int(det.xmax)]

                cv2.imshow('cropped_images', crop_img)

                key = cv2.waitKey(0)

                if key == ord('q'):
                    quit = True
                    break
                elif key == ord('a'):
                    # Save as a 'dig'
                    crop_img_name = f'{crop_img_path}/{splitext(video_file)[0]}{digs}/player_{digs_num}.jpg'
                    digs_num += 1
                elif key == ord('s'):
                    # Save as a 'set'
                    crop_img_name = f'{crop_img_path}/{splitext(video_file)[0]}{sets}/player_{sets_num}.jpg'
                    sets_num += 1
                elif key == ord('d'):
                    # Save as a 'other'
                    crop_img_name = f'{crop_img_path}/{splitext(video_file)[0]}{others}/player_{others_num}.jpg'
                    others_num +=1

                cv2.imwrite(crop_img_name, crop_img)

# After the loop release the cap object
capture.release()
# Destroy all the windows
cv2.destroyAllWindows()
cv2.waitKey(1)

1522


In [None]:
def draw_bbox(image, dets):
    for _, det in dets.iterrows():
            image = cv2.rectangle(image, (int(det.xmin),int(det.ymin)), (int(det.xmax), int(det.ymax)), (0,255, 0), 2)