<a href="https://colab.research.google.com/github/emanueltay/Optimal_Subtitle_Placement/blob/main/caption_placement_preposition_testing_1_apr.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [24]:
# !unzip /content/optimal_subtitle_copied.zip -d /

In [25]:
# !pip install ultralytics

### Detect Key Objects in the Frame

In [26]:
# ! pip install ultralytics

from ultralytics import YOLO
import cv2
import torch

# Load the trained model
# model = YOLO("/content/optimal_subtitle_copied/best_100.pt")
model = YOLO("/content/optimal_subtitle_copied/best.pt")

# # ✅ Automatically select GPU if available

device = "cuda" if torch.cuda.is_available() else "cpu"

# # ✅ Move model to the selected device
model.to(device).float()  # Use FP32 (FP16 can cause issues on CPU)

# ✅ Optimize PyTorch settings
torch.backends.cudnn.benchmark = True  # Optimize for fixed-size inputs
torch.backends.cudnn.enabled = True
torch.set_num_threads(torch.get_num_threads())  # Use optimal number of CPU threads


def detect_objects(frames):
    """
    Perform batch object detection on multiple frames.

    Parameters:
        frames (list): List of frames (each a NumPy array).

    Returns:
        list: A list containing detections for each frame.
              Each element is a NumPy array of detections.
    """
    # ✅ Run the model in batch mode
    results = model(frames, batch=len(frames), verbose=False)  # Run batch inference 640

    # ✅ Extract detections for each frame
    batch_detections = []
    for result in results:
        detections = result.boxes.data.cpu().numpy()  # Convert detections to NumPy array
        batch_detections.append(detections)

    return batch_detections  # List of detections per frame

### Define Safe Zones for Subtitle Placement

In [27]:
import json

# ✅ Global in-memory cache for dynamic & shifted safe zones
safe_zone_cache = {}
used_safe_zones = {}  # Dictionary to store all assigned safe zones

def calculate_safe_zone_with_prepositions_test_new(frame_width, frame_height, detections, pre_positions, subtitle_height, margin, shift_x=20):
    """
    Calculate the safe zone for subtitle placement using pre-defined positions.
    If blocked, it attempts to shift left/right before moving vertically.
    If no predefined position works, falls back to a dynamic safe zone and caches it in memory.

    Returns:
        tuple: (position_name, coordinates)
    """

    # def zones_overlap(zone1, zone2):
    #     """Checks if two zones overlap."""
    #     x1a, y1a, x2a, y2a = zone1
    #     x1b, y1b, x2b, y2b = zone2
    #     return not (x2a < x1b or x1a > x2b or y2a < y1b or y1a > y2b)

    def zones_overlap(zone1, zone2, threshold=0.1):  # 10% threshold
      """Checks if two zones overlap significantly (IoA)."""
      x1a, y1a, x2a, y2a = zone1
      x1b, y1b, x2b, y2b = zone2

      # Calculate overlap rectangle
      inter_x1 = max(x1a, x1b)
      inter_y1 = max(y1a, y1b)
      inter_x2 = min(x2a, x2b)
      inter_y2 = min(y2a, y2b)

      # Check if there's any intersection
      if inter_x2 <= inter_x1 or inter_y2 <= inter_y1:
          return False

      inter_area = (inter_x2 - inter_x1) * (inter_y2 - inter_y1)
      zone_area = (x2a - x1a) * (y2a - y1a)

      # Intersection over Area (IoA)
      iou = inter_area / zone_area

      return iou > threshold

    # ✅ Step 1: Try Predefined Positions
    for position_name, position in sorted(pre_positions.items(), key=lambda x: x[1].get("priority", 0), reverse=True):
        x1, y1, x2, y2 = position["coordinates"]

        # if not any(zones_overlap((x1, y1, x2, y2), detection[:4]) for detection in detections):
        if not any(zones_overlap((x1, y1, x2, y2), detection[:4], threshold=0.1) for detection in detections):
            used_safe_zones[position_name] = {
                "coordinates": [x1, y1, x2, y2]
            }
            return (position_name, (x1, y1, x2, y2))

        # ✅ Step 2: Try shifting left and right
        min_width = max(0.6 * frame_width, 600)
        for shift_dir in ["left", "right"]:
            attempts = 0
            shift_value = shift_x
            while attempts < 10:
                if shift_dir == "left":
                    new_x1 = max(0, x1 - shift_value)
                    new_x2 = max(min_width, x2 - shift_value)
                else:
                    new_x1 = min(frame_width - min_width, x1 + shift_value)
                    new_x2 = min(frame_width, x2 + shift_value)

                shifted_zone = (new_x1, y1, new_x2, y2)
                if new_x2 > new_x1 and not any(zones_overlap(shifted_zone, detection[:4]) for detection in detections):
                    shifted_name = f"shifted_{position_name}"
                    used_safe_zones[shifted_name] = {
                        "coordinates": [new_x1, y1, new_x2, y2]
                    }
                    return (shifted_name, shifted_zone)

                shift_value *= 1.5
                attempts += 1

    # ✅ Step 3: Try Cached Safe Zone (if exists)
    cache_key = (frame_width, frame_height, tuple(tuple(d) for d in detections))
    if cache_key in safe_zone_cache:
        return safe_zone_cache[cache_key]

    # ✅ Step 4: Dynamic fallback position (bottom-up shifting)
    fallback_position_name = "dynamic_position"
    proposed = (0, frame_height - subtitle_height - margin, frame_width, frame_height - margin)
    while True:
        if all(not zones_overlap(proposed, detection[:4]) for detection in detections):
            safe_zone_cache[cache_key] = (fallback_position_name, proposed)
            used_safe_zones[fallback_position_name] = {
                "coordinates": list(proposed)
            }
            return (fallback_position_name, proposed)

        # Shift upwards
        x1, y1, x2, y2 = proposed

        new_y1 = y1 - subtitle_height - margin
        new_y2 = y2 - subtitle_height - margin

        if new_y1 < 0:
            break
        proposed = (x1, new_y1, x2, new_y2)

    # ✅ Step 5: Final Fallback to Top
    fallback_position_name = "fallback_top"
    final_safe_zone = (0, margin, frame_width, subtitle_height + margin)
    safe_zone_cache[cache_key] = (fallback_position_name, final_safe_zone)
    used_safe_zones[fallback_position_name] = {
        "coordinates": list(final_safe_zone)
    }
    return (fallback_position_name, final_safe_zone)

