In [3]:
import os
from collections import deque
import cv2
import numpy as np
from typing import Generator, List, Tuple
from pathlib import Path

class VideoLoader:
    def __init__(
        self,
        root_dir: str,
        window_size: int,
        sampling_rate: float,
        inference_interval: float,
        video_extensions: tuple = ('.mp4', '.avi', '.mkv')
    ):
        self.root_dir = Path(root_dir)
        self.window_size = window_size
        self.sampling_rate = sampling_rate
        self.inference_interval = inference_interval
        self.video_extensions = video_extensions
    
    def _find_videos(self) -> List[Tuple[str, Path]]:
        """비디오 파일들을 재귀적으로 찾기"""
        videos = []
        for ext in self.video_extensions:
            for video_path in self.root_dir.rglob(f'*{ext}'):
                category = video_path.parent.name
                videos.append((category, video_path))
        return videos

    def _extract_frames(
        self, 
        video_path: Path, 
        frame_interval: int,
        start_pos: int = 0
    ) -> Generator[np.ndarray, None, None]:
        """지정된 간격으로 프레임 추출"""
        cap = cv2.VideoCapture(str(video_path))
        
        if start_pos > 0:
            cap.set(cv2.CAP_PROP_POS_FRAMES, start_pos)
        
        frame_count = start_pos
        
        while True:
            ret, frame = cap.read()
            if not ret:
                break
            
            if frame_count % frame_interval == 0:
                # BGR to RGB
                frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                yield frame
                
            frame_count += 1
            
        cap.release()

    def __iter__(self) -> Generator[Tuple[str, np.ndarray, float], None, None]:
        """비디오 클립 생성기"""
        videos = self._find_videos()
        
        for category, video_path in videos:
            cap = cv2.VideoCapture(str(video_path))
            if not cap.isOpened():
                print(f"Warning: Could not open video {video_path}")
                continue
                
            fps = cap.get(cv2.CAP_PROP_FPS)
            total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
            cap.release()
            
            # 샘플링 간격 계산
            frame_interval = int(fps * self.sampling_rate)
            # 추론 간격의 프레임 수
            inference_frames = int(fps * self.inference_interval)
            
            # 각 시작 위치에서 윈도우 크기만큼 프레임 추출
            for start_frame in range(0, total_frames, inference_frames):
                frames = []
                
                # 프레임 추출기 생성
                frame_generator = self._extract_frames(
                    video_path,
                    frame_interval,
                    start_frame
                )
                
                # 첫 프레임 가져오기
                try:
                    first_frame = next(frame_generator)
                    frames.append(first_frame)
                except StopIteration:
                    continue
                
                # 첫 프레임으로 윈도우 채우기
                while len(frames) < self.window_size:
                    frames.append(first_frame.copy())
                
                # 이후 프레임들로 윈도우 업데이트
                for idx in range(1, self.window_size):
                    try:
                        frame = next(frame_generator)
                        frames[idx] = frame
                    except StopIteration:
                        break
                
                # [T, H, W, C] 형태로 스택
                frames = np.stack(frames)
                start_time = start_frame / fps
                yield category, frames, start_time

# 사용 예시
if __name__ == "__main__":
    loader = VideoLoader(
        root_dir="/home/piawsa6000/nas192/videos/huggingface_benchmarks_dataset/Leaderboard_bench/TEST/dataset",

        window_size=16,  # 16프레임으로 구성된 윈도우
        sampling_rate=0.5,  # 0.5초 간격으로 프레임 샘플링
        inference_interval=2.0  # 2초 간격으로 추론
    )
    
    for category, frames, start_time in loader:
        # frames: numpy array with shape [T, H, W, C]
        print(f"Category: {category}")
        print(f"Frames shape: {frames.shape}")
        print(f"Start time: {start_time:.2f}s")
        print("-" * 50)
# 사용 예시

Category: fire
Frames shape: (16, 2160, 3840, 3)
Start time: 0.00s
--------------------------------------------------
Category: fire
Frames shape: (16, 2160, 3840, 3)
Start time: 2.00s
--------------------------------------------------
Category: fire
Frames shape: (16, 2160, 3840, 3)
Start time: 4.00s
--------------------------------------------------
Category: fire
Frames shape: (16, 2160, 3840, 3)
Start time: 6.00s
--------------------------------------------------


KeyboardInterrupt: 