# A Step By Step Workflow to Product 4DGS Volumetric Vdeio

## **Stage 0** - Function Definition & Initialization
---

#### Function definition -->

In [None]:
import glob
import cv2
import os
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import sys
import torch
from multiprocessing import freeze_support
import subprocess
import shutil


def extract_frames(video_path, output_dir):
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    video_capture = cv2.VideoCapture(video_path)
    frame_count = 0

    while True:
        ret, frame = video_capture.read()
        if not ret:
            break
        frame_filename = os.path.join(output_dir, f"{frame_count:05d}.jpg")
        cv2.imwrite(frame_filename, frame)
        frame_count += 1

    video_capture.release()
    print(f"Extracted {frame_count} frames from {video_path} to {output_dir}.")

def move_to_folder(src, dst_path, dst_name):
    source = Path(rf"{src}")
    ext = source.suffix
    destination = Path(rf"{dst_path}/{dst_name}{ext}")
    destination.parent.mkdir(parents=True, exist_ok=True)
    source.rename(destination)

def rotate_images_in_folder(folder_path, opencv_rotate = cv2.ROTATE_90_CLOCKWISE):
    image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.tif']
    files = os.listdir(folder_path)

    rotated_count = 0

    for file in files:
        if any(file.lower().endswith(ext) for ext in image_extensions):
            file_path = os.path.join(folder_path, file)
            
            img = cv2.imread(file_path)
            
            if img is not None:
                rotated_img = cv2.rotate(img, opencv_rotate)
                
                cv2.imwrite(file_path, rotated_img)
                rotated_count += 1
    
    print(f"Completed! Rotated {rotated_count} images in {folder_path}")

def convert_jpg_to_png(folder_path):
    image_extensions = ['.jpg', '.jpeg']
    files = os.listdir(folder_path)

    converted_count = 0

    for file in files:
        if any(file.lower().endswith(ext) for ext in image_extensions):
            file_path = os.path.join(folder_path, file)
            
            img = cv2.imread(file_path)
            
            if img is not None:
                rgba_img = cv2.cvtColor(img, cv2.COLOR_BGR2BGRA)
                
                png_file_path = os.path.splitext(file_path)[0] + '.png'
                cv2.imwrite(png_file_path, rgba_img)
                
                os.remove(file_path)
                
                converted_count += 1

    print(f"Completed! Converted {converted_count} images in {folder_path} to PNG format.")

def convert_png_to_jpg(folder_path):
    image_extensions = ['.png']
    files = os.listdir(folder_path)

    converted_count = 0

    for file in files:
        if any(file.lower().endswith(ext) for ext in image_extensions):
            file_path = os.path.join(folder_path, file)
            
            img = cv2.imread(file_path)
            
            if img is not None:
                rgba_img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
                
                png_file_path = os.path.splitext(file_path)[0] + '.jpg'
                cv2.imwrite(png_file_path, rgba_img)
                
                os.remove(file_path)
                
                converted_count += 1

    print(f"Completed! Converted {converted_count} images in {folder_path} to JPG format.")

def detect_flash(images, pixel_position, pixel_size, channel_threshold: float = 5000, skip_frames: int = 150):
    x, y = pixel_position

    start_frame = -1
    end_frame = -1

    last_roi = cv2.imread(images[0])[y-pixel_size:y+pixel_size, x-pixel_size:x+pixel_size]

    plot_size = (1, 1)

    plt.figure(figsize=plot_size)
    plt.imshow(cv2.cvtColor(last_roi, cv2.COLOR_BGR2RGB))
    plt.title('ROI Calculation Preview')
    plt.axis('off')
    plt.show()

    i = 0
    while i < len(images):
        image_path = images[i]
        curr_roi = cv2.imread(image_path)[y-pixel_size:y+pixel_size, x-pixel_size:x+pixel_size]
        curr_bgr_mean = np.mean(curr_roi, axis=(0, 1))
        last_bgr_mean = np.mean(last_roi, axis=(0, 1))
        r_diff = curr_bgr_mean[2] - last_bgr_mean[2]
        r_diff_cube = r_diff * r_diff * r_diff

        last_roi = curr_roi
        
        if r_diff_cube > channel_threshold:
            if start_frame >= 0:
                plt.figure(figsize=plot_size)
                plt.imshow(cv2.cvtColor(curr_roi, cv2.COLOR_BGR2RGB))
                plt.title('End Frame Preview')
                plt.axis('off')
                plt.show()

                end_frame = i
                break

            if start_frame < 0:
                plt.figure(figsize=plot_size)
                plt.imshow(cv2.cvtColor(curr_roi, cv2.COLOR_BGR2RGB))
                plt.title('Start Frame Preview')
                plt.axis('off')
                plt.show()

                start_frame = i
                last_roi = cv2.imread(images[i + skip_frames])[y-pixel_size:y+pixel_size, x-pixel_size:x+pixel_size]
                i += skip_frames
                continue
        
        i += 1

        

    return start_frame, end_frame
    
