## 피치 트래커 (Pitch tracker box) 여부를 판별하는 Classification model 추론
- name: 남기범
- project: 야구 AI 캐스터
- stack: roboflow (1.1.34), tensorflow (2.15.0), cv2 (4.8.0), numpy (1.25.2)

In [None]:
# 활용한 라이브러리 목록 및 버전
roboflow==1.1.34
tensorflow==2.15.0
cv2==4.8.0
numpy==1.25.2

# 추론모델: Real-time detection

In [None]:
# 10번째 프레임마다 모델에 넣어 판별하도록 변경 (프레임별 연산시간 20~40 ms)
# 판별 결과 Pitchbox가 있는지 없는지 판정 (probability > 0.95)
# Pitchbox가 있다고 연속 3번 판정하면 1, 아니면 0을 출력
import cv2
import numpy as np
import tensorflow as tf

def detect_pitch_tracker(video_path):
    cap = cv2.VideoCapture(video_path)

    # Get the frame rate of the video
    fps = cap.get(cv2.CAP_PROP_FPS)

    frame_count = 0

    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break

        # Analyze every 10th frame
        if frame_count % 10 == 0:
            # Preprocess the frame for the model
            resized_frame = cv2.resize(frame, (224, 224))
            input_tensor = np.expand_dims(resized_frame, axis=0) / 255.0

            # Perform classification
            prediction = model.predict(input_tensor, verbose=0)

            # Determine the result based on the prediction
            if prediction[0] > 0.95:
                result = 1
            else:
                result = 0

            # Print result for real-time verification
            print(result)

        frame_count += 1

    cap.release()
    cv2.destroyAllWindows()

In [None]:
# Load the trained model
model = tf.keras.models.load_model('best_model_dataset_5_test_1.0.h5')
# Path to your input video file
video_path = 'videoplayback_360p.mp4'

detect_pitch_tracker(video_path)

## 모델 검증: (Tracker ON = 1, OFF =0) 인 프레임들 각각 모아서 영상으로 저장

In [None]:
import cv2
import numpy as np
import tensorflow as tf
import time

# Load the trained model
model = tf.keras.models.load_model('best_model_dataset_5_test_1.0.h5')
# Path to your input video file
video_path = 'videoplayback_360p.mp4'


def detect_pitch_tracker_verbose(video_path):
    cap = cv2.VideoCapture(video_path)

    # Get the frame rate of the video
    fps = cap.get(cv2.CAP_PROP_FPS)

    frame_count = 0
    result_frames = []  # List to store frames with result 1
    non_result_frames = []  # List to store frames with result 0

    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break

        start_time = time.time()

        # Analyze every 10th frame
        if frame_count % 10 == 0:
            # Preprocess the frame for the model
            resized_frame = cv2.resize(frame, (224, 224))
            input_tensor = np.expand_dims(resized_frame, axis=0) / 255.0

            # Perform classification
            prediction = model.predict(input_tensor, verbose=0)

            # Determine the result based on the prediction
            if prediction[0] > 0.95:
                result = 1
                # Store the frame with result 1
                result_frames.append(frame)
            else:
                result = 0
                # Store the frame with result 0
                non_result_frames.append(frame)

            # Print result for real-time verification
            print(result)

            # Log the frame processing time
            end_time = time.time()
            processing_time = end_time - start_time
            print(f"Frame {frame_count}: Processing time = {processing_time:.4f} seconds")

        frame_count += 1

    cap.release()
    cv2.destroyAllWindows()

    # Save the frames with result 1 to a video file
    if result_frames:
        height, width, layers = result_frames[0].shape
        out = cv2.VideoWriter('pitch_tracker_detected.mp4', cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
        for frame in result_frames:
            out.write(frame)
        out.release()
        print("Saved frames with result 1 to pitch_tracker_detected.mp4")

    # Save the frames with result 0 to a video file
    if non_result_frames:
        height, width, layers = non_result_frames[0].shape
        out = cv2.VideoWriter('pitch_tracker_not_detected.mp4', cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
        for frame in non_result_frames:
            out.write(frame)
        out.release()
        print("Saved frames with result 0 to pitch_tracker_not_detected.mp4")

detect_pitch_tracker_verbose(video_path)

## 사용 라이브러리 및 환경 정리

In [None]:
# 활용한 라이브러리 목록 및 버전
# 최종 추론모델에선 google.colab 불필요
roboflow==1.1.34
tensorflow==2.15.0
cv2==4.8.0
numpy==1.25.2
# google.colab==1.0.0

In [None]:
!pip freeze > requirements.txt

## 참고 링크

Roboflow 프로젝트 링크: https://app.roboflow.com/2024ksebpitchboxdetection/pitchbox_classification/1/export

Youtube URL 파일로 다운로드: https://ssyoutube.com/ko34aM/youtube-video-downloader