In [1]:
import os
import cv2
import pickle
import numpy as np

In [2]:
def create_HR_LR_images_from_video(
        video_path, 
        skip_seconds=(2, 2), 
        frame_interval=10, 
        scale_factor=0.5, 
        output_name="",
        class_label=None):
    """
    Extracts frames from a video file, skipping the first and last few seconds,
    crops each frame to a square (width x width) region containing the main object,
    and saves them as high-resolution (HR) and low-resolution (LR) image pairs
    in separate directories. The LR images are created by resizing the HR frames
    using the specified scale factor and interpolation method. If images already
    exist in the output directory, numbering will continue from the last image.

    The cropping tries to keep the main object (assumed to be the largest contour)
    centered in the square crop, minimizing background.

    Parameters:
        video_path (str): Path to the input video file.
        skip_seconds (tuple): Seconds to skip at the start and end of the video.
        frame_interval (int): Interval at which frames are saved (0 means save all frames).
        scale_factor (float): Factor by which to scale the images for LR.
        interpolation: OpenCV interpolation method for upscaling LR images.
        output_name (str): Name for the output directory and image files.
        class_label (int): Classification label to assign to all HR images extracted
            from this video. A mapping of HR image basename -> class_label is stored
            in images/class_labels_map.pkl.
    """
    
    def smart_square_crop(img):
        """
        Crops the image to a square (width x width) region containing the main object.
        The crop is centered on the largest contour (assumed to be the object).
        If no contour is found, crops the center square.
        """
        
        h, w = img.shape[:2]
        crop_size = min(w, h)
        gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
        
        # Threshold to find object (assume object is not background)
        _, thresh = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
        contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        
        if contours:
            # Find largest contour
            largest = max(contours, key=cv2.contourArea)
            x, y, ww, hh = cv2.boundingRect(largest)
            
            # Center crop on the object
            cx = x + ww // 2
            cy = y + hh // 2
            
            # Calculate crop box
            half = crop_size // 2
            left = max(0, cx - half)
            top = max(0, cy - half)
            
            # Ensure crop is within image
            if left + crop_size > w:
                left = w - crop_size
                
            if top + crop_size > h:
                top = h - crop_size
                
            left = max(0, left)
            top = max(0, top)
            crop = img[top:top+crop_size, left:left+crop_size]
        else:
            # Fallback: center crop
            left = (w - crop_size) // 2
            top = (h - crop_size) // 2
            crop = img[top:top+crop_size, left:left+crop_size]
        
        return crop
    
    def degrade_image(hr_image, scale_factor=0.5):
        """
        Applies a combination of realistic degradations to an HR image to generate an LR image.
        Returns (lr_image, interp_name) where interp_name is the OpenCV interpolation
        method name used, so later we can upscale with the same method.
        """
        
        if np.random.rand() < 0.7:
            ksize = np.random.choice([3, 5, 7])
            sigma = np.random.uniform(0.8, 2.0)
            hr_image = cv2.GaussianBlur(hr_image, (ksize, ksize), sigmaX=sigma)
        
        if np.random.rand() < 0.3:
            size = np.random.choice([5, 7, 9])
            kernel_motion_blur = np.zeros((size, size))
            kernel_motion_blur[int((size-1)/2), :] = np.ones(size)
            kernel_motion_blur = kernel_motion_blur / size
            hr_image = cv2.filter2D(hr_image, -1, kernel_motion_blur)
        
        interp_code_to_name = {
            cv2.INTER_LINEAR: "INTER_LINEAR",
            cv2.INTER_CUBIC: "INTER_CUBIC",
            cv2.INTER_AREA: "INTER_AREA",
            cv2.INTER_LANCZOS4: "INTER_LANCZOS4",
        }
        interp_method = np.random.choice([
            cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_LANCZOS4
        ])
        interp_name = interp_code_to_name.get(interp_method, str(interp_method))
        h, w = hr_image.shape[:2]
        lr_image = cv2.resize(
            hr_image,
            (int(w*scale_factor), int(h*scale_factor)),
            interpolation=interp_method
        )
        
        if np.random.rand() < 0.7:
            noise_std = np.random.uniform(2, 10)
            noise = np.random.normal(0, noise_std, lr_image.shape).astype(np.float32)
            lr_image = np.clip(
                lr_image.astype(np.float32) + noise, 0, 255
            ).astype(np.uint8)
        
        if np.random.rand() < 0.7:
            encode_param = [
                int(cv2.IMWRITE_JPEG_QUALITY), np.random.randint(20, 60)
            ]
            _, encimg = cv2.imencode('.jpeg', lr_image, encode_param)
            lr_image = cv2.imdecode(encimg, 1)
        return lr_image, interp_name

    if not video_path or not isinstance(video_path, str):
        raise ValueError("video_path must be a non-empty string.")
    if not os.path.exists(video_path):
        raise FileNotFoundError(f"Video file not found: {video_path}")
    if not isinstance(skip_seconds, tuple) or len(skip_seconds) != 2:
        raise ValueError("skip_seconds must be a tuple of two values (start_skip, end_skip).")
    if not isinstance(frame_interval, int) or frame_interval < 0:
        raise ValueError("frame_interval must be a non-negative integer.")
    if not isinstance(output_name, str) or not output_name:
        raise ValueError("output_name must be a non-empty string.")
    if not isinstance(scale_factor, (int, float)) or scale_factor <= 0:
        raise ValueError("scale_factor must be a positive number.")
    if not isinstance(class_label, int) or class_label < 0:
        raise ValueError("class_label must be a non-negative integer.")
    
    # Open the video file
    cap = cv2.VideoCapture(video_path)
    fps = cap.get(cv2.CAP_PROP_FPS)
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    
    # Calculate start and end frames after skipping initial/final seconds
    start_frame = int(skip_seconds[0] * fps)
    end_frame = total_frames - int(skip_seconds[1] * fps)
    if start_frame < 0 or end_frame < 0:
        raise ValueError("Skip seconds result in negative frame indices.")
    elif start_frame >= total_frames or end_frame > total_frames:
        raise ValueError("Skip seconds exceed total video duration.")
    elif start_frame >= end_frame:
        raise ValueError("Start frame must be less than end frame after skipping seconds.")

    # Create output directory
    full_HR_output_dir = f"images/HR/{output_name}"
    full_LR_output_dir = f"images/LR/{output_name}"
    os.makedirs(full_HR_output_dir, exist_ok=True)
    os.makedirs(full_LR_output_dir, exist_ok=True)
    
    # Pickle path for interpolation mapping
    interp_map_path = f"images/interpolation_map.pkl"
    if os.path.exists(interp_map_path):
        try:
            with open(interp_map_path, 'rb') as f:
                interp_map = pickle.load(f)
        except Exception:
            interp_map = {}
    else:
        interp_map = {}

    # Pickle path for classification labels mapping (HR basename -> class)
    class_map_path = f"images/class_labels_map.pkl"
    if os.path.exists(class_map_path):
        try:
            with open(class_map_path, 'rb') as f:
                class_map = pickle.load(f)
        except Exception:
            class_map = {}
    else:
        class_map = {}

    # Find the last image number in the directory
    existing_files = [
        f for f in os.listdir(full_HR_output_dir)
        if f.startswith(output_name) and f.endswith('.png')
    ]
    
    if existing_files:
        last_number = max([
            int(f.replace(output_name, "").replace(".png", ""))
            for f in existing_files
        ])
        saved_count = last_number + 1
    else:
        saved_count = 0
    
    current_frame = 0

    while True:
        ret, frame = cap.read()
        if not ret:
            break
        
        if start_frame <= current_frame < end_frame:
            if (frame_interval == 0) or ((current_frame - start_frame) % frame_interval == 0):
                # --- HR CROP ---
                cropped = smart_square_crop(frame)
                HR_filename = os.path.join(
                    full_HR_output_dir, f"{output_name}{saved_count}.png"
                )
                cv2.imwrite(HR_filename, cropped)

                # Record classification label for HR image
                class_map[os.path.basename(HR_filename)] = class_label
                
                # Create LR version of the frame + capture interpolation
                lr_image, interp_name = degrade_image(
                    cropped, scale_factor=scale_factor
                )
                LR_filename = os.path.join(
                    full_LR_output_dir, f"{output_name}{saved_count}.png"
                )
                cv2.imwrite(LR_filename, lr_image)
                # Record interpolation method used as string
                interp_map[os.path.basename(LR_filename)] = interp_name
                
                saved_count += 1
        current_frame += 1

    cap.release()

    # Persist interpolation mapping
    try:
        with open(interp_map_path, 'wb') as f:
            pickle.dump(interp_map, f)
    except Exception as e:
        print(f"Warning: failed to save interpolation map: {e}")

    # Persist class labels mapping
    try:
        with open(class_map_path, 'wb') as f:
            pickle.dump(class_map, f)
    except Exception as e:
        print(f"Warning: failed to save class labels map: {e}")

    # Print analysis
    print("=== VIDEO ANALYSIS ===")
    print(f"Total frames: {total_frames}")
    print(f"Total frames (after skipped seconds): {end_frame - start_frame}")
    print(f"Images saved in this run: {saved_count - (last_number + 1 if 'last_number' in locals() else 0)}")
    print(f"Total images in directory: {saved_count}")
    print(f"Original HR frame size (width x height): {frame_width} x {frame_height}")
    print(f"Cropped HR frame size (width x height): {cropped.shape[1]} x {cropped.shape[0]}")
    print(f"LR frame size (width x height): {int(cropped.shape[1] * scale_factor)} x {int(cropped.shape[0] * scale_factor)}")
    print(f"Saved interpolation map entries: {len(interp_map)} -> {interp_map_path}")
    print(f"Saved class labels map entries: {len(class_map)} -> {class_map_path}")