def get_pixel_position(event, x, y, flags, param):
    scale_factor, clicked_flag, pix_x, pix_y = param
    
    if event == cv2.EVENT_LBUTTONDOWN:
        original_x = int(x / scale_factor)
        original_y = int(y / scale_factor)
        
        pix_x[0] = original_x
        pix_y[0] = original_y
        
        clicked_flag[0] = True

def roi_pixel_selection(image_path):
    # Window dimensions
    MAX_WIDTH = 1000
    MAX_HEIGHT = 800

    clicked_flag = [False]

    img = cv2.imread(image_path)

    original_height, original_width = img.shape[:2]

    scale_factor = min(MAX_WIDTH / original_width, MAX_HEIGHT / original_height)
    scale_factor = min(1.0, scale_factor) 

    display_width = int(original_width * scale_factor)
    display_height = int(original_height * scale_factor)

    display_img = cv2.resize(img, (display_width, display_height))

    window_name = 'Click on Image to Select ROI Pixel'
    cv2.namedWindow(window_name)

    roi_pix_x = [-1]
    roi_pix_y = [-1]

    cv2.setMouseCallback(window_name, get_pixel_position, (scale_factor, clicked_flag, roi_pix_x, roi_pix_y))

    cv2.imshow(window_name, display_img)

    while not clicked_flag[0]:
        if cv2.waitKey(1) != -1 or cv2.getWindowProperty(window_name, cv2.WND_PROP_VISIBLE) < 1:
            break

    cv2.destroyAllWindows()

    if roi_pix_x[0] != -1 and roi_pix_y[0] != -1:
        return (roi_pix_x[0], roi_pix_y[0])
    
def rs_align_campos(rs_path, import_path):
    cmd = [
        rs_path, "-headless",
        "-addFolder", str(import_path),
        "-align",
        "-exportXMP",
        "-quit"
    ]

    result = subprocess.run(cmd, capture_output=True, text=True)

    if result.returncode != 0:
        print(f"Error running command: {' '.join(cmd)}")
        print(f"stdout: {result.stdout}")
        print(f"stderr: {result.stderr}")
        raise RuntimeError(f"COLMAP command failed with return code {result.returncode}")
        return False
    else:
        return True
    
def rs_align_xmp(rs_path, import_path, export_path, xml_path):
    cmd = [
        rs_path, "-headless",
        "-addFolder", str(import_path),
        "-align",
        "-exportRegistration", f"{str(export_path)}/placeholder.txt", str(xml_path),
        "-quit"
    ]
    
    result = subprocess.run(cmd, capture_output=True, text=True)

    if result.returncode != 0:
        print(f"Error running command: {' '.join(cmd)}")
        print(f"stdout: {result.stdout}")
        print(f"stderr: {result.stderr}")
        raise RuntimeError(f"COLMAP command failed with return code {result.returncode}")
        return False
    else:
        return True

#### Global path definitions -->

In [None]:
# Define path for raw videos
videos_by_view = r"C:\repos\_DATASETS_\4dgs-250729\gs-take-00\views"

# Define path for soar mesh raw color feed
soar_sequence = r"C:\repos\_DATASETS_\4dgs-250729\gs-take-00\take1"

# Define the output path for training dataset
output_dataset = r"C:\repos\_DATASETS_\4dgs-250729\gs-take-00\out"


