<a href="https://colab.research.google.com/github/leeds1219/AArticle_Review/blob/main/DDD.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
youtube_ids = ['lJiyXBRglgs','q7tqLBK5HTQ','-8zo9XKvnEs','ktnmaWrxCww']

In [None]:
import torch
import pytube # $ pip install pytube
import cv2
import moviepy.editor as mpy

#from pytube import YouTube
#YouTube('https://youtu.be/9bZkp7q19f0').streams.first().download()
#yt = YouTube('http://youtube.com/watch?v=9bZkp7q19f0')
#yt.streams
#.filter(progressive=True, file_extension='mp4')
#.order_by('resolution')
#.desc()
#.first()
#.download()

class YouTubeVideoDataset(torch.utils.data.Dataset):
    def __init__(self, youtube_ids, video_dir, transform=None):
        self.youtube_ids = youtube_ids
        self.video_dir = video_dir
        self.transform = transform

    def __len__(self):
        return len(self.youtube_ids)

    def __getitem__(self, idx):
        youtube_id = self.youtube_ids[idx]
        youtube_url = 'https://youtu.be/' + youtube_id
        video_path = os.path.join(self.video_dir, f"{youtube_id}.mp4")

        # Download video if not already downloaded
        if not os.path.exists(video_path):
            YouTube(youtube_url).streams.first().download(output_path=self.video_dir, filename=f"{youtube_id}.mp4")

        # Extract central image frame and audio
        try:
            central_frame = extract_central_frame(video_path)
            audio = extract_audio(video_path)
            # Preprocess frame and audio if needed
            if self.transform:
                central_frame = self.transform(central_frame)
#                audio = preprocess_audio(audio)
            return central_frame, audio
        except Exception as e:
            print(f"Error processing video {youtube_id}: {e}")
            return None, None

def extract_central_frame(video_path, output_image_path="central_image.png"):
    """Extracts the middle frame from a video, resizes it, and saves it.

    Args:
        video_path (str): Path to the input video file.
        output_image_path (str, optional): Path to save the extracted image.
            Defaults to "central_image.png".

    Returns:
        bool: True if the frame was extracted and saved successfully, False otherwise.
    """
    cap = cv2.VideoCapture(video_path)
    try:
        # Get total frames and middle frame index
        num_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        middle_frame_index = num_frames // 2

        # Find the middle frame
        cap.set(cv2.CAP_PROP_POS_FRAMES, middle_frame_index)

        # Read the frame
        ret, frame = cap.read()

        if ret:
            # Resize the frame
            resized_frame = cv2.resize(frame, (1024, 1024))

            # Save the resized frame
            cv2.imwrite(output_image_path, resized_frame)
            return True
        else:
            print("Failed to read the middle frame from the video.")
    except cv2.error as e:
        print(f"Error processing video: {e}")
    finally:
        # Release the video capture
        cap.release()

    return False

def extract_audio(video_path, output_audio_path="audio.wav"):
    """Extracts audio from a video and saves it as a WAV file.

    Args:
        video_path (str): Path to the input video file.
        output_audio_path (str, optional): Path to save the extracted audio.
            Defaults to "audio.wav".

    Returns:
        None
    """

    try:
        clip = mpy.VideoFileClip(video_path)
        clip.audio.write_audiofile(output_audio_path)
    except mpy.error as e:
        print(f"Error extracting audio from video: {e}")

dataset = YouTubeVideoDataset(youtube_ids, "videos_dir")
dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, shuffle=True, num_workers=2)

for batch in dataloader:
    central_frames, audio = batch