In [1]:
"""Process raw data."""
from functools import partial
import logging
import os
import sys
autofocus_path = os.path.abspath(os.path.join('../../..'))
if autofocus_path not in sys.path:
    sys.path.append(autofocus_path)
from autofocus.build_dataset.lpz_2016_2017.ops import (
    record_is_grayscale,
    record_mean_brightness,
    trim_bottom,
)
from autofocus.build_dataset.helpers import has_channels_equal

from pathlib import Path
import time
import numpy as np
import cv2 as cv
from typing import DefaultDict

from creevey import CustomReportingPipeline
from creevey.load_funcs.image import load_image_from_disk
from creevey.ops.image import resize
from creevey.path_funcs import replace_dir
from creevey.util.image import find_image_files
from creevey.write_funcs.image import write_image
from fastai.vision import verify_images
import pandas as pd
from typing import Union

In [2]:
from os.path import dirname, abspath
current = abspath('')

DATASET_NAME = current
REPO_DIR = dirname(dirname(dirname(current)))
DATA_DIR = f"{REPO_DIR}/data/"
PathOrStr = Union[Path, str]

In [3]:
MIN_DIM = 512
N_JOBS = 5
NUM_PIXELS_TO_TRIM = 1

THIS_DATASET_DIR = DATA_DIR + "lpz_2012-2014/lpz_2012-2014/"
RAW_IMG_DIR = THIS_DATASET_DIR + 'raw/'
# RAW_CSV_FILENAMES = ["detections_2016.csv", "detections_2017.csv"]
# RAW_CSV_PATHS = [RAW_DIR / fn for fn in RAW_CSV_FILENAMES]

PROCESSED_DIR = THIS_DATASET_DIR + "processed/"
PROCESSED_IMAGE_DIR = PROCESSED_DIR + "images/"
# PROCESSED_LABELS_CSV_OUTPATH = PROCESSED_DIR / "labels.csv"

# CORRUPTED_FILES = [
#     RAW_DIR / "images_2016" / "DPT" / "D03-AMP1" / "._CHIL - D03-AMP1-JU16_00037.JPG"
# ]

In [4]:
# # THIS IS FROM HELPERS>PY

# def has_channels_equal(image: np.array) -> bool:
#     """
#     Indicate whether all channels have equal values.

#     Assumes that channels lie along the final axis.

#     Parameters
#     ----------
#     image

#     Returns
#     -------
#     True if all channels have equal values, including when there is only
#     one channel as long as there is an axis corresponding to that
#     channel (e.g. a grayscale image with shape height x width x 1, but
#     not one with shape height x width).

#     """
#     first_channel = image[..., 0]
#     return all(
#         [
#             np.equal(image[..., channel_num], first_channel).all()
#             for channel_num in range(1, image.shape[-1])
#         ]
#     )

In [5]:
# # THIS IS FROM OPS.PY
# def trim_bottom(image: np.array, num_pixels: int, **kwargs) -> np.array:
#     """
#     Trim off the bottom of an image.

#     `kwargs` included only for compatibility with Creevey's
#     `CustomReportingPipeline`

#     Parameters
#     ----------
#     image
#     num_pixels
#         Height of strip to trim off the bottom of the image, in pixels

#     Returns
#     -------
#     Trimmed image

#     """
#     return image[:-num_pixels, :]


# def record_mean_brightness(
#     image: np.array, inpath: PathOrStr, log_dict: DefaultDict[str, dict]
# ) -> np.array:
#     """
#     Record the mean brightness of image.

#     Parameters
#     ----------
#     image
#     inpath
#         Image input path
#     log_dict
#         Dictionary of image metadata

#     Side effect
#     -----------
#     Adds a "mean_brightness" item to log_dict[inpath]

#     """
#     is_grayscale = has_channels_equal(image)

#     if is_grayscale:
#         image_gray = image
#     else:
#         image_gray = cv.cvtColor(src=image, code=cv.COLOR_RGB2GRAY)

#     log_dict[inpath]["mean_brightness"] = image_gray.mean()

#     return image


# def record_is_grayscale(
#     image: np.array, inpath: PathOrStr, log_dict: DefaultDict[str, dict]
# ) -> None:
#     """
#     Record whether image is grayscale.

#     In this dataset, grayscale images have been saved as three-channel
#     images with all three channels equal, so this function checks for
#     equality across channels rather than the number of channels.

#     Parameters
#     ----------
#     image
#     inpath
#         Image input path
#     log_dict
#         Dictionary of image metadata

#     Side effect
#     -----------
#     Adds a "grayscale" item to log_dict[inpath]

#     """
#     is_grayscale = has_channels_equal(image)