stage_1_rgb_sequence_by_view = rf"{output_dataset}\rgb_sequence_by_view"
rs_training_dataset = rf"{output_dataset}\rs_train"
postshot_training_dataset = rf"{output_dataset}\postshot_train"

#### Dependencies path definitions -->

In [None]:
# Project path to RobustVideoMatting
rvm_path = r"C:\repos\RobustVideoMatting"

sys.path.append(rvm_path)
model = torch.hub.load("PeterL1n/RobustVideoMatting", "mobilenetv3").cuda()
convert_video = torch.hub.load("PeterL1n/RobustVideoMatting", "converter")

#### Raw dataset extension definitions -->

In [None]:
image_extensions = ['*.jpg', '*.jpeg', '*.png']
video_extensions = ['*.mp4', '*.avi', '*.mov', '*.mkv', '*.wmv', '*.flv', '*.webm', '*.m4v', '*.3gp', '*.mpg', '*.mpeg']

## **Stage 1** - RGB Sequence Extraction from Raw Dataset
---

### Organize soar raw color feeds

#### Report data found -->

In [None]:
soar_frames_by_view = []
images = []
for ext in image_extensions:
    images.extend(glob.glob(os.path.join(soar_sequence, ext)))

images = sorted(images)

print(f"Found {len(images)} images in {soar_sequence}")

view_names = []
i = -1

for image_path in images:
    filename = os.path.basename(image_path)
    
    parts = filename.split('.')
    view_name = parts[0]
    
    if view_name not in view_names:
        view_names.append(view_name)
        soar_frames_by_view.append([rf"{image_path}"])
        i += 1
    else:
        soar_frames_by_view[i].append(rf"{image_path}")

for i, view in enumerate(soar_frames_by_view):
    print(f"View: {i}, Name: {view_names[i]}, Number of frames: {len(view)}")


#### Image rotation if needed -->

In [None]:
rotate_images_in_folder(soar_sequence, cv2.ROTATE_90_CLOCKWISE)

#### Organize reported soar data -->

In [None]:
soar_rgb_sequence_folders = []
for i in range(len(soar_frames_by_view)):
    folder_name = f"soar_view_{i}"
    output_path = os.path.join(stage_1_rgb_sequence_by_view, folder_name)
    os.makedirs(output_path, exist_ok=True)
    soar_rgb_sequence_folders.append(output_path)

for i, frame in enumerate(soar_frames_by_view):
    for j, view_image in enumerate(frame):
        if view_image:
            move_to_folder(view_image, soar_rgb_sequence_folders[i], f"{j:05d}")

print("Organized soar raw color feeds into respective folders.")

### Organize raw video captures

#### Report data found -->

In [None]:
videos = []
for ext in video_extensions:
    videos.extend(glob.glob(os.path.join(videos_by_view, ext)))

videos = sorted(videos)

print(f"Found {len(videos)} videos in {videos_by_view}:")
for video in videos:
    print(video)

#### Extract frames and organize -->

In [None]:
for i, video in enumerate(videos):
    print(f"Processing video {i+1}/{len(videos)}: {video}")
    output_dir = os.path.join(stage_1_rgb_sequence_by_view, f"recorded_{i:02d}")
    extract_frames(video, output_dir)

#### Data preview. Stage 1 --> Stage 2

In [None]:
rgb_sequence_by_view_folders = [d for d in os.listdir(stage_1_rgb_sequence_by_view) if os.path.isdir(os.path.join(stage_1_rgb_sequence_by_view, d))]
rgb_sequence_by_view_folders = sorted(rgb_sequence_by_view_folders)
rgb_sequence_by_view_folders = [os.path.join(stage_1_rgb_sequence_by_view, folder) for folder in rgb_sequence_by_view_folders]
stage_2_rgb_sequence_by_view_folders = rgb_sequence_by_view_folders

for i, folder in enumerate(stage_2_rgb_sequence_by_view_folders):
    images = []
    for ext in image_extensions:
        images.extend(glob.glob(os.path.join(folder, ext)))

    images = sorted(images)
    print(f"View: {i}, Path: {folder}, Number of frames: {len(images)}")

