In [None]:
import os
import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import partial
from tqdm import tqdm




def processVideo(videoPath, imageSize, num_snippets = 6, frames_per_snippet = 16, augmentation = True):
    cap = cv2.VideoCapture(videoPath)

    # Get video properties
    fps = cap.get(cv2.CAP_PROP_FPS)
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    duration = total_frames / fps

    # Initialize an array to store keyframes
    snippets = []
    for i in range(num_snippets):
        # Calculate the start and end time for the snippet
        start_time = duration * (i / num_snippets)
        end_time = duration * ((i + 1) / num_snippets)

        # Uniformly sample time points within the snippet interval
        keyframe_time = np.linspace(start_time, end_time, frames_per_snippet)

        keyframes = []
        for j in range(frames_per_snippet):
            # Read the frame at the selected time
            frame_index = min(int(keyframe_time[j]*fps), total_frames - 2)
            cap.set(cv2.CAP_PROP_POS_FRAMES, frame_index)
            _, frame = cap.read()
            # Perform data augmentation if specified
            if augmentation:
                # Apply random cropping
                h, w, _ = frame.shape
                crop_start_x = np.random.randint(0, w // 4)
                crop_start_y = np.random.randint(0, h // 4)
                frame = frame[crop_start_y:crop_start_y + 3 * h // 4, crop_start_x:crop_start_x + 3 * w // 4, :]

                # Apply horizontal flipping
                if np.random.rand() > 0.5:
                    frame = cv2.flip(frame, 1)

                # Adjust brightness
                alpha = 1.0 + np.random.uniform(-0.2, 0.2)
                frame = np.clip(alpha * frame, 0, 255).astype(np.uint8)

            # Append the keyframe to the list
            frame = cv2.resize(frame, imageSize)
            frameTensor = torch.tensor(frame, dtype=torch.float32).view(3, *imageSize)
            keyframes.append(frameTensor)
        snippets.append(torch.stack(keyframes))

    # Release the video capture object
    cap.release()
    return torch.stack(snippets)
    