def get_used_safe_zones():
    """
    Returns the used safe zones as a dictionary without the "priority" field.
    This can be directly used for updating the TTML layout.
    """
    return {
        key: {"coordinates": value["coordinates"]}
        for key, value in used_safe_zones.items()
    }

### Subtitle size and margin

In [28]:
def get_subtitle_size(frame_height):
    """
    Dynamically calculate subtitle height and margin based on frame resolution.

    Parameters:
        frame_height (int): Height of the video frame.

    Returns:
        tuple: (subtitle_height, margin)
    """
    subtitle_height = max(0.15 * frame_height, 30)  # Minimum 18px for readability
    margin = max(0.02 * frame_height, 5)  # Minimum 5px to avoid text touching edges

    return int(subtitle_height), int(margin)

### Complete Pipeline for frames batch

In [29]:
import json

def get_pixel_pre_positions_from_json(json_path, frame_width, frame_height):
    """
    Reads percentage-based layout from a JSON file and converts to pixel coordinates.

    Args:
        json_path (str): Path to the JSON file containing percentages.
        frame_width (int): Width of the video frame.
        frame_height (int): Height of the video frame.

    Returns:
        dict: Dictionary of region names mapped to pixel coordinates and priority.
    """
    with open(json_path, 'r') as f:
        percentage_data = json.load(f)

    pixel_positions = {}
    for region, data in percentage_data.items():
        x1_pct, y1_pct, x2_pct, y2_pct = data["percentages"]
        pixel_positions[region] = {
            "coordinates": [
                int(x1_pct * frame_width),
                int(y1_pct * frame_height),
                int(x2_pct * frame_width),
                int(y2_pct * frame_height)
            ],
            "priority": data["priority"]
        }

    return pixel_positions

In [30]:
import json
from collections import Counter, deque
import numpy as np
import cv2

safe_zone_history = deque(maxlen=4)  # Stores past safe zones for consistency (4)
region_json_path = "/content/optimal_subtitle_copied/news_video_subtitle_positions.json"

def process_frames_batch_3fps_processed(frames, process_fps=3, video_fps=30):
    """
    Process a batch of frames at 3 FPS:
    - Detects objects in frames sampled at 3 FPS
    - Computes one safe zone for the batch
    - Overlays subtitles using the same safe zone

    Parameters:
        frames (list): List of frames (NumPy arrays).
        subtitles (list): List of subtitles corresponding to each frame.
        video_fps (int): Original FPS of the video.
        process_fps (int): FPS at which YOLO will run.

    Returns:
        list: Processed frames with subtitles.
    """

    # ✅ Step 1: Select Frames at 3 FPS for YOLO Detection
    frame_interval = video_fps // process_fps  # Process every `frame_interval` frames
    selected_indices = list(range(0, len(frames), frame_interval))

    if not selected_indices:  # Prevent empty selection
        selected_indices = [0]  # Process at least one frame

    selected_frames = [frames[i] for i in selected_indices]  # Sampled frames for YOLO

    # ✅ Step 2: Batch Object Detection on Selected Frames
    batch_detections = detect_objects(selected_frames)  # YOLO runs only on sampled frames
    frame_height, frame_width = frames[0].shape[:2]
    # print(frame_height, frame_width)

    pre_positions = get_pixel_pre_positions_from_json(region_json_path, frame_width, frame_height)

    # # ✅ Load Predefined Safe Zones (JSON file loaded once)
    # with open("/content/optimal_subtitle_copied/news_video_subtitle_positions.json", "r") as file:
    #     pre_positions = json.load(file).get(f"{frame_width}x{frame_height}", {})

    # ✅ Step 3: Compute Safe Zone for Each Sampled Frame
    subtitle_height, margin = get_subtitle_size(frame_height)

    # ✅ Step 3: Collect Safe Zone Positions for Each Frame in Batch
    batch_safe_zones = [
        calculate_safe_zone_with_prepositions_test_new(
        #  calculate_safe_zone_with_prepositions_numpy(
            frame_width, frame_height, batch_detections[i], pre_positions, subtitle_height, margin
        )[0]  # ✅ Extract only the position name
        for i in range(len(selected_frames))
    ]

    # ✅ Step 4: Determine the Most Used Safe Zone
    combined_safe_zones = batch_safe_zones + list(safe_zone_history)  # Merge with history
    # combined_safe_zones = batch_safe_zones
    # print(f"✅ Combined Safe Zones: {combined_safe_zones}")
    print(combined_safe_zones)
    zone_counts = Counter(combined_safe_zones)  # Count occurrences
    # zone_counts = Counter(batch_safe_zones)

    # ✅ Assign the most frequently used zone
    if zone_counts:
        most_common_zones = zone_counts.most_common()  # Get all zones sorted by frequency
        highest_frequency = most_common_zones[0][1]  # Find the highest occurrence count

        # ✅ Get all zones with the highest frequency
        top_zones = [zone for zone, count in most_common_zones if count == highest_frequency]

        # ✅ If there's a tie, choose the last used zone from combined_safe_zones
        final_safe_zone = next((zone for zone in reversed(combined_safe_zones) if zone in top_zones), "bottom")
    else:
        final_safe_zone = "bottom"  # ✅ Default fallback
    # print(f"✅ Final Safe Zone: {final_safe_zone}")


    # ✅ Store the final safe zone for future frames
    safe_zone_history.append(final_safe_zone)

    return final_safe_zone

In [31]:
import json
from collections import Counter, deque
import numpy as np
import cv2

