<a href="https://colab.research.google.com/github/bolinocroustibat/movies-palettes/blob/main/movies_palettes.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Imports

In [None]:
import numpy as np
import cv2
import glob
import json
import os
import re
import time
from datetime import datetime, timezone
from multiprocessing import Pool, cpu_count
from pathlib import Path, PosixPath
from PIL import Image
from skimage import feature, color
from sklearn.cluster import KMeans
from tqdm.notebook import tqdm

# Configuration constants

*   **`RESIZE_W` and `RESIZE_H` (Reduced Frame Size)**

  This determines the resolution of each frame after downsampling. A reasonable value should:
    - Retain enough details for meaningful clustering.
    - Eliminate excessive computational overhead from large frame sizes.

  Suggested Value: 160 x 90 (W x H)
  - This keeps the aspect ratio of typical widescreen content (16:9).
  -	At 160x90, you get 14,400 pixels per frame, which is sufficient for clustering dominant colors while ensuring faster processing.

*  **`FRAME_SKIP` (Frames Skipped)**

    This determines how many frames you skip between processed frames.
  -	Movies at 24fps typically don’t have significant color changes frame-by-frame.
  -	A higher skip value reduces computation but might miss some rapid scene changes.

  Suggested Value: `FRAME_SKIP = 60`
 	- This processes 1 frame per 2.5 seconds (approx.), enough to capture major color changes without excessive redundancy.
 	-	For a 90-minute movie:
 	   - Total frames at 24fps:  90 x 60 x 24 = 129,600
 	   - With `FRAME_SKIP = 60`, you’ll process around  129,600 / 60 = 2,160  frames per movie.

*  **`BATCH_SIZE` (Frames Processed at Once)**

  This determines the number of frames processed in a single batch.
 	-	Larger batch sizes are computationally efficient as you can leverage batch processing in libraries like OpenCV.
 	-	However, too large a batch size might lead to memory limitations.

  Suggested Value: `BATCH_SIZE = 20`
 	- Processes 20 frames at a time, balancing speed and memory use.
 	- Matches well with typical video processing setups on consumer GPUs/TPUs.

In [None]:
CLUSTERS_NB = 10
FRAME_SKIP = 60  # Process one frame out of every FRAME_SKIP frames in order to speed up the process
# Resize the image to RESIZE_W x RESIZE_H pixels in order to reduce complexity
RESIZE_W = 160
RESIZE_H = 90
BATCH_SIZE = 20  # Number of frames to process in each batch

# Mount Google Drive

In [None]:
from google.colab import drive
drive.mount('/content/drive')

## Google Drive movies path and movies colors list file path

> **WARNING**: Don't forget to mount Google Drive first

In [None]:
MOVIES_PATH = Path("/content/drive/MyDrive/MOVIES/")
FILE_PATH = Path("/content/drive/MyDrive/MOVIES/movies_palettes.json")

# Function to save movies colors list in a JSON file on Google Drive

In [None]:
def save_as_file(data: list, file_path: Path) -> None:
    """Save as JSON file on Google Drive"""
    with open(file_path, "w", encoding="utf-8") as f:
        json.dump(data, f, ensure_ascii=False, indent='\t')
    print(f'"{str(file_path)}" successfully saved.')

# Build analysis file with list of dicts with movies info
> **WARNING**: Only to be executed if the file doesn't exist yet, otherwise it will overwrite the file.

In [None]:
# subdirs: list[PosixPath] = [p for p in MOVIES_PATH.iterdir() if p.is_dir()]

# movies: list[dict] = []

# for p in subdirs:
#     # Gather some info about the movie using the folders names
#     matches: re.Match | None = re.search(r"(.*) \((\d{4})\, (.*)\)", p.name)
#     # Prepare the data structure
#     year: str | None = None
#     director: str | None = None
#     file_path: str | None = None
#     if matches:
#         title: str = matches.group(1)
#         year  = matches.group(2)
#         director = matches.group(3)
#     else:
#       title = p.name

#     # Get all video files for this movie directory
#     file_types: tuple[str] = ('*.avi', '*.mkv', '*.mp4')
#     files_paths: list[Path] = []
#     for file_type in file_types:
#         files_paths.extend(p.glob(file_type))

#     # Get the unique video file path for this movie directory
#     if len(files_paths) == 1:
#         file_path: Path = files_paths[0]
#         status: str = "Movie file found"
#     else:
#       if len(files_paths) == 0:
#           status = "No movie file found"
#           print(f'{status} for "{title}"')
#       else:
#           status = "More than 1 video file found:"
#           for f in files_paths:
#               status += f' \"{f.name}\"'
#           print(status)

#     movie: dict = {
#         "title": title,
#         "status": status,
#         "director": director,
#         "year": year,
#         "path": str(file_path) if file_path else None,
#         "palettes": [],
#     }
#     movies.append(movie)

# # Sort alphabetically
# movies.sort(key=lambda m: m["title"])