In [3]:
video_base_dir = "videos"

folder_max_videos = {
    "low_z_offset": 1,
    "high_z_offset": 0,
}

# Assign a class id per subfolder
folder_class_id = {
    "low_z_offset": 0,
    "high_z_offset": 1,
}

processed = {}
for subfolder, max_videos in folder_max_videos.items():
    subdir = os.path.join(video_base_dir, subfolder)
    if not os.path.isdir(subdir):
        print(f"Skipping missing folder: {subdir}")
        continue
    
    class_id = folder_class_id.get(subfolder)
    if class_id is None:
        print(f"Warning: no class id defined for {subfolder}; skipping.")
        continue

    videos = [f for f in os.listdir(subdir) if f.lower().endswith(".mp4")]
    videos.sort()

    count = 0
    for video_file in videos:
        if count >= max_videos:
            break
        video_path = os.path.join(subdir, video_file)
        
        name_no_ext = os.path.splitext(video_file)[0]
        parts = name_no_ext.rsplit("_", 1)
        defect_type = parts[0] if len(parts) == 2 and parts[1].isdigit() else name_no_ext

        try:
            create_HR_LR_images_from_video(
                video_path,
                skip_seconds=(2, 2),
                frame_interval=50,
                scale_factor=0.5,
                output_name=defect_type,
                class_label=class_id,
            )
            count += 1
        except Exception as e:
            print(f"Error processing {video_path}: {e}")

    processed[subfolder] = count

print("Summary per folder:", processed)

=== VIDEO ANALYSIS ===
Total frames: 1597
Total frames (after skipped seconds): 1477
Images saved in this run: 30
Total images in directory: 30
Original HR frame size (width x height): 478 x 850
Cropped HR frame size (width x height): 478 x 478
LR frame size (width x height): 239 x 239
Saved interpolation map entries: 30 -> images/interpolation_map.pkl
Saved class labels map entries: 30 -> images/class_labels_map.pkl
Summary per folder: {'low_z_offset': 1, 'high_z_offset': 0}