safe_zone_history = deque(maxlen=2)
region_json_path = "/content/optimal_subtitle_copied/news_video_subtitle_positions.json"

def process_frames_batch_3fps_processed_test(frames, process_fps=3, video_fps=30):
    frame_interval = video_fps // process_fps
    selected_indices = list(range(0, len(frames), frame_interval))
    if not selected_indices:
        selected_indices = [0]
    selected_frames = [frames[i] for i in selected_indices]

    batch_detections = detect_objects(selected_frames)
    frame_height, frame_width = frames[0].shape[:2]
    pre_positions = get_pixel_pre_positions_from_json(region_json_path, frame_width, frame_height)
    subtitle_height, margin = get_subtitle_size(frame_height)

    batch_safe_zones = [
        calculate_safe_zone_with_prepositions_test_new(
            frame_width, frame_height, batch_detections[i], pre_positions, subtitle_height, margin
        )[0]
        for i in range(len(selected_frames))
    ]

    # ✅ Detect sudden change in safe zones
    if len(set(batch_safe_zones)) > 1:
        final_safe_zone = batch_safe_zones[-1]  # Use most recent zone
    else:
        final_safe_zone = batch_safe_zones[0]  # All zones consistent

    safe_zone_history.append(final_safe_zone)
    return final_safe_zone

In [32]:
import xml.etree.ElementTree as ET

def print_ttml_with_updated_regions(ttml_file_path, subtitle_data):
    """
    Prints the TTML <p> elements with updated regions, removing 'region' if it's None.

    Parameters:
        ttml_file_path (str): Path to the TTML file.
        subtitle_data (list): List of subtitles in the format:
            [{"start": start_time, "end": end_time, "text": text, "region": "region_id"}]
    """

    # ✅ Load TTML File
    tree = ET.parse(ttml_file_path)
    root = tree.getroot()
    ns = {'ttml': 'http://www.w3.org/ns/ttml'}

    # ✅ Find All <p> Elements (Subtitles) and Update Regions
    for p in root.findall('.//ttml:p', ns):
        start_time = convert_ttml_time_to_seconds(p.attrib.get("begin", "0.0s"))
        end_time = convert_ttml_time_to_seconds(p.attrib.get("end", "0.0s"))

        # ✅ Find Matching Subtitle
        matched_subtitle = next((sub for sub in subtitle_data if sub["start"] <= start_time <= sub["end"]), None)

        if matched_subtitle:
            if matched_subtitle["region"] is not None:
                p.attrib["region"] = matched_subtitle["region"]  # ✅ Assign Correct Region
            elif "region" in p.attrib:
                del p.attrib["region"]  # ✅ Remove `region` if it's None

    # ✅ Print Updated TTML Content
    updated_ttml = ET.tostring(root, encoding="utf-8").decode("utf-8")
    print(updated_ttml)  # ✅ Print instead of writing to a file

In [33]:
import xml.etree.ElementTree as ET
import json
import math

def generate_updated_ttml(ttml_file_path, output_ttml_path, json_data, subtitle_data, frame_width, frame_height):
    """
    Generates a new TTML file with updated subtitle styles, layout regions, and assigned regions for subtitles.

    Parameters:
        ttml_file_path (str): Path to the input TTML file.
        output_ttml_path (str): Path to save the updated TTML file.
        json_data (dict): JSON data containing subtitle positions.
        subtitle_data (list): List of subtitles with timestamps and regions.
        frame_width (int): Width of the video frame.
        frame_height (int): Height of the video frame.

    Returns:
        None (Writes updated TTML file to disk)
    """

    # ✅ Load TTML File
    tree = ET.parse(ttml_file_path)
    root = tree.getroot()

    # # ✅ Define Namespace for TTML
    # ns = {'ttml': 'http://www.w3.org/ns/ttml'}
    # ET.register_namespace("", ns["ttml"])

    # # ✅ Find or Create the <head> Element
    # head_element = root.find('.//ttml:head', ns)
    # if head_element is None:
    #     head_element = ET.Element("{http://www.w3.org/ns/ttml}head")
    #     root.insert(0, head_element)  # Insert <head> at the top

    # ✅ Load TTML File
    tree = ET.parse(ttml_file_path)
    root = tree.getroot()

    # ✅ Preserve All Original Root Attributes (Ensuring All Namespaces Remain)
    root_attribs = root.attrib.copy()  # Copy attributes before modification

    # ✅ Extract Namespace (from <tt> root tag)
    namespace_uri = root.tag.split("}")[0].strip("{")  # Extracts URI from "{namespace}tag"
    ns = {"ttml": namespace_uri} if namespace_uri else {}

    # ✅ Restore All Root Attributes (Explicitly Add Missing Namespaces)
    root.attrib.clear()
    root.attrib.update(root_attribs)  # ✅ Restore original attributes

    # ✅ Ensure `xmlns:tts` is Explicitly Set (if missing)
    if "xmlns:tts" not in root.attrib:
        root.set("xmlns:tts", "http://www.w3.org/ns/ttml#styling")  # ✅ Add missing styling namespace

    # ✅ Find or Create the <head> Element (Using Preserved Namespace)
    head_element = root.find(f'.//{{{namespace_uri}}}head', ns)
    if head_element is None:
        head_element = ET.Element(f"{{{namespace_uri}}}head")
        root.insert(0, head_element)  # Insert <head> as the first child

    # ✅ Find or Create the <styling> Element
    styling_element = head_element.find('.//ttml:styling', ns)
    if styling_element is None:
        styling_element = ET.Element("{http://www.w3.org/ns/ttml}styling")
        head_element.insert(0, styling_element)  # Insert before layout

    # ✅ Remove Any Existing <style> Elements (Always Replacing)
    for style in styling_element.findall('.//ttml:style', ns):
        styling_element.remove(style)

    # ✅ Define and Add the New Style Element
    new_style = ET.Element("{http://www.w3.org/ns/ttml}style", attrib={
        "xml:id": "s0",
        "tts:color": "white",
        "tts:fontSize": "70%",
        "tts:fontFamily": "sansSerif",
        "tts:backgroundColor": "black",
        "tts:displayAlign": "center",
        "tts:wrapOption": "wrap"
    })
    styling_element.append(new_style)

    # ✅ Find or Create the <layout> Element
    layout_element = head_element.find('.//ttml:layout', ns)
    if layout_element is None:
        layout_element = ET.Element("{http://www.w3.org/ns/ttml}layout")
        head_element.append(layout_element)

    # ✅ Remove ALL existing <region> elements inside <layout>
    for region in list(layout_element):
        layout_element.remove(region)

    # ✅ Insert Subtitle Regions from JSON
    for region_name, region_data in json_data.items():
        x1, y1, x2, y2 = region_data["coordinates"]

        print(frame_height,frame_width)

        # Convert absolute pixel values to TTML percentages
        origin_x = (x1 / frame_width) * 100
        origin_y = (y1 / frame_height) * 100
        extent_x = ((x2 - x1) / frame_width) * 100
        extent_y = ((y2 - y1) / frame_height) * 100

        # Construct the region XML element
        region_element = ET.Element("{http://www.w3.org/ns/ttml}region", attrib={
            "tts:origin": f"{math.ceil(origin_x)}% {math.ceil(origin_y)}%",
            "tts:extent": f"{math.ceil(extent_x)}% {math.ceil(extent_y)}%",
            "tts:displayAlign": "center",
            "tts:textAlign": "center",
            "xml:id": region_name
        })

        # Add to <layout>
        layout_element.append(region_element)

    # ✅ Find All <p> Elements (Subtitles) and Update Regions
    for p in root.findall('.//ttml:p', ns):
        start_time = convert_ttml_time_to_seconds(p.attrib.get("begin", "0.0s"))
        end_time = convert_ttml_time_to_seconds(p.attrib.get("end", "0.0s"))

        # ✅ Find Matching Subtitle
        matched_subtitle = next((sub for sub in subtitle_data if sub["start"] <= start_time <= sub["end"]), None)

        if matched_subtitle:
            if matched_subtitle["region"] is not None:
                p.attrib["region"] = matched_subtitle["region"]  # ✅ Assign Correct Region
            elif "region" in p.attrib:
                del p.attrib["region"]  # ✅ Remove `region` if it's None

    # ✅ Save Updated TTML File
    # tree.write(output_ttml_path, encoding="utf-8", xml_declaration=True)
    # print(f"✅ Updated TTML file saved: {output_ttml_path}")
    updated_ttml = ET.tostring(root, encoding="utf-8").decode("utf-8")
    print(updated_ttml)