#     log_dict[inpath]["grayscale"] = int(is_grayscale)

#     return image

In [6]:
TEST_IMG_DIR = RAW_IMG_DIR + 'SP12/DPT/D02-HUP1-SP12/'

In [7]:
def _process_images():
    # Bottom 198 pixels are often a footer of camera information. I
    # suspect that those pixels are more likely to lead the model to
    # learn batch effects that do not generalize than to lead to genuine
    # learning, so I remove them.
    trim_footer = partial(trim_bottom, num_pixels=NUM_PIXELS_TO_TRIM)
    resize_min_dim = partial(resize, min_dim=MIN_DIM)
    ops = [trim_footer, resize_min_dim, record_is_grayscale, record_mean_brightness]

    trim_resize_pipeline = CustomReportingPipeline(
        load_func=load_image_from_disk, ops=ops, write_func=write_image
    )

    image_paths = find_image_files(TEST_IMG_DIR)
    path_func = partial(replace_dir, outdir=PROCESSED_IMAGE_DIR)

    run_record = trim_resize_pipeline.run(
        inpaths=image_paths,
        path_func=path_func,
        n_jobs=N_JOBS,
        skip_existing=False,
        exceptions_to_catch=ZeroDivisionError,
    )
    logging.info("Checking for additional corrupted images")
    run_record = _delete_bad_images(run_record)
    
    return run_record

In [8]:
def _delete_bad_images(run_record):
    verify_images(PROCESSED_IMAGE_DIR, delete=True)
    is_file = run_record.loc[:, "outpath"].apply(os.path.isfile)
    run_record = run_record.loc[is_file, :]

    return run_record

In [9]:
def _extract_seasons(file_name):
    # For parsing the seasons from the File Names
    # The season names are based on the codes provided by Lincoln Park Zoo researchers
    file_name = file_name.split("-")[3]
    if file_name.startswith(("JA", "WI")):
        return "Winter"
    elif file_name.startswith(("AP", "SP")):
        return "Spring"
    elif file_name.startswith(("JU", "SU")):
        return "Summer"
    else:
        return "Fall"

In [10]:
LABELS_PATH = DATA_DIR + 'lpz_2012-2014/lpz_2012-2014/raw/labels_clean.csv'
RAW_CSV = pd.read_csv(LABELS_PATH)

In [None]:
def _process_labels(run_record):
    raw_df = (
        pd.read_csv(RAW_CSV)
        .set_index("FileName")
        .drop(["Unnamed: 0", Sure], axis="columns")
        .rename(columns={"ShortName": "label", "ImageDate": "date"})
    )

    run_record.index = pd.Series(run_record.index).apply(lambda path: Path(path).name)

    processed_df = (
        run_record.drop(
            ["skipped_existing", "exception_handled", "time_finished"], axis="columns"
        )
        .join(raw_df, how="left")
        .loc[:, ["outpath", "label", "grayscale", "mean_brightness", "date"]]
        .reset_index(drop=True)
    )
    processed_df.loc[:, "filename"] = processed_df.loc[:, "outpath"].apply(
        lambda path: Path(path).name
    )
    processed_df.loc[:, "location"] = processed_df.loc[:, "filename"].apply(
        lambda fn: fn.split("-")[2]
    )
    processed_df.loc[:, "season"] = processed_df.loc[:, "filename"].apply(
        _extract_seasons
    )
    processed_df = processed_df.drop("outpath", axis="columns")

    return processed_df

In [None]:
def main() -> None:
    """
    Process raw data.

    Delete blacklisted corrupted images. Trim a footer from each image
    and resize it to 512 pixels on its shorter dimension. Write results
    to "autofocus/data/processed/images". Reformat labels from CSV and
    write to a new file "autofocus/data/processed/labels.csv".

    """
#     logging.info("Deleting known corrupted files")
#     for path in CORRUPTED_FILES:
#         path.unlink()
    logging.info(f"Processing images and writing results to {PROCESSED_IMAGE_DIR}")
    run_record = _process_images()
    
#     logging.info("Processing labels")
#     labels = _process_labels(run_record)
#     logging.info(f"Writing processed labels to {PROCESSED_LABELS_CSV_OUTPATH }")
#     labels.to_csv(PROCESSED_DIR / "labels.csv", index=False)

In [None]:
if __name__ == "__main__":
    start_time = time.time()
    logging.basicConfig(format="%(levelname)s %(asctime)s %(message)s")
    logging.getLogger().setLevel(logging.INFO)

    main()

    end_time = time.time()
    logging.info(f"Completed in {round(end_time - start_time, 2)} seconds")

In [None]:
run_record = _process_images()

In [None]:
run_record.shape