<a href="https://colab.research.google.com/github/bolinocroustibat/movies-palettes/blob/main/movies_palettes.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Imports

In [None]:
import numpy as np
import cv2
import glob
import hashlib
import json
import os
import re
import time
from datetime import datetime, timezone
from multiprocessing import Pool, cpu_count
from pathlib import Path, PosixPath
from PIL import Image
from skimage import feature, color
from sklearn.cluster import KMeans
from tqdm.notebook import tqdm

# Configuration constants

*   **`RESIZE_W` and `RESIZE_H` (Reduced Frame Size)**

  This determines the resolution of each frame after downsampling. A reasonable value should:
    - Retain enough details for meaningful clustering.
    - Eliminate excessive computational overhead from large frame sizes.

  Suggested Value: 160 x 90 (W x H)
  - This keeps the aspect ratio of typical widescreen content (16:9).
  -	At 160x90, you get 14,400 pixels per frame, which is sufficient for clustering dominant colors while ensuring faster processing.

*  **`FRAME_SKIP` (Frames Skipped)**

    This determines how many frames you skip between processed frames.
  -	Movies at 24fps typically don’t have significant color changes frame-by-frame.
  -	A higher skip value reduces computation but might miss some rapid scene changes.

  Suggested Value: `FRAME_SKIP = 60`
 	- This processes 1 frame per 2.5 seconds (approx.), enough to capture major color changes without excessive redundancy.
 	-	For a 90-minute movie:
 	   - Total frames at 24fps:  90 x 60 x 24 = 129,600
 	   - With `FRAME_SKIP = 60`, you’ll process around  129,600 / 60 = 2,160  frames per movie.

*  **`BATCH_SIZE` (Frames Processed at Once)**

  This determines the number of frames processed in a single batch.
 	-	Larger batch sizes are computationally efficient as you can leverage batch processing in libraries like OpenCV.
 	-	However, too large a batch size might lead to memory limitations.

  Suggested Values:
  - For CPU processing, you can try `BATCH_SIZE = 20`.
  - For an NVIDIA A100 GPU, which has 40/80GB of VRAM depending on the variant, you can use a significantly larger batch, like `BATCH_SIZE = 32`



In [None]:
CLUSTERS_NB = 10
CLUSTERS_NB_BW = 4

FRAME_SKIP = 60  # Process one frame out of every FRAME_SKIP frames in order to speed up the process
# Resize the image to RESIZE_W x RESIZE_H pixels in order to reduce complexity
RESIZE_W = 160
RESIZE_H = 90
BATCH_SIZE = 32  # Number of frames to process in each batch

METHOD = "cv2" # "cv2" or "sklearn". sklearn is supposedly more precise, but cv2 seems to give better results

# Recalculate a new palette even if it already has one
RECALC_PALETTES = False
# Calculate only a specific movie
MOVIE_TO_PROCESS = None  # Set to None to process all movies

# Mount Google Drive

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


## Google Drive movies path and movies colors list file path

> **WARNING**: Don't forget to mount Google Drive first

In [None]:
MOVIES_PATH = Path("/content/drive/MyDrive/MOVIES/")
FILE_PATH = Path("/content/drive/MyDrive/MOVIES/movies_palettes.json")

# Function to save movies colors list in a JSON file on Google Drive

In [None]:
def save_as_file(data: dict, file_path: Path) -> None:
    """Save as JSON file on Google Drive"""

    def convert_numpy_types(obj):
        """Convert NumPy types to Python native types."""
        if isinstance(obj, np.bool_):
            return bool(obj)
        if isinstance(obj, np.integer):
            return int(obj)
        if isinstance(obj, np.floating):
            return float(obj)
        if isinstance(obj, np.ndarray):
            return obj.tolist()
        return obj

    with open(file_path, "w", encoding="utf-8") as f:
        json.dump(data, f, ensure_ascii=False, indent='\t', default=convert_numpy_types)
    print(f'"{str(file_path)}" successfully saved.')

# Build analysis file with list of dicts with movies info
> **WARNING**: Only to be executed if the file doesn't exist yet, otherwise it will overwrite the file.

In [None]:
# subdirs: list[PosixPath] = [p for p in MOVIES_PATH.iterdir() if p.is_dir()]

# movies: list[dict] = []

# for p in subdirs:
#     # Try to match both formats:
#     # 1. "Title (Year, Director)"
#     # 2. "Title (Year)"
#     # Single regex pattern with optional director group
#     match = re.search(r"(.*) \((\d{4})(?:\, (.*))?\)", p.name)

#     # Extract movie info
#     title: str | None = match.group(1) if match else p.name
#     year: str | None = match.group(2) if match else None
#     director: str | None = match.group(3) if match and match.group(3) else None