## Testing the Code

### Integrate with srt file and video fps

In [34]:
# !pip install pysrt

In [35]:
def get_subtitles_for_frames(frame_times, subtitle_data):
    """
    Retrieves subtitle texts for a batch of frame timestamps.

    Parameters:
        frame_times (list): List of timestamps (in seconds).
        subtitle_data (list): List of subtitles in the format:
            [{"start": start_time, "end": end_time, "text": text, "region": region}, ...]

    Returns:
        list: List of subtitle texts corresponding to each frame timestamp.
    """
    frame_subtitles = []

    for time in frame_times:
        subtitle_text = ""  # Default to empty string

        for subtitle in subtitle_data:
            if subtitle["start"] <= time <= subtitle["end"]:
                subtitle_text = subtitle["text"].replace("\n", " ")  # Remove newlines
                break  # Stop once we find a match

        frame_subtitles.append(subtitle_text)

    return frame_subtitles

In [36]:
import subprocess

def get_video_fps(video_path):
    """Extracts FPS from a video using FFmpeg."""
    cmd = ["ffmpeg", "-i", video_path]

    # ✅ Use stdout and stderr explicitly
    result = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)

    # ✅ Parse FPS from FFmpeg output
    for line in result.stderr.split("\n"):
        if "Stream" in line and "Video" in line and "fps" in line:
            fps_value = float(line.split("fps")[0].strip().split()[-1])  # Extract FPS
            return fps_value

    return 30  # Default to 30 FPS if not found

In [37]:
## combined srt and ttml timestamp

import pysrt
import xml.etree.ElementTree as ET
import re

def get_subtitle_timestamps(subtitle_file, file_type="auto"):
    """
    Extracts subtitle timestamps from an SRT or TTML file.

    Parameters:
        subtitle_file (str): Path to the subtitle file.
        file_type (str): "srt" for SRT, "ttml" for TTML, or "auto" to detect from extension.

    Returns:
        list of tuples: Each tuple contains (start_time, end_time) in seconds.
    """
    # Auto-detect file type
    if file_type == "auto":
        if subtitle_file.endswith(".srt"):
            file_type = "srt"
        elif subtitle_file.endswith(".ttml") or subtitle_file.endswith(".xml"):
            file_type = "ttml"
        else:
            raise ValueError("Unsupported subtitle file format. Use 'srt' or 'ttml'.")

    # Process SRT
    if file_type == "srt":
        return get_srt_timestamps(subtitle_file)

    # Process TTML
    elif file_type == "ttml":
        return get_ttml_timestamps(subtitle_file)

    else:
        raise ValueError("Invalid file type specified. Use 'srt' or 'ttml'.")

def get_srt_timestamps(srt_file):
    """Extracts subtitle timestamps from an SRT file."""
    subs = pysrt.open(srt_file)
    subtitle_timestamps = []

    for sub in subs:
        start_time = sub.start.hours * 3600 + sub.start.minutes * 60 + sub.start.seconds + sub.start.milliseconds / 1000
        end_time = sub.end.hours * 3600 + sub.end.minutes * 60 + sub.end.seconds + sub.end.milliseconds / 1000
        subtitle_timestamps.append((start_time, end_time))

    return subtitle_timestamps