# # Save as file
# save_as_file(data=movies, file_path=FILE_PATH)

# Optional: analyze frames numbers and length

In [None]:
# for m in (pbar := tqdm(movies)):

#   pbar.set_description(f'Analyzing "{m["title"]}"')

#   if m.get("path") and not m.get("frames"):
#     cap = cv2.VideoCapture(str(m["path"]))
#     m["frames"]: int = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
#     fps: float = cap.get(cv2.CAP_PROP_FPS)
#     if fps:
#       m["length"]: int = int(m["frames"] / fps)

# save_as_file(data=movies, file_path=FILE_PATH)

# Read existing movies analysis file

In [None]:
with open(FILE_PATH, "r+", encoding="utf-8") as f:
    movies: list[dict] = json.load(f)

print(f"Loaded {len(movies)} movies from \"{str(FILE_PATH)}\".")

# Runtime type

In [None]:
def get_runtype_type() -> str:
  if int(os.environ.get("COLAB_GPU", 0)) > 0:
    return "GPU"
  elif "TPU_DRIVER_MODE" in os.environ and os.environ["TPU_DRIVER_MODE"] == "tpu":
    return "TPU"
  elif "COLAB_TPU_ADDR" in os.environ and os.environ["COLAB_TPU_ADDR"]:
    return "TPU"
  return "unknown"

# Main logic

This is the main logic here. It might takes a few minutes to a few hours per movie file, depending on where it's run.

I suggest a GPU on Google Colab as the runtime.


In [None]:
def get_dominant_colors(
    data: np.ndarray, clusters_nb: int
) -> tuple[np.ndarray, np.ndarray]:
    """
    Cluster pixels using k-means and return the dominant colors.

    Parameters:
    data (np.ndarray): A 2D array where each row is a pixel in RGB format.
    clusters_nb (int): The number of clusters to form.

    Returns:
    tuple[np.ndarray, np.ndarray]: A tuple containing:
        - centers (np.ndarray): The RGB values of the cluster centers.
        - labels (np.ndarray): The label of the cluster each pixel belongs to.
    """
    criteria: tuple[int, int, float] = (
        cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER,
        10,
        1.0,
    )
    flags: int = cv2.KMEANS_RANDOM_CENTERS
    compactness: float
    labels: np.ndarray
    centers: np.ndarray
    compactness, labels, centers = cv2.kmeans(
        data.astype(np.float32), clusters_nb, None, criteria, 10, flags
    )
    return centers, labels

In [None]:
def get_salient_mask(image: np.ndarray) -> np.ndarray:
    """
    Generates a saliency mask for the input image, highlighting the most important
    (salient) regions based on a Spectral Residual model.

    Parameters:
    image (np.ndarray): The input color image in BGR format.

    Returns:
    np.ndarray: A binary saliency map (uint8), where pixel values range from 0 to 255,
                with high values indicating salient regions.
    """
    # Convert the image to grayscale
    gray: np.ndarray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)

    # Create a saliency detector object using the Spectral Residual model
    saliency = cv2.saliency.StaticSaliencySpectralResidual_create()

    # Compute the saliency map
    _, saliency_map = saliency.computeSaliency(gray)

    # Return the saliency map scaled to the range [0, 255] and cast to uint8
    return (saliency_map * 255).astype("uint8")

In [None]:
def enhance_saturation(image: np.ndarray, factor: float=1.5, threshold: int = 50) -> np.ndarray:
    """
    Enhance the saturation of an image and filter out low-saturation pixels.

    Parameters:
    image (np.ndarray): The input image in BGR format.
    factor (float): The factor by which to enhance saturation.
    threshold (int): The saturation threshold below which pixels are filtered out.

    Returns:
    np.ndarray: The enhanced image.
    """
    # Convert image to HSV color space
    hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)

    # Enhance the saturation
    hsv[:, :, 1] = np.clip(hsv[:, :, 1] * factor, 0, 255)

    # Create a mask for pixels with sufficient saturation
    saturation_mask = hsv[:, :, 1] > threshold

    # Apply the mask to filter low-saturation pixels (set to black)
    hsv[~saturation_mask] = 0

    # Convert back to BGR color space
    return cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR)

In [None]:
def process_frame(frame: np.ndarray, clusters_nb: int) -> tuple[np.ndarray, np.ndarray]:
    # To reduce complexity, resize the image
    frame = cv2.resize(frame, (RESIZE_W, RESIZE_H))

    # Enhance saturation
    data = enhance_saturation(frame)

    # Apply saliency detection
    salient_mask = get_salient_mask(data)
    salient_mask = salient_mask > 128  # Convert to binary mask
    data = frame[salient_mask]  # Apply mask to RGB channels
    if data.size == 0:
        # Return empty centers and labels if no salient data exists
        return np.empty((0, 3)), np.empty((0,))

    # Convert to LAB for better perceptual clustering
    # Generate a Numpy array of 2 dimensions, and shape of (10000, 3) (1000)
    lab_data = cv2.cvtColor(data, cv2.COLOR_BGR2LAB).reshape(-1, 3)

    # Perform k-means clustering
    centers, labels = get_dominant_colors(lab_data, CLUSTERS_NB)

    # Filter out low-luminance colors
    luminance = centers[:, 0]
    valid_indices = luminance > 50
    centers = centers[valid_indices]
    # If no valid colors remain, return empty results
    if centers.size == 0:
        return np.empty((0, 3)), np.empty((0,))

    return centers, labels


