# Pre-process the dataset for FelineFlow
The original Cats vs. Dogs dataset consists of more than 12,000 images, each of different sizes of cats alone! We need to modify the dataset to suit our needs accordingly.

In this notebook, we will apply the following transformations:
- Select 4096 images for training and testing.
- Crop the images to a 1:1 aspect ratio.
- Check if the images have a single cat, in frame using Haar Cascades (and choose another image otherwise)
- Downscale the image to 128x128.

Let's get started by importing the necessary modules

In [109]:
import os
import cv2

Let's create a function to crop an image to 1:1

In [110]:
def crop_square(img_path):
    img = cv2.imread(img_path)
    height, width = img.shape[:2]
    if(height==width):
        return img
    if(height>width):
        return img[(height-width)//2: (height+width)//2, :]
    else:
        return img[:, (width-height)//2: (width+height)//2]

Function to resize the image to a specified resolution

In [111]:
def resize(img, res):
    return cv2.resize(img, res, interpolation=cv2.INTER_AREA)

A module that recognises cat faces. We will be using the [Cat Frontal Face Haar Cascade](https://github.com/opencv/opencv/blob/master/data/haarcascades/haarcascade_frontalcatface_extended.xml) contributed by Joseph Howse provided on the OpenCV GitHub repository

In [129]:
def isValidCat(img):
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    cascade = cv2.CascadeClassifier('./haarcascade_frontalcatface.xml')
    faces = cascade.detectMultiScale(gray, 1.1, 3) #image, reject levels, level weight
    return len(faces)

In [113]:
isValidCat(crop_square('./cats_source/122.jpg'))

False

And finally, a function to generate the dataset

In [116]:
def generate_dataset(input_dir, output_dir, res, num_images):
    os.makedirs(output_dir, exist_ok=True)
    
    input_files = os.listdir(input_dir)
    img_len = len(input_files)
    if(img_len<num_images):
        raise Exception("Not enough images in source directory")
    count=1
    for file in input_files:
        if(count>=num_images):
            break
        try:
            destination_path = os.path.join(output_dir, str(count))+'.jpg'
            cropped_image = crop_square(os.path.join(input_dir, file))
            if(not isValidCat(cropped_image)):
                continue
                            
            resized_image = resize(cropped_image, res)
            print(destination_path, 'saved')
            cv2.imwrite(destination_path, resized_image)
            count+=1
        except AttributeError:
            continue    

Let's specify the parameters globally and watch the script work its magic!

In [117]:
generate_dataset('./cats_source', './cats_processed', (256,256), 4096)

./cats_processed\1.jpg saved
./cats_processed\2.jpg saved
./cats_processed\3.jpg saved
./cats_processed\4.jpg saved
./cats_processed\5.jpg saved
./cats_processed\6.jpg saved
./cats_processed\7.jpg saved
./cats_processed\8.jpg saved
./cats_processed\9.jpg saved
./cats_processed\10.jpg saved
./cats_processed\11.jpg saved
./cats_processed\12.jpg saved
./cats_processed\13.jpg saved
./cats_processed\14.jpg saved
./cats_processed\15.jpg saved
./cats_processed\16.jpg saved
./cats_processed\17.jpg saved
./cats_processed\18.jpg saved
./cats_processed\19.jpg saved
./cats_processed\20.jpg saved
./cats_processed\21.jpg saved
./cats_processed\22.jpg saved
./cats_processed\23.jpg saved
./cats_processed\24.jpg saved
./cats_processed\25.jpg saved
./cats_processed\26.jpg saved
./cats_processed\27.jpg saved
./cats_processed\28.jpg saved
./cats_processed\29.jpg saved
./cats_processed\30.jpg saved
./cats_processed\31.jpg saved
./cats_processed\32.jpg saved
./cats_processed\33.jpg saved
./cats_processed\34

KeyboardInterrupt: 