def get_ttml_timestamps(ttml_file):
    """Extracts subtitle timestamps from a TTML file."""
    tree = ET.parse(ttml_file)
    root = tree.getroot()
    ns = {'ttml': 'http://www.w3.org/ns/ttml'}

    subtitle_timestamps = []

    for p in root.findall('.//ttml:p', ns):
        start_time = p.attrib.get("begin", "0.0s")
        end_time = p.attrib.get("end", "0.0s")

        start_seconds = convert_ttml_time_to_seconds(start_time)
        end_seconds = convert_ttml_time_to_seconds(end_time)

        subtitle_timestamps.append((start_seconds, end_seconds))

    return subtitle_timestamps

# def convert_ttml_time_to_seconds(ttml_time):
#     """
#     Converts TTML time format (HH:MM:SS.mmm or MM:SS.mmm or SS.mmm or SS.mmm's') to seconds.

#     Parameters:
#         ttml_time (str): TTML-formatted time.

#     Returns:
#         float: Time in seconds.
#     """
#     ttml_time = ttml_time.rstrip('s')  # Remove trailing 's' if present
#     parts = ttml_time.split(":")

#     if len(parts) == 3:  # HH:MM:SS.mmm
#         hours, minutes, seconds = map(float, parts)
#     elif len(parts) == 2:  # MM:SS.mmm
#         hours, minutes, seconds = 0, *map(float, parts)
#     else:  # SS.mmm
#         hours, minutes, seconds = 0, 0, float(parts[0])

#     return hours * 3600 + minutes * 60 + seconds

def convert_ttml_time_to_seconds(ttml_time):
    """
    Converts TTML time format (HH:MM:SS.mmm, MM:SS.mmm, SS.mmm, or SS,mmm) to seconds.

    Parameters:
        ttml_time (str): TTML-formatted time.

    Returns:
        float: Time in seconds (with millisecond precision).
    """

    # ✅ Remove trailing 's' if present and replace ',' with '.'
    ttml_time = ttml_time.rstrip('s').replace(',', '.')

    # ✅ Use regex to extract time components
    match = re.match(r"(?:(\d+):)?(?:(\d+):)?(\d+)(?:\.(\d+))?", ttml_time)

    if not match:
        raise ValueError(f"Invalid TTML time format: {ttml_time}")

    # ✅ Extract components safely
    hours = int(match.group(1)) if match.group(1) else 0
    minutes = int(match.group(2)) if match.group(2) else 0
    seconds = int(match.group(3)) if match.group(3) else 0
    milliseconds = int(match.group(4)) if match.group(4) else 0

    return hours * 3600 + minutes * 60 + seconds + milliseconds / 1000.0

In [38]:
import xml.etree.ElementTree as ET
import pysrt
import os

def parse_subtitle_file(file_path):
    """
    Parses either an SRT or TTML subtitle file and extracts subtitles.

    Parameters:
        file_path (str): Path to the subtitle file.

    Returns:
        list: List of subtitles in the format:
            [
                {"start": start_time, "end": end_time, "text": "subtitle text", "region": "region_id"}
            ]
    """
    extension = os.path.splitext(file_path)[-1].lower()
    subtitle_data = []

    if extension == ".srt":
        subs = pysrt.open(file_path)
        for sub in subs:
            start_time = (
                sub.start.hours * 3600 + sub.start.minutes * 60 + sub.start.seconds + sub.start.milliseconds / 1000
            )
            end_time = (
                sub.end.hours * 3600 + sub.end.minutes * 60 + sub.end.seconds + sub.end.milliseconds / 1000
            )
            text = sub.text.replace("\n", " ")  # Convert newlines to spaces

            subtitle_data.append({
                "start": start_time,
                "end": end_time,
                "text": text,
                "region": None  # SRT doesn't support regions
            })

    elif extension == ".ttml":
        # ✅ Register TTML Namespaces
        ET.register_namespace('', "http://www.w3.org/ns/ttml")  # Default TTML namespace
        ET.register_namespace('ttp', "http://www.w3.org/ns/ttml#parameter")
        ET.register_namespace('tts', "http://www.w3.org/ns/ttml#styling")
        ET.register_namespace('ttm', "http://www.w3.org/ns/ttml#metadata")

        # ✅ Parse TTML File
        tree = ET.parse(file_path)
        root = tree.getroot()
        ns = {'ttml': 'http://www.w3.org/ns/ttml'}

        # ✅ Extract Subtitle Data
        for p in root.findall('.//ttml:p', ns):
            start_time = convert_ttml_time_to_seconds(p.attrib.get("begin", "0.0s"))
            end_time = convert_ttml_time_to_seconds(p.attrib.get("end", "0.0s"))
            text = " ".join(p.itertext()).strip()
            region = p.attrib.get("region", None)

            subtitle_data.append({
                "start": start_time,
                "end": end_time,
                "text": text,
                "region": region
            })

    else:
        raise ValueError("Unsupported subtitle format. Only SRT and TTML are supported.")

    return subtitle_data

In [42]:
import os
import cv2
import time

# ✅ Optional Resize Parameter (None = no resize)
# resize_resolution = (640, 360)  # Example: downscale to 640x360; set to None to disable resizing
# resize_resolution = (1280, 720)
# resize_resolution = (854, 480)
resize_resolution = None

# ✅ Start Load Timer
start_load_time = time.time()

# ✅ Define file paths
video_input_path = "/content/optimal_subtitle_copied/ABC_News.mp4"
file_path = "/content/optimal_subtitle_copied/TTML_file/ABC_News.ttml"
output_path = "/content/output.ttml"

# ✅ Load Video Metadata
fps = get_video_fps(video_input_path)
print(f"✅ Corrected FPS: {fps}")

subtitle_data = parse_subtitle_file(file_path)
subtitle_timestamps = get_subtitle_timestamps(file_path)

cap = cv2.VideoCapture(video_input_path)

# ✅ Original Resolution for TTML generation
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
print(f"🎞 Frame Dimensions: {frame_width}x{frame_height}")

total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
video_duration = total_frames / fps