In [None]:
def process_movie(movie: dict) -> None:
    """
    Calculate the palette for a movie and save the updated movies palettes files.

    Parameters:
    movie (dict): the movie details from the movies list.
    """
    start_time = time.time()  # Record the start time

    cap = cv2.VideoCapture(str(m["path"]))
    frames_count: int = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

    pbar.set_description(
        f'Processing "{m["title"]}", with {frames_count} frames.'
    )

    frame_nb = 0
    colors: list[np.ndarray] = []
    with tqdm(total=frames_count // FRAME_SKIP, desc="Frames processed") as frame_pbar:
        while frame_nb < frames_count:

            # Extract a batch of images/frames
            batch_frames: list[np.ndarray] = []
            for _ in range(BATCH_SIZE):
                cap.set(cv2.CAP_PROP_POS_FRAMES, frame_nb)
                status, frame = cap.read()
                frame_nb += FRAME_SKIP
                if not status:
                  print(f"Frame {frame_nb} could not be read. Skipping.")
                  break
                batch_frames.append(frame)

            if not batch_frames:
                break

            for frame in batch_frames:
                # To reduce complexity, resize the image
                data = cv2.resize(frame, (RESIZE_W, RESIZE_H))

                # Convert the image to a list of pixels
                # Generate a Numpy array of 2 dimensions, and shape of (10000, 3)
                data: np.ndarray = data.reshape(-1, 3)

                centers, labels = get_dominant_colors(data, CLUSTERS_NB)

                # Put the CLUSTERS_NB colors in a list (will be later converted to a Numpy array)
                # For example: [4450 2148 745 2048 609]
                cluster_sizes: np.ndarray = np.bincount(labels.flatten())

                # Sort from the largest to the smallest cluster and append to the list
                for cluster_idx in np.argsort(-cluster_sizes):
                    colors.append(centers[cluster_idx])

            frame_pbar.update(len(batch_frames))

    # Convert list of colors for this movie to a Numpy array
    colors = np.array(colors)

    if len(colors) == 0:
      print(f"No colors extracted for {m['title']}. Skipping.")
      return

    # Perform K-means clustering on the colors
    kmeans: KMeans = KMeans(n_clusters=CLUSTERS_NB, n_init="auto")
    kmeans.fit(colors)

    # Get the cluster centers (representative colors)
    cluster_centers: np.ndarray = kmeans.cluster_centers_

    # Calculate the total processing duration
    end_time: datetime = time.time()  # Record the end time
    duration = int(end_time - start_time)  # Duration in seconds

    palette: dict = {
        # Add the parameters it used to calculate the colors
        "calculation_date": datetime.now(timezone.utc).strftime("%Y/%m/%d_%H:%M:%S"),
        "calculation_duration_seconds": duration,
        "runtime": get_runtype_type(),
        "clusters_nb": CLUSTERS_NB,
        "frame_skip": FRAME_SKIP,
        "resize": {"width": RESIZE_W, "height": RESIZE_H},
        "batch_size": BATCH_SIZE,
        # Convert the cluster centers to integers (RGB values)
        "colors": cluster_centers.astype(int).tolist()
    }
    m["palettes"].append(palette)

    # Debug: display output
    print(f'Colors for \"{m["title"]}\":')
    for color in palette["colors"]:
      img = Image.new(mode='RGB', size=(200,30), color=tuple(color))
      display(img)

    # Saving the updated list of dicts as a file
    save_as_file(data=movies, file_path=FILE_PATH)

# Final loop through movies

In [None]:
RECALC_PALETTES = False # Recalculate a new palette even if it already has one

for m in (pbar := tqdm(movies)):
    if len(m["palettes"]) > 0 and not RECALC_PALETTES:
        print(f'"{m["title"]}" already has at least one color palette calculated, skipping...')
        continue

    if not m.get("path"):
        print(f'"{m["title"]}" has no filepath, skipping...')
        continue

    else:
      process_movie(movie=m)

# Test: display colors palettes for each movie

In [None]:
for m in (pbar := tqdm(movies)):
    if len(m.get("palettes")) > 0:
        print(m["title"])
        for p in m["palettes"]:
          print(f'\nPalette calculated on {p.get("calculation_date", "unknown")}:')
          for color in p["colors"]:
            img = Image.new(mode='RGB', size=(200,30), color=tuple(color))
            display(img)
        print("\n-------------------------------------\n")