## **Stage 2** - Frame Synchronization & Alpha Extraction

#### ROI selection -->

In [None]:
led_roi = []

for i, folder in enumerate(stage_2_rgb_sequence_by_view_folders):
    images = []
    for ext in image_extensions:
        images.extend(glob.glob(os.path.join(folder, ext)))
    images = sorted(images)

    led_roi.append(roi_pixel_selection(images[0]))

for i, roi in enumerate(led_roi):
    if roi is not None:
        print(f"View: {i}, ROI pixel position: {roi}")
    else:
        print("No ROI selected.")

#### Flash frame detection -->

In [None]:
crop_frame_locs = []
frame_count = 900 # init approximate frame count
for i, view_folder in enumerate(stage_2_rgb_sequence_by_view_folders):
    images = []
    for ext in image_extensions:
        images.extend(glob.glob(os.path.join(view_folder, ext)))
    images = sorted(images)
    
    start, end = detect_flash(images, led_roi[i], 10, 10000, frame_count - 90)
    frame_count = end - start
    crop_frame_locs.append([start, end])
    if (start != -1 and end != -1):
        print(f"View: {i}: Count: {end - start}, from({os.path.basename(images[start])}) to({os.path.basename(images[end])})")
    elif(start == -1):
        frame_count = 900
        print(f"View: {i}, Start: Not Found, End: Not Found")
    elif(end == -1):
        frame_count = 900
        print(f"View: {i}, Start: {os.path.basename(images[start])}, End: Not Found")

#### Synchronize frames -->

In [None]:
for i, view_folder in enumerate(stage_2_rgb_sequence_by_view_folders):
    images = []
    for ext in image_extensions:
        images.extend(glob.glob(os.path.join(view_folder, ext)))
    images = sorted(images)

    start, end = crop_frame_locs[i]

    for j in range(0, start):
        os.remove(images[j])

    for k in range(end + 1, len(images)):
        os.remove(images[k])

    images = []
    for ext in image_extensions:
        images.extend(glob.glob(os.path.join(view_folder, ext)))
    images = sorted(images)
    # raname images after removing the first frames name format: 00000.jpg, 00001.jpg, ...
    for j, image in enumerate(images):
        new_name = f"{j:05d}.jpg"
        new_path = os.path.join(view_folder, new_name)
        os.rename(image, new_path)

    print(f"View: {i}, Updated frames: {len(images)}")


#### Save single frame for camera data -->

In [None]:
# generate training data - first frame

frames_by_views = []
for i, view_folder in enumerate(stage_2_rgb_sequence_by_view_folders):
    images = []
    for ext in image_extensions:
        images.extend(glob.glob(os.path.join(view_folder, ext)))

    images = sorted(images)
    
    frames_by_views.append(images)

first_frame_folder = os.path.join(rs_training_dataset, "frame_0")
for view_number in range(len(frames_by_views)):
    move_to_folder(frames_by_views[view_number][0], first_frame_folder, f"{view_number:05d}")

#### Convert to png if RGBA is needed -->

In [None]:
convert_jpg_to_png(first_frame_folder)

#### RGB sequence matting -->

In [None]:
# Remove background

temp_folder = rf"{output_dataset}\temp"
if not os.path.exists(temp_folder):
    os.makedirs(temp_folder)

frames_nobg_by_view_folders = []

for i, view_folder in enumerate(stage_2_rgb_sequence_by_view_folders):
    temp_sequence_folder = os.path.join(temp_folder, f"view_{i:02d}")
    convert_video(
        model,                           # The loaded model, can be on any device (cpu or cuda).
        input_source=view_folder,        # A video file or an image sequence directory.
        downsample_ratio=None,           # [Optional] If None, make downsampled max size be 512px.
        output_type='png_sequence',             # Choose "video" or "png_sequence"
        output_composition=temp_sequence_folder,    # File path if video; directory path if png sequence.
        #output_alpha=f"{output}/pha.mp4",          # [Optional] Output the raw alpha prediction.
        #output_foreground=f"{output}/fgr.mp4",     # [Optional] Output the raw foreground prediction.
        #output_video_mbps=4,             # Output video mbps. Not needed for png sequence.
        seq_chunk=15,                    # Process n frames at once for better parallelism.
        num_workers=5,                   # Only for image sequence input. Reader threads.
        progress=True                    # Print conversion progress.
    )
    frames_nobg_by_view_folders.append(temp_sequence_folder)

