# Filtering, Cropping, Resizing

In [None]:
import os
import cv2
import pandas as pd
import numpy as np
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path

def resize_image(image, new_height=448):
    """
    Resize image to a new height with square aspect ratio 
    """
    new_width = new_height

    # Resize the image
    resized_image = cv2.resize(image, (new_width, new_height), interpolation=cv2.INTER_AREA)

    return resized_image

def process_images(filename, image_dir_path, new_image_dir_path, new_image_filenames, images_df, total):
    """
    Crop and resize the images filtering out the images that do not meet the criteria.
    Save to new directory.
    
    1. Return (remove) if the image has already been processed
    2. Return if the file is not a jpg file
    3. Return if the image is a written form
    4. Return if the image is a duplicate (and not to be kept)
    5. Return if the image is corrupted
    6. Return if the image is a portrait
    7. Return if the image is too small
    8. Center crop height to 95% of the actual drawing height
    9. Center crop width to cropped height (square aspect ratio)
    10. Resize image
    11. Save the image to new image directory
    
    filename: str
    image_dir_path: str (path to the original image directory)
    new_image_dir_path: str (path to the new image directory)
    new_image_filenames: list (list of filenames in the new image directory)
    images_df: pandas DataFrame (dataframe containing the image information)
    total: dict (dictionary to store the total count of images processed)
    """

    try:
        # 1. Check if the image has already been processed
        if filename in new_image_filenames:
            return

        # 2. Check if the file is a jpg file
        if not filename.endswith(('.jpg')):
            return
        
        # 3. Check if the image is a written form
        if images_df[images_df['filename'] == filename]['written_form'].bool() == True:
            total['written_form'] += 1
            return
        
        # 4. Check if the image is a duplicate
        if images_df[images_df['filename'] == filename]['keep_duplicate'].item() == False:
            total['duplicate'] += 1
            return
        
        # 5. Check if the image is corrupted
        if images_df[images_df['filename'] == filename]['corrupted'].item() == True:
            total['corrupted'] += 1
            return

        # Read the image
        image_path = os.path.join(image_dir_path, filename)
        image = cv2.imread(image_path)
        
        # Check if the image is None
        if image is None:
            print('Filename is None: ', filename)
            return
        
        # Define the height of the drawing in pixels
        drawing_pixel_height = 2945
        
        # Aspect ratio of the drawing
        aspect_ratio = np.sqrt(2)
        
        # Calculate drawing pixel width
        drawing_pixel_width = drawing_pixel_height * aspect_ratio
        
        # 6. Check if the image is a portrait
        if (image.shape[0] > image.shape[1]):
            
            total['portrait'] += 1
            return
        
        # 7. Check if the image is too small
        if (image.shape[0] < (2/3) * drawing_pixel_height) or \
           (image.shape[1] < (2/3) * drawing_pixel_width):
               
            total['small'] += 1
            return

        # 8. Check if the image height is greater than 95% of the drawing height
        if image.shape[0] > (int(drawing_pixel_height * 0.95)): 
            
            # Define the upper and lower pixel coordinates
            height_to_crop = int((image.shape[0] - drawing_pixel_height)/2 + (drawing_pixel_height * 0.025))
        
            # Define the upper and lower pixel coordinates (array indexing starts at top)
            # If image height - drawing height is odd, lower crop is greater than upper by 1 pixel.
            upper = height_to_crop
            lower = height_to_crop + drawing_pixel_height
            
            # Center crop the height
            image = image[upper:lower, :]
            
        # 9. If image is not already square, center crop the width to the cropped height (square aspect ratio)
        if image.shape[1] > image.shape[0]:

            # Define the width to crop
            width_to_crop = (image.shape[1] - image.shape[0]) // 2
        
            # Define the left and right pixel coordinates
            # If image width - image height is odd, right crop is greater than left by 1 pixel.
            left = width_to_crop
            right = width_to_crop + image.shape[0]
            
            # Center crop the width
            image = image[:, left:right]
        
        # 10. Resize the image
        resized_image = resize_image(image, new_height=448)
        
        total['saved'] += 1

        # 11. Save the image to new directory
        #cv2.imwrite(os.path.join(new_image_dir_path, filename), resized_image)  

    except Exception as e:
        print(f"Error processing {filename}: {e}")

def main():
    """
    Main function to process the images
    """

    ### EDIT PATHS ###
    repo_root = Path.cwd().parent
    image_dir_path = "/path/to/raw_images"
    labels_path = repo_root / "data/image_preprocessing_labels.xlsx"
    new_image_dir_path = "/path/to/preprocessed_images"

    images_df = pd.read_excel(labels_path)
    
    os.makedirs(new_image_dir_path, exist_ok=True)
    
    # Dictionary to store the total count of images processed
    total = {'portrait': 0,
             'small': 0,
             'written_form': 0,
             'duplicate': 0,
             'corrupted': 0,
             'saved': 0
             }

    # Dynamically set the number of workers based on CPU count
    num_workers = os.cpu_count()
    print('Number of Workers: ', num_workers)
    
    # Get the lists of image filenames
    image_filenames = os.listdir(image_dir_path)
    new_image_filenames = os.listdir(new_image_dir_path)

    # Process the images in parallel
    with ThreadPoolExecutor(max_workers=num_workers) as executor:
        list(tqdm(executor.map(lambda filename: process_images(filename, 
                                                           image_dir_path,
                                                           new_image_dir_path, 
                                                           new_image_filenames, 
                                                           images_df,
                                                           total), image_filenames), total=len(image_filenames)))

    # Print the total count of images processed
    for i in total.keys():
        print(f'{i}: {total[i]}')

if __name__ == "__main__":
    main()

portrait: 145
small: 191
written_form: 58
duplicate: 497
corrupted: 7
saved: 9700