#     # Get all video files for this movie directory
#     file_types: tuple[str] = ("*.avi", "*.mkv", "*.mp4")
#     files_paths: list[Path] = []
#     for ft in file_types:
#         files_paths.extend(p.glob(ft))

#     # Get the unique video file path for this movie directory
#     if len(files_paths) == 1:
#         file_path: Path = files_paths[0]
#     else:
#       if len(files_paths) == 0:
#           status = "No movie file found"
#           print(f'{status} for "{title}"')
#       else:
#           status = "More than 1 video file found:"
#           for f in files_paths:
#               status += f' \"{f.name}\"'
#           print(status)

#     movie: dict = {
#         "title": title,
#         "status": status,
#         "director": director,
#         "year": year,
#         "path": str(file_path) if file_path else None,
#         "palettes": [],
#     }
#     movies.append(movie)

# # Sort alphabetically
# movies.sort(key=lambda m: m["title"])

# # Save as file
# content: dict = {
#     "last_updated": datetime.now(timezone.utc).isoformat(),
#     "movies": movies
# }
# save_as_file(data=content, file_path=FILE_PATH)

# Optional: analyze frames numbers and length

In [None]:
# for m in (pbar := tqdm(movies)):

#   pbar.set_description(f'Analyzing "{m["title"]}"')

#   if m.get("path") and not m.get("frames"):
#     cap = cv2.VideoCapture(str(m["path"]))
#     m["frames"]: int = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
#     fps: float = cap.get(cv2.CAP_PROP_FPS)
#     if fps:
#       m["length"]: int = int(m["frames"] / fps)

# save_as_file(data=movies, file_path=FILE_PATH)

# Read existing movies analysis file

In [None]:
with open(FILE_PATH, "r+", encoding="utf-8") as f:
    content: dict = json.load(f)
    movies: list[dict] = content["movies"]

print(f"Loaded {len(movies)} movies from \"{str(FILE_PATH)}\".")

Loaded 441 movies from "/content/drive/MyDrive/MOVIES/movies_palettes.json".


# Function to get runtime type

In [None]:
def get_runtype_type() -> str:
  if int(os.environ.get("COLAB_GPU", 0)) > 0:
    return "GPU"
  elif "TPU_DRIVER_MODE" in os.environ and os.environ["TPU_DRIVER_MODE"] == "tpu":
    return "TPU"
  elif "COLAB_TPU_ADDR" in os.environ and os.environ["COLAB_TPU_ADDR"]:
    return "TPU"
  return "unknown"

# Function to build a UUID

In [None]:
def generate_palette_id(title: str, calculation_date: str) -> str:
    """Generate a short unique identifier for a palette"""
    unique_string = f"{title}_{calculation_date}"
    hash_object = hashlib.md5(unique_string.encode())
    return hash_object.hexdigest()[:6]

# Main logic

This is the main logic here. It might takes a few minutes to a few hours per movie file, depending on where it's run.

I suggest a GPU on Google Colab as the runtime.