#### Organize RGBA sequence -->

In [None]:
temp_folder = rf"{output_dataset}\temp"

frames_nobg_by_view_folders = [d for d in os.listdir(temp_folder) if os.path.isdir(os.path.join(temp_folder, d))]
frames_nobg_by_view_folders = sorted(frames_nobg_by_view_folders)
frames_nobg_by_view_folders = [os.path.join(temp_folder, folder) for folder in frames_nobg_by_view_folders]

frames_nobg_by_views = []

for i, folder in enumerate(frames_nobg_by_view_folders):
    images = []
    for ext in image_extensions:
        images.extend(glob.glob(os.path.join(folder, ext)))

    images = sorted(images)
    frames_nobg_by_views.append(images)

for i, view in enumerate(frames_nobg_by_views):
    print(f"View: {i}, Number of frames: {len(view)}")


#### Crop synchronized RGBA sequence -->

In [None]:
min_num_of_frames = min(len(view) for view in frames_nobg_by_views)

for i, frames_by_view in enumerate(frames_nobg_by_views):
    for j, frame in enumerate(frames_by_view):
        if j < min_num_of_frames:
            dst = os.path.join(rs_training_dataset, f"frame_{j:05d}")
            move_to_folder(frame, dst, f"{i:05d}")

#### Data preview. Stage 2 --> Stage 3

In [None]:
stage_3_frame_folders = [d for d in os.listdir(rs_training_dataset) if os.path.isdir(os.path.join(rs_training_dataset, d))]
stage_3_frame_folders = sorted(stage_3_frame_folders)
stage_3_frame_folders = [os.path.join(rs_training_dataset, folder) for folder in stage_3_frame_folders]

for i, folder in enumerate(stage_3_frame_folders):
    images = []
    for ext in image_extensions:
        images.extend(glob.glob(os.path.join(folder, ext)))

    images = sorted(images)
    print(f"Frame: {folder}, Views: {len(images)}")

## **Stage 3** - Volumetrization & Reconstruction

#### Define paths to RealityScan files -->

In [None]:
# Path to RealityScan.exe
rs_exe_path = r"C:\Program Files\Epic Games\RealityScan_2.0\RealityScan.exe"

# Path to Colmap data export profile
rx_export_xml = r"C:\repos\_DATASETS_\4dgs-250729\gs-take-00\colmap_profile_img_rgba_wxmp.xml"

#### Extract camera intrinct & extrinct data -->

In [None]:
first_frame = stage_3_frame_folders[0]

if (rs_align_campos(rs_exe_path, first_frame)):
    print(f"First frame {first_frame} aligned successfully.")

    xmp_files = []
    
    xmp_files.extend(glob.glob(os.path.join(first_frame, "*.xmp")))

    if not xmp_files:
        print(f"No XMP files found in {first_frame}, skipping subsequent frames alignment.")
        exit(1)

    for xmp_file in xmp_files:
        for frame_folder in stage_3_frame_folders[1:]:
            shutil.copy2(xmp_file, frame_folder)
        
        print(f"Copied XMP file {xmp_file} to all frames.")
else:
    print(f"Failed to align first frame {first_frame.name}. Exiting.")
    exit(1)

#### Construct sparse point model -->

In [None]:
# Process subsequent frames
for i, images_path in enumerate(stage_3_frame_folders[1:]):
    
    export_path = rf"{postshot_training_dataset}/frame_{i:05d}"
    if not os.path.exists(export_path):
        os.makedirs(export_path)

    try:
        if rs_align_xmp(rs_exe_path, images_path, export_path, rx_export_xml):
            print(f"Frame {i} aligned successfully.")
        else:
            print(f"Failed to align frame {frame_folder.name}.")
    except Exception as e:
        print(f"Error processing frame {frame_folder.name}: {e}")
        continue

## Below is for testing
---

In [None]:

def extract_brightness_contrast(reference_image_path):
    image = cv2.imread(reference_image_path)
    image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    
    # Convert to grayscale for global statistics
    gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    
    # Calculate brightness (mean intensity)
    target_brightness = np.mean(gray)
    
    # Calculate contrast (standard deviation)
    target_contrast = np.std(gray)
    
    # Calculate per-channel statistics for color correction
    channel_means = [np.mean(image_rgb[:, :, i]) for i in range(3)]
    channel_stds = [np.std(image_rgb[:, :, i]) for i in range(3)]
    
    brightness_contrast_data = {
        'target_brightness': target_brightness,
        'target_contrast': target_contrast,
        'channel_means': channel_means,  # [R, G, B]
        'channel_stds': channel_stds,    # [R, G, B]
        'reference_image_path': reference_image_path
    }
    
    print(f"Reference image: {reference_image_path}")
    print(f"Target brightness: {target_brightness:.2f}")
    print(f"Target contrast: {target_contrast:.2f}")
    print(f"Channel means (R,G,B): [{channel_means[0]:.2f}, {channel_means[1]:.2f}, {channel_means[2]:.2f}]")
    print(f"Channel stds (R,G,B): [{channel_stds[0]:.2f}, {channel_stds[1]:.2f}, {channel_stds[2]:.2f}]")
    
    return brightness_contrast_data

def apply_brightness_contrast(input_image_path, brightness_contrast_data):
    # Load the input image
    image = cv2.imread(input_image_path)
    image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    
    # Get target values
    target_brightness = brightness_contrast_data['target_brightness']
    target_contrast = brightness_contrast_data['target_contrast']
    target_means = brightness_contrast_data['channel_means']
    target_stds = brightness_contrast_data['channel_stds']
    
    # Calculate current image statistics
    current_gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    current_brightness = np.mean(current_gray)
    current_contrast = np.std(current_gray)
    
    current_means = [np.mean(image_rgb[:, :, i]) for i in range(3)]
    current_stds = [np.std(image_rgb[:, :, i]) for i in range(3)]
    
    # Create corrected image
    corrected_image = image_rgb.copy().astype(np.float32)
    
    # Apply correction to each channel
    for channel in range(3):
        if current_stds[channel] > 0:  # Avoid division by zero
            # Normalize current channel to have mean=0, std=1
            corrected_image[:, :, channel] = (corrected_image[:, :, channel] - current_means[channel]) / current_stds[channel]
            
            # Scale to target std and shift to target mean
            corrected_image[:, :, channel] = corrected_image[:, :, channel] * target_stds[channel] + target_means[channel]
    
    # Clip values to valid range [0, 255]
    corrected_image = np.clip(corrected_image, 0, 255).astype(np.uint8)
    
    # Convert back to BGR for saving
    corrected_bgr = cv2.cvtColor(corrected_image, cv2.COLOR_RGB2BGR)
    
    # Replace the original image
    cv2.imwrite(input_image_path, corrected_bgr)
    
    # Print statistics
    corrected_gray = cv2.cvtColor(corrected_bgr, cv2.COLOR_BGR2GRAY)
    final_brightness = np.mean(corrected_gray)
    final_contrast = np.std(corrected_gray)
    
    print(f"\nProcessed: {input_image_path}")
    print(f"Brightness: {current_brightness:.2f} -> {final_brightness:.2f} (Target: {target_brightness:.2f})")
    print(f"Contrast: {current_contrast:.2f} -> {final_contrast:.2f} (Target: {target_contrast:.2f})")
    print(f"✓ Original image replaced with corrected version")

# Example usage:

# Step 1: Extract brightness/contrast data from reference image
reference_path = r"C:\Users\otuga\Desktop\test_frames\1.jpg"
bc_data = extract_brightness_contrast(reference_path)

print("\n" + "="*50)

image_extensions = ['*.jpg', '*.jpeg', '*.png']

image_paths = []
for ext in image_extensions:
    image_paths.extend(glob.glob(os.path.join(r"C:\Users\otuga\Desktop\test_frames", ext)))

for img_path in image_paths:
    apply_brightness_contrast(img_path, bc_data)
    print("-" * 30)

print("\n✓ All images have been processed and replaced with corrected versions!")