# ✅ End Load Timer
end_load_time = time.time()
load_duration = end_load_time - start_load_time
print(f"📦 Total Load Time: {load_duration:.2f} seconds")

# ✅ Start Run Timer
start_run_time = time.time()

frame_buffer = []
timestamp_buffer = []
subtitle_index = 0
frame_number = 0
total_video_read_time = 0
total_yolo_time = 0
total_region_assign_time = 0

while cap.isOpened():
    read_start = time.time()
    ret, frame = cap.read()
    read_end = time.time()

    total_video_read_time += (read_end - read_start)
    if not ret:
        break

    # ✅ Resize frame for faster processing
    if resize_resolution:
        frame = cv2.resize(frame, resize_resolution)
        frame_width, frame_height = resize_resolution

    frame_time = frame_number / fps

    if subtitle_index < len(subtitle_data):
        current_subtitle = subtitle_data[subtitle_index]
        current_start = current_subtitle["start"]
        current_end = current_subtitle["end"]

        if current_start <= frame_time <= current_end:
            frame_buffer.append(frame)
            timestamp_buffer.append(frame_time)

        if frame_time > current_end:
            if frame_buffer:
                subtitles = [current_subtitle]

                detect_start = time.time()
                processed_frames = process_frames_batch_3fps_processed_test(frame_buffer)
                detect_end = time.time()
                total_yolo_time += (detect_end - detect_start)

                region_start = time.time()
                current_subtitle["region"] = processed_frames
                region_end = time.time()
                total_region_assign_time += (region_end - region_start)

                frame_buffer.clear()
                timestamp_buffer.clear()
                subtitle_index += 1

    frame_number += 1

# ✅ Final batch
if frame_buffer and subtitle_index < len(subtitle_data):
    subtitles = [subtitle_data[subtitle_index]]
    detect_start = time.time()
    processed_frames = process_frames_batch_3fps_processed_test(frame_buffer)
    detect_end = time.time()
    total_yolo_time += (detect_end - detect_start)
    subtitle_data[subtitle_index]["region"] = processed_frames

cap.release()

# ✅ End Run Timer
end_run_time = time.time()
run_duration = end_run_time - start_run_time
run_minutes, run_seconds = divmod(run_duration, 60)

# ✅ Generate TTML Layout using original resolution
ttml_gen_start = time.time()
layout = get_used_safe_zones()
generate_updated_ttml(file_path, output_path, layout, subtitle_data, frame_width, frame_height)
ttml_gen_end = time.time()
ttml_generation_time = ttml_gen_end - ttml_gen_start

minutes, seconds = divmod(video_duration, 60)

# ✅ Final Timing Summary
print("\n📊 PROFILING SUMMARY")
print(f"🎬 Video Duration: {int(minutes)}m {int(seconds)}s")
print(f"📦 Load Time: {load_duration:.2f}s")
print(f"🚀 Run Time: {int(run_minutes)}m {int(run_seconds)}s")
print(f"📥 Video Read Time: {total_video_read_time:.2f}s")
print(f"🔍 YOLO Detection Time: {total_yolo_time:.2f}s")
print(f"📐 Region Assignment Time: {total_region_assign_time:.2f}s")
print(f"📝 TTML Generation Time: {ttml_generation_time:.2f}s")
print(f"✅ Output TTML saved to: {output_path}")

✅ Corrected FPS: 29.97
🎞 Frame Dimensions: 1280x720
📦 Total Load Time: 0.12 seconds
720 1280
720 1280
720 1280
720 1280
720 1280
720 1280
720 1280
720 1280
720 1280
720 1280
720 1280
720 1280
720 1280
720 1280
720 1280
720 1280
720 1280
<tt xmlns="http://www.w3.org/ns/ttml" xmlns:ttm="http://www.w3.org/ns/ttml#metadata" xmlns:ttp="http://www.w3.org/ns/ttml#parameter" ttp:timeBase="media" xml:lang="en" xmlns:tts="http://www.w3.org/ns/ttml#styling">
  <head>
    <metadata>
      <ttm:title />
    </metadata>
    <styling>
      <style xml:id="s0" tts:color="white" tts:fontSize="70%" tts:fontFamily="sansSerif" tts:backgroundColor="black" tts:displayAlign="center" tts:wrapOption="wrap" /></styling>
    <layout>
      <region tts:origin="10% 70%" tts:extent="80% 16%" tts:displayAlign="center" tts:textAlign="center" xml:id="middle3" /><region tts:origin="10% 56%" tts:extent="80% 15%" tts:displayAlign="center" tts:textAlign="center" xml:id="middle2" /><region tts:origin="28% 70%" tts:extent="

In [41]:
## CV2

import os
import cv2
import time
import numpy as np

# ✅ Optional Resize Parameter for YOLO detection (None = no resize)
resize_resolution = (854, 480)  # or None
# resize_resolution = (640, 360)
stream_fps = 5  # Optional frame skipping (not implemented in this script)

# ✅ Define file paths
video_input_path = "/content/optimal_subtitle_copied/ABC_News.mp4"
file_path = "/content/optimal_subtitle_copied/TTML_file/ABC_News.ttml"
output_path = "/content/output.ttml"

# ⏳ Load metadata
start_load_time = time.time()
fps = get_video_fps(video_input_path)
subtitle_data = parse_subtitle_file(file_path)
subtitle_timestamps = get_subtitle_timestamps(file_path)
cap = cv2.VideoCapture(video_input_path)

# ⏳ Get original resolution for TTML output
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
video_duration = total_frames / fps
end_load_time = time.time()

# ⏱ Profiling stats
load_duration = end_load_time - start_load_time
total_video_read_time = 0
total_yolo_time = 0
total_region_assign_time = 0

# ✅ Use resized resolution for detection if specified
if resize_resolution:
    frame_width, frame_height = resize_resolution
    print(f"🔄 Resized Frame Dimensions for Detection: {frame_width}x{frame_height}")