In [None]:
def is_black_and_white(cap: cv2.VideoCapture, threshold: float = 0.1) -> bool:
    """
    Determine if a movie is black and white based on the average saturation.

    Parameters:
    cap (cv2.VideoCapture): The video capture object.
    threshold (float): The saturation threshold below which the movie is considered B&W.

    Returns:
    bool: True if the movie is black and white, False otherwise.
    """
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    sample_frames = min(total_frames, 100)  # Sample up to 100 frames

    saturation_values = []
    for i in range(0, total_frames, max(1, total_frames // sample_frames)):
        cap.set(cv2.CAP_PROP_POS_FRAMES, i)
        ret, frame = cap.read()
        if not ret:
            continue
        hsv = cv2.cvtColor(frame, cv2.COLOR_BGR2HSV)
        saturation = hsv[:, :, 1].mean() / 255.0  # Normalize to [0,1]
        saturation_values.append(saturation)

    average_saturation = np.mean(saturation_values)
    return average_saturation < threshold

In [None]:
def get_dominant_colors_cv2(
    data: np.ndarray, clusters_nb: int
) -> tuple[np.ndarray, np.ndarray]:
    """OpenCV version"""
    criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 10, 1.0)
    flags = cv2.KMEANS_RANDOM_CENTERS

    # Ensure data is float32 for cv2.kmeans
    data = data.astype(np.float32)

    # cv2.kmeans returns: retval, labels, centers
    _, labels, centers = cv2.kmeans(data, clusters_nb, None, criteria, 10, flags)

    # Ensure labels are in the same format as sklearn (1D array)
    labels = labels.ravel()

    # Ensure centers are in the same format as sklearn
    centers = centers.astype(np.float64)  # sklearn uses float64

    return centers, labels

def get_dominant_colors_sklearn(
    data: np.ndarray, clusters_nb: int
) -> tuple[np.ndarray, np.ndarray]:
    """
    Cluster pixels using k-means and return the dominant colors.
    Will automatically reduce number of clusters if not enough distinct clusters are found.

    Parameters:
    data (np.ndarray): A 2D array where each row is a pixel in RGB format.
    clusters_nb (int): The maximum number of clusters to form.

    Returns:
    tuple[np.ndarray, np.ndarray]: A tuple containing:
        - centers (np.ndarray): The RGB values of the cluster centers.
        - labels (np.ndarray): The label of the cluster each pixel belongs to.
    """
    kmeans = KMeans(n_clusters=clusters_nb, n_init="auto")
    labels = kmeans.fit_predict(data)
    centers = kmeans.cluster_centers_

    return centers, labels


def get_dominant_luminances(
    data: np.ndarray, clusters_nb: int
) -> tuple[np.ndarray, np.ndarray]:
    """
    Find the dominant luminance values in an image using k-means clustering.
    Returns these luminance values as grayscale RGB values (where R=G=B).

    Parameters:
    data (np.ndarray): Input image (can be color or grayscale)
    clusters_nb (int): Number of luminance levels to identify

    Returns:
    tuple[np.ndarray, np.ndarray]: A tuple containing:
        - centers (np.ndarray): Array of shape (clusters_nb, 3) where each row is
          a grayscale RGB value [v,v,v] representing a dominant luminance level
        - labels (np.ndarray): Array indicating which cluster each pixel belongs to
    """
    # Convert to grayscale if input is color image
    if len(data.shape) > 2:
        data = cv2.cvtColor(data, cv2.COLOR_BGR2GRAY)

    # Reshape to 1D array for clustering
    data_reshaped = data.reshape(-1, 1).astype(np.float32)

    # Define k-means parameters
    criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 10, 1.0)
    flags = cv2.KMEANS_RANDOM_CENTERS

    # Perform k-means clustering on luminance values
    _, labels, centers = cv2.kmeans(
        data_reshaped, clusters_nb, None, criteria, 10, flags
    )

    # Sort centers from darkest to lightest
    sort_idx = np.argsort(centers.flatten())
    centers = centers[sort_idx]

    # Convert luminance values to RGB format (all channels equal)
    # e.g., luminance value 127 becomes [127, 127, 127]
    centers_rgb = np.column_stack([centers.flatten()] * 3)

    return centers_rgb, labels

In [None]:
def process_batch(frames: list[np.ndarray], is_bw: bool) -> np.ndarray:
    """
    Process a batch of frames at once to extract dominant colors.

    Parameters:
    frames (list): List of frames to process
    is_bw (bool): Whether to process as black and white

    Returns:
    np.ndarray: Array of dominant colors found across all frames
    """
    # Stack all frames into a single array
    batch = np.stack(frames)

    # Resize all frames at once
    resized = np.array([cv2.resize(f, (RESIZE_W, RESIZE_H)) for f in batch])

    if is_bw:
        # Convert to grayscale and process all frames
        gray_batch = np.array([cv2.cvtColor(f, cv2.COLOR_BGR2GRAY) for f in resized])
        all_pixels = gray_batch.reshape(-1, 1)
        centers, labels = get_dominant_luminances(all_pixels, CLUSTERS_NB_BW)
    else:
        # Convert to RGB and process all frames
        rgb_batch = np.array([cv2.cvtColor(f, cv2.COLOR_BGR2RGB) for f in resized])

        # Reshape to process all pixels from all frames
        all_pixels = rgb_batch.reshape(-1, 3)

        # Get dominant colors across all frames at once
        if METHOD == "cv2":
            centers, labels = get_dominant_colors_cv2(all_pixels, CLUSTERS_NB)
        elif METHOD == "sklearn":
            centers, labels = get_dominant_colors_sklearn(
                all_pixels, CLUSTERS_NB
            )
        else:
            raise ValueError(f"Unknown method: {METHOD}")

    # Sort clusters by frequency
    unique_labels, counts = np.unique(labels, return_counts=True)
    sorted_idx = np.argsort(-counts)  # Sort in descending order
    centers = centers[sorted_idx]

    return centers


