In [1]:
#| default_exp validate_images

In [2]:
#| export

import daft
import numpy as np
from typing import Callable
from functools import partial

## Define validations

In [3]:
#| export

MIN_BYTES = 500

In [4]:
# | export

def split_on_condition(df: daft.DataFrame, condition: Callable[[daft.DataFrame], daft.DataFrame]):
    """Splits a DataFrame into accepted and dropped rows based on a filtering condition.

    Args:
        df (daft.DataFrame): The input DataFrame.
        condition (Callable[[daft.DataFrame], daft.DataFrame]): A function that filters the DataFrame.

    Returns:
        Tuple[daft.DataFrame, daft.DataFrame]: (accepted_df, dropped_df)
    """
    filtered_df = condition(df)
    if filtered_df.count_rows() < df.count_rows():
        dropped_df = filtered_df.except_distinct(df)
        return filtered_df, dropped_df
    else:
        return df, None

In [5]:
# | export

# Define filtering functions
def size_nontrivial(df: daft.DataFrame) -> daft.DataFrame:
    """Keeps images that are at least MIN_BYTES in size on disk."""
    return df.filter(df["size"] > MIN_BYTES)

In [6]:
# | export

@daft.udf(return_dtype=daft.DataType.bool())
def array_not_oblong(arrs: daft.Series, max_oblongness: float = 4.0) -> bool:
    """is an array oblong"""
    arrs = arrs.to_pylist()
    shapes = np.array([a.shape[:2] for a in arrs])  # Extract h, w as an array
    max_aspects = np.max(shapes / shapes[:, ::-1], axis=1)  # Compute max(h/w, w/h)
    return max_aspects < max_oblongness

In [7]:
# | export

def img_not_oblong(df: daft.DataFrame) -> daft.DataFrame:
    """Keeps images with an aspect ratio between 1:4 and 4:1 using Daft's `image_decode`."""
    # checkable = decoded.with_column("is_not_oblong", df["img"].apply(array_not_oblong, daft.DataType.bool()))
    checkable = df.with_column("is_not_oblong", array_not_oblong(df["img"]))
    checked = checkable.filter(checkable["is_not_oblong"]).exclude("is_not_oblong")  # Drop transient column
    return checked

In [8]:
# | export

def img_name_distinct(df: daft.DataFrame, name_col: str = "img_name") -> daft.DataFrame:
    """Keeps images with unique filenames."""
    aggs = [daft.col(c).any_value() for c in set(df.column_names) - {name_col}]
    return df.groupby(name_col).agg(*aggs)

## Collect validations as pipeline

In [9]:
#| export

def do_validations(df: daft.DataFrame, validations: list[Callable]
                    ) -> tuple[daft.DataFrame, daft.DataFrame]:
    """process checks pipeline"""
    for validation in validations:
        print(f"Checking {validation.__qualname__}")
        df, dropped = split_on_condition(df, validation)

        if dropped:
            print(f"{dropped.count_rows()} images failed check for {validation.__name__} and will be dropped.")
            print(dropped.head(1))  # Print first dropped row as an example

    return df, dropped


In [10]:
#| export

# Define pipeline of conditions
pipeline: list[Callable[[daft.DataFrame], daft.DataFrame]] = [size_nontrivial, img_not_oblong, img_name_distinct]

validate_images = partial(do_validations, validations=pipeline)

In [11]:
# | hide
import nbdev

nbdev.nbdev_export()