else:
    print(f"🖼 Using Original Resolution for Detection: {frame_width}x{frame_height}")


# print(f"🎞 Frame Dimensions: {frame_width}x{frame_height}")
print(f"📦 Total Load Time: {load_duration:.2f} seconds")

# ⏱ Start processing
start_run_time = time.time()

frame_buffer = []
subtitle_index = 0
frame_number = 0
skip_interval = int(round(fps / stream_fps)) if stream_fps < fps else 1

while cap.isOpened():
    read_start = time.time()
    ret, frame = cap.read()
    read_end = time.time()
    total_video_read_time += (read_end - read_start)

    if not ret:
        break

    # Skip frames not in interval
    if frame_number % skip_interval != 0:
        frame_number += 1
        continue

    detection_frame = cv2.resize(frame, resize_resolution) if resize_resolution else frame
    frame_time = frame_number / fps

    if subtitle_index < len(subtitle_data):
        sub = subtitle_data[subtitle_index]
        if sub["start"] <= frame_time <= sub["end"]:
            frame_buffer.append(detection_frame)

        if frame_time > sub["end"]:
            if frame_buffer:
                detect_start = time.time()
                region = process_frames_batch_3fps_processed_test(frame_buffer)
                detect_end = time.time()
                total_yolo_time += (detect_end - detect_start)

                assign_start = time.time()
                subtitle_data[subtitle_index]["region"] = region
                assign_end = time.time()
                total_region_assign_time += (assign_end - assign_start)

            frame_buffer.clear()
            subtitle_index += 1

    frame_number += 1

# ⏳ Final frame batch
if frame_buffer and subtitle_index < len(subtitle_data):
    detect_start = time.time()
    region = process_frames_batch_3fps_processed_test(frame_buffer)
    detect_end = time.time()
    total_yolo_time += (detect_end - detect_start)
    subtitle_data[subtitle_index]["region"] = region

cap.release()
end_run_time = time.time()

# ✅ TTML Generation
ttml_gen_start = time.time()
layout = get_used_safe_zones()
generate_updated_ttml(file_path, output_path, layout, subtitle_data, frame_width, frame_height)
ttml_gen_end = time.time()
ttml_generation_time = ttml_gen_end - ttml_gen_start

# ✅ Summary
run_minutes, run_seconds = divmod(end_run_time - start_run_time, 60)
minutes, seconds = divmod(video_duration, 60)

print("\n📊 PROFILING SUMMARY")
print(f"🎬 Video Duration: {int(minutes)}m {int(seconds)}s")
print(f"📦 Load Time: {load_duration:.2f}s")
print(f"🚀 Run Time: {int(run_minutes)}m {int(run_seconds)}s")
print(f"📥 Video Read Time: {total_video_read_time:.2f}s")
print(f"🔍 YOLO Detection Time: {total_yolo_time:.2f}s")
print(f"📐 Region Assignment Time: {total_region_assign_time:.5f}s")
print(f"📝 TTML Generation Time: {ttml_generation_time:.5f}s")
print(f"✅ Output TTML saved to: {output_path}")

🔄 Resized Frame Dimensions for Detection: 854x480
📦 Total Load Time: 0.12 seconds
480 854
480 854
480 854
480 854
480 854
480 854
480 854
480 854
480 854
480 854
480 854
480 854
480 854
<tt xmlns="http://www.w3.org/ns/ttml" xmlns:ttm="http://www.w3.org/ns/ttml#metadata" xmlns:ttp="http://www.w3.org/ns/ttml#parameter" ttp:timeBase="media" xml:lang="en" xmlns:tts="http://www.w3.org/ns/ttml#styling">
  <head>
    <metadata>
      <ttm:title />
    </metadata>
    <styling>
      <style xml:id="s0" tts:color="white" tts:fontSize="70%" tts:fontFamily="sansSerif" tts:backgroundColor="black" tts:displayAlign="center" tts:wrapOption="wrap" /></styling>
    <layout>
      <region tts:origin="10% 70%" tts:extent="80% 15%" tts:displayAlign="center" tts:textAlign="center" xml:id="middle3" /><region tts:origin="10% 56%" tts:extent="80% 15%" tts:displayAlign="center" tts:textAlign="center" xml:id="middle2" /><region tts:origin="22% 70%" tts:extent="79% 15%" tts:displayAlign="center" tts:textAlign="c

In [20]:
## ffmpeg

import os
import cv2
import time
import subprocess
import numpy as np

# ✅ Optional Resize and FPS Reduction Parameters
resize_resolution = (854, 480)  # (width, height)
stream_fps = 5

# ✅ Define file paths
video_input_path = "/content/optimal_subtitle_copied/test_video_4.mp4"
file_path = "/content/optimal_subtitle_copied/TTML_file/test_video_4_eng.ttml"
output_path = "/content/output.ttml"

# ✅ Start Load Timer
start_load_time = time.time()

fps = get_video_fps(video_input_path)
subtitle_data = parse_subtitle_file(file_path)
subtitle_timestamps = get_subtitle_timestamps(file_path)

# ✅ Get original resolution
cap = cv2.VideoCapture(video_input_path)
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
video_duration = total_frames / fps
cap.release()

end_load_time = time.time()
load_duration = end_load_time - start_load_time

# ✅ Start FFmpeg Setup + Conversion Timer
ffmpeg_setup_start = time.time()

width, height = resize_resolution
frame_size = width * height * 3
ffmpeg_cmd = [
    'ffmpeg',
    '-i', video_input_path,
    '-vf', f'fps={stream_fps},scale={width}:{height}',
    '-f', 'rawvideo',
    '-pix_fmt', 'bgr24',
    '-loglevel', 'quiet',
    '-'
]
pipe = subprocess.Popen(ffmpeg_cmd, stdout=subprocess.PIPE, bufsize=10**8)

ffmpeg_conversion_start = time.time()  # ✅ Conversion begins (after setup)
ffmpeg_setup_time = ffmpeg_conversion_start - ffmpeg_setup_start