In [None]:
def process_movie(movie: dict) -> dict:
    """
    Calculate the palette for a movie.

    This function processes a movie file to extract its dominant color palette.
    For black and white movies, it extracts dominant luminance values instead.
    The processing is done in batches to improve performance.

    Parameters:
    movie (dict): A dictionary containing movie information including:
        - path: Path to the movie file
        - title: Movie title
        - palettes: List of previously calculated palettes

    Returns:
    dict: The input movie dictionary updated with a new palette entry
    """
    start_time: float = time.time()

    pbar.set_description(f'Processing movie "{movie["title"]}"...')

    cap = cv2.VideoCapture(str(movie["path"]))
    frames_count: int = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    frames_to_process: int = frames_count // FRAME_SKIP
    if frames_to_process < 1:
        print(f'Movie "{movie["title"]}" has less than {FRAME_SKIP} frames, skipping...')
        return movie

    # Detect if movie is black and white
    is_bw: bool = is_black_and_white(cap)
    if is_bw:
        print(f'Movie "{movie["title"]}" is black and white, with {frames_count} frames.')
    else:
        print(f'Movie "{movie["title"]}" is color, with {frames_count} frames.')

    frame_nb = 0
    all_centers = []
    frames_processed = 0

    with tqdm(total=frames_to_process, desc="Frames processed") as frame_pbar:
        while frame_nb < frames_count:
            batch_frames = []
            frames_read = 0

            # Calculate how many frames we still need to process
            remaining_frames = frames_to_process - frames_processed
            current_batch_size = min(BATCH_SIZE, remaining_frames)

            # Read up to current_batch_size frames
            for _ in range(current_batch_size):
                if frame_nb >= frames_count:
                    break

                cap.set(cv2.CAP_PROP_POS_FRAMES, frame_nb)
                status, frame = cap.read()

                if not status:
                    frame_nb += FRAME_SKIP
                    continue

                batch_frames.append(frame)
                frames_read += 1
                frame_nb += FRAME_SKIP

            if not batch_frames:
                break

            # Process batch and extend results
            batch_centers = process_batch(batch_frames, is_bw)
            all_centers.extend(batch_centers)
            frame_pbar.update(frames_read)
            frames_processed += frames_read

    if len(all_centers) == 0:
        print(f"No colors extracted for {movie['title']}. Skipping.")
        return movie

    # Combine all centers and do final clustering
    all_centers = np.vstack(all_centers)
    kmeans = KMeans(
        n_clusters=CLUSTERS_NB if not is_bw else CLUSTERS_NB_BW, n_init="auto"
    )
    kmeans.fit(all_centers)
    final_colors: np.ndarray = kmeans.cluster_centers_
    final_palette = [tuple(map(int, c)) for c in final_colors]

    # Calculate the total processing duration
    end_time: float = time.time()
    duration: int = int(end_time - start_time)
    # Calculation date
    now: str = datetime.now(timezone.utc).isoformat()

    palette: dict = {
        "id": generate_palette_id(title=movie["title"], calculation_date=now),
        "calculation_date": now,
        "calculation_duration_seconds": duration,
        "runtime": get_runtype_type(),
        "clustering_method": METHOD,
        "is_black_and_white": is_bw,
        "clusters_nb": CLUSTERS_NB_BW if is_bw else CLUSTERS_NB,
        "frame_skip": FRAME_SKIP,
        "resize": {"width": RESIZE_W, "height": RESIZE_H},
        "batch_size": BATCH_SIZE,
        "colors": final_palette,
    }
    movie["palettes"].append(palette)

    return movie


# Display palette

In [None]:
def display_palette(palette: dict) -> None:
    # Display metadata
    for k, v in palette.items():
        if k != "colors":  # Skip colors as we'll display them separately
            print(f'{k}: {v}')

    # Display color swatches
    for c in palette["colors"]:
        img = Image.new(mode="RGB", size=(200, 30), color=tuple(c))
        display(img)

# Final loop through movies

In [None]:
for m in (pbar := tqdm(movies)):
    # Skip movies not matching MOVIE_TO_PROCESS, if specified
    if MOVIE_TO_PROCESS and m["title"] != MOVIE_TO_PROCESS:
        continue

    if len(m["palettes"]) > 0 and not RECALC_PALETTES:
        print(f'"{m["title"]}" already has at least one color palette calculated, skipping...')
        continue

    if not m.get("path"):
        print(f'"{m["title"]}" has no filepath, skipping...')
        continue

    else:

      # Debug: display previous palette
      if len(m.get("palettes")) > 0:
        print("Previous palette:")
        display_palette(palette=m["palettes"][-1])

      updated_m: dict = process_movie(m)

      # Debug: display new palette
      if len(m.get("palettes")) > 0:
        print("Calculated palette:")
        display_palette(palette=m["palettes"][-1])

      # Update file data
      movies[movies.index(m)] = updated_m
      data: dict = {
          "last_updated": datetime.now(timezone.utc).isoformat(),
          "movies": movies
      }
      save_as_file(data=data, file_path=FILE_PATH)

# Test: display colors palettes for each movie

In [None]:
for m in (pbar := tqdm(movies)):
    if len(m.get("palettes")) > 0:
        print(m["title"])
        for p in m["palettes"]:
          print(f'\nPalette calculated on {p.get("calculation_date", "unknown")}, runtime: {p.get("runtime", "unknown")}')
          for color in p["colors"]:
            img = Image.new(mode='RGB', size=(200,30), color=tuple(color))
            display(img)
        print("\n-------------------------------------\n")