# ✅ Start Frame Processing Timer
processing_start = time.time()

frame_buffer = []
subtitle_index = 0
frame_number = 0
frame_time = 0

total_video_read_time = 0
total_yolo_time = 0
total_region_assign_time = 0

while True:
    read_start = time.time()
    raw_frame = pipe.stdout.read(frame_size)
    read_end = time.time()

    total_video_read_time += (read_end - read_start)
    if not raw_frame:
        break

    frame = np.frombuffer(raw_frame, dtype=np.uint8).reshape((height, width, 3))
    frame_time = frame_number * (1 / stream_fps)
    original_frame_index = int(frame_time * fps)
    frame_time = original_frame_index / fps

    if subtitle_index < len(subtitle_data):
        current_subtitle = subtitle_data[subtitle_index]
        current_start = current_subtitle["start"]
        current_end = current_subtitle["end"]

        if current_start <= frame_time <= current_end:
            frame_buffer.append(frame)

        if frame_time > current_end:
            if frame_buffer:
                detect_start = time.time()
                processed_frames = process_frames_batch_3fps_processed_test(
                    frame_buffer, process_fps=3, video_fps=stream_fps)
                detect_end = time.time()
                total_yolo_time += (detect_end - detect_start)

                region_start = time.time()
                current_subtitle["region"] = processed_frames
                region_end = time.time()
                total_region_assign_time += (region_end - region_start)

                frame_buffer.clear()
                subtitle_index += 1
            else:
                subtitle_index += 1

    frame_number += 1

# ✅ Final batch
if frame_buffer and subtitle_index < len(subtitle_data):
    detect_start = time.time()
    processed_frames = process_frames_batch_3fps_processed_test(frame_buffer, process_fps=3, video_fps=stream_fps)
    detect_end = time.time()
    total_yolo_time += (detect_end - detect_start)
    subtitle_data[subtitle_index]["region"] = processed_frames

pipe.stdout.close()
pipe.wait()

ffmpeg_conversion_end = time.time()
ffmpeg_conversion_duration = ffmpeg_conversion_end - ffmpeg_conversion_start

processing_end = time.time()
processing_duration = processing_end - processing_start

# ✅ TTML Generation
ttml_gen_start = time.time()
layout = get_used_safe_zones()
generate_updated_ttml(file_path, output_path, layout, subtitle_data, width, height)
ttml_gen_end = time.time()
ttml_generation_time = ttml_gen_end - ttml_gen_start

minutes, seconds = divmod(video_duration, 60)

# ✅ Profiling Summary
print("\n📊 PROFILING SUMMARY")
print(f"🎬 Video Duration: {int(minutes)}m {int(seconds)}s")
print(f"📦 Load Time: {load_duration:.2f}s")
print(f"⚙️ FFmpeg Setup Time: {ffmpeg_setup_time:.2f}s")
print(f"🎞 FFmpeg Conversion Time: {ffmpeg_conversion_duration:.2f}s")
print(f"🚀 Frame Processing Time: {processing_duration:.2f}s")
print(f"📥 Video Read Time: {total_video_read_time:.2f}s")
print(f"🔍 YOLO Detection Time: {total_yolo_time:.2f}s")
print(f"📐 Region Assignment Time: {total_region_assign_time:.4f}s")
print(f"📝 TTML Generation Time: {ttml_generation_time:.4f}s")
print(f"✅ Output TTML saved to: {output_path}")

480 854
480 854
480 854
480 854
<tt xmlns="http://www.w3.org/ns/ttml" xmlns:ttm="http://www.w3.org/ns/ttml#metadata" xmlns:ttp="http://www.w3.org/ns/ttml#parameter" ttp:timeBase="media" xml:lang="en" xmlns:tts="http://www.w3.org/ns/ttml#styling">
  <head>
    <metadata>
      <ttm:title />
    </metadata>
    <styling>
      <style xml:id="s0" tts:color="white" tts:fontSize="70%" tts:fontFamily="sansSerif" tts:backgroundColor="black" tts:displayAlign="center" tts:wrapOption="wrap" /></styling>
    <layout>
      <region tts:origin="23% 203%" tts:extent="180% 23%" tts:displayAlign="center" tts:textAlign="center" xml:id="bottom" /><region tts:origin="10% 70%" tts:extent="80% 15%" tts:displayAlign="center" tts:textAlign="center" xml:id="middle3" /><region tts:origin="23% 147%" tts:extent="180% 34%" tts:displayAlign="center" tts:textAlign="center" xml:id="between_m3_m2_1" /><region tts:origin="0% 70%" tts:extent="71% 15%" tts:displayAlign="center" tts:textAlign="center" xml:id="shifted_mid

In [None]:
!zip -r /content/optimal_subtitle_copied.zip /content/optimal_subtitle_copied

  adding: content/optimal_subtitle_copied/ (stored 0%)
  adding: content/optimal_subtitle_copied/test_video_4.mp4 (deflated 1%)
  adding: content/optimal_subtitle_copied/subtitle_pre_positions(updated).json (deflated 85%)
  adding: content/optimal_subtitle_copied/subtitle_regions_scaled_test.json (deflated 48%)
  adding: content/optimal_subtitle_copied/news_video_subtitle_positions.json (deflated 74%)
  adding: content/optimal_subtitle_copied/CBS_news.mp4 (deflated 0%)
  adding: content/optimal_subtitle_copied/best.pt (deflated 11%)
  adding: content/optimal_subtitle_copied/test_video_2.mp4 (deflated 2%)
  adding: content/optimal_subtitle_copied/test_video_3.mp4 (deflated 0%)
  adding: content/optimal_subtitle_copied/TTML_file/ (stored 0%)
  adding: content/optimal_subtitle_copied/TTML_file/test_long_video_17_eng.ttml (deflated 75%)
  adding: content/optimal_subtitle_copied/TTML_file/NBC_news.ttml (deflated 74%)
  adding: content/optimal_subtitle_copied/TTML_file/CBS_News.ttml (deflate