Note: the data used in this project can be downloaded from: https://www.dropbox.com/s/cst9awcjpp08k33/50_categories.tar.gz

## Imports

In [26]:
from skimage import feature, filters
from skimage.io import imread
from skimage.segmentation import felzenszwalb
from skimage.color import rgb2grey
from itertools import combinations
import numpy as np
import os
import time
import pickle as pkl
from sklearn.model_selection import train_test_split, GridSearchCV, StratifiedKFold
from sklearn.preprocessing import StandardScaler
from sklearn.dummy import DummyClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, precision_score, recall_score, zero_one_loss
import matplotlib.pyplot as plt
%matplotlib inline

## Compute Features from Data

In [2]:
def feature_extract(img_file):
    
    # read image
    im_arr = imread(img_file)
    
    # also get greyscale image to use for some of the feature extraction
    gi = rgb2grey(im_arr)
    
    # check if image has only one color channel and if so stack the image three times to ensure that there are three
    # (unfortunately) identical color channels
    if len(im_arr.shape) == 2:
        im_arr = np.dstack((im_arr, im_arr, im_arr))
    
    ## instatiate list (later to convert to array) to hold features
    features = []
    
    # first do some dumb features:
    
    # calculate the mean in each color
    color_means = im_arr.mean(axis=1).mean(axis=0)
    
    # use the ratio of maximum value in each color to the mean of each color
    color_max_div_mean = im_arr.max(axis=1).max(axis=0) / color_means
    features += list(color_max_div_mean)
     
    # use the ratio of standard deviations in each color to mean in each color as another set of features
    color_std_div_mean = im_arr.std(axis=1).mean(axis=0) / color_means
    features += list(color_std_div_mean)
    
    # use ratios of means, and correlation coefficients between flattened as additional features
    mean_ratios = []
    corr_coefs = []
    for idx_pair in combinations(range(3), 2):
        mean_ratios.append(color_means[idx_pair[0]] / color_means[idx_pair[1]])
        corr_coefs.append(np.corrcoef(im_arr[:,:,idx_pair[0]].flatten(), im_arr[:,:,idx_pair[1]].flatten())[0,1])
    features += mean_ratios + corr_coefs
    
    # encode edge information
    for i in range(3):
        features.append(np.mean(filters.sobel(im_arr[:,:,i])))
        features.append(np.mean(filters.sobel_v(im_arr[:,:,i])))
        features.append(np.mean(filters.sobel_h(im_arr[:,:,i])))
        
    # encode segmentation information
    features.append(felzenszwalb(im_arr).mean() / im_arr.mean())
    
    # include proportion of edges detected in greyscale image relative to number of pixels
    features.append(np.count_nonzero(feature.canny(gi)) / (np.shape(gi)[0] * np.shape(gi)[1]))
    
    # include ratio of max to mean of SIFT extracted features
    tmp = feature.daisy(gi)
    features.append(tmp.max() / tmp.mean())
    
    return np.array(features)

In [4]:
%%timeit
feature_extract('50_categories/bat/bat_0060.jpg')

370 ms ± 16.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


## Run Feature Extraction

In [5]:
path = '50_categories/'

# get labels
labels = np.array([dr for dr in os.listdir(path) if '.DS_Store' not in dr])

# process features
X = []
y = []
for label in labels:
    print('Processing label: {}'.format(label))
    for fl in [fl for fl in os.listdir(path + label) if '.DS_Store' not in fl]:
        X.append(feature_extract(path + label + '/' + fl))
        y.append(label)
        
X = np.array(X)
y = np.array(y)

Processing label: gorilla
Processing label: raccoon
Processing label: crab
Processing label: blimp
Processing label: snail
Processing label: airplanes
Processing label: dog
Processing label: dolphin
Processing label: goldfish
Processing label: giraffe
Processing label: bear
Processing label: killer-whale
Processing label: penguin
Processing label: zebra
Processing label: duck
Processing label: conch
Processing label: camel
Processing label: owl
Processing label: helicopter
Processing label: starfish
Processing label: saturn
Processing label: galaxy
Processing label: goat
Processing label: iguana
Processing label: elk
Processing label: hummingbird
Processing label: triceratops
Processing label: porcupine
Processing label: teddy-bear
Processing label: comet
Processing label: hot-air-balloon
Processing label: leopards
Processing label: toad
Processing label: mussels
Processing label: kangaroo
Processing label: speed-boat
Processing label: bat
Processing label: swan
Processing label: octop

## Perform Classification

In [37]:
# prepare training and testing data
X_tr_tmp, X_test_tmp, y_train, y_test = train_test_split(X, y, test_size = 0.2, stratify = y, random_state = 100)

# fit and scale training data
X_scaler = StandardScaler()
X_train = X_scaler.fit_transform(X_tr_tmp)

# use scaling from training data to transform testing data
X_test = X_scaler.transform(X_test_tmp)

In [38]:
# determine baseline
d_clf = DummyClassifier(strategy='prior')
d_clf.fit(X_train, y_train)
d_clf.score(X_test, y_test)

0.12603062426383982

In [39]:
# do random forest classification with default params
rf_clf = RandomForestClassifier(class_weight='balanced', n_jobs=-1, random_state=100)
rf_clf.fit(X_train, y_train)
rf_clf.score(X_test, y_test)

0.21554770318021202

In [40]:
# do grid search over parameters with random forest classifier (note GridSearch provides access to best model by default when calling score/predict/etc)
parameters = {'n_estimators': [10, 50, 150, 200, 300], 'max_depth': [10, 50, 100], 'min_samples_split': [2, 3, 4, 5]}
cross_val = StratifiedKFold(n_splits=6, random_state = 100)
gs = GridSearchCV(RandomForestClassifier(class_weight='balanced', n_jobs=-1, random_state=100), parameters, cv = cross_val, n_jobs=-1)
gs.fit(X_train, y_train)
gs.score(X_test, y_test)

0.2767962308598351

## Evaluate Classification

In [41]:
def eval_class(clf, X_test, y_test):
    
    pred = clf.predict(X_test)
    
    print('Classification Metrics, between 0 and 1\n')
    
    print('Accuracy Score: {:.3f}'.format(accuracy_score(y_test, pred)))
    print('proportion of correct classifications - higher better\n')
    
    print('Precision Score: {:.3f}'.format(precision_score(y_test, pred, average='weighted')))
    print('tp / (tp + fp), how good at not having fp - higher better\n')
    
    print('Recall Score: {:.3f}'.format(recall_score(y_test, pred, average='weighted')))
    print('tp / (tp + fn), how good at finding positives - higher better\n')
    
    print('Zero-One Loss: {:.3f}'.format(zero_one_loss(y_test, pred)))
    print('fraction of misclassifications - smaller better')
    
    print('\nFeature Importances: {}'.format(clf.feature_importances_))

In [42]:
eval_class(rf_clf, X_test, y_test)

Classification Metrics, between 0 and 1

Accuracy Score: 0.216
proportion of correct classifications - higher better

Precision Score: 0.191
tp / (tp + fp), how good at not having fp - higher better

Recall Score: 0.216
tp / (tp + fn), how good at finding positives - higher better

Zero-One Loss: 0.784
fraction of misclassifications - smaller better

Feature Importances: [0.04257338 0.04363274 0.04058682 0.03779667 0.04043667 0.04382056
 0.04664514 0.04677576 0.04814077 0.04128299 0.04487028 0.04145193
 0.03660696 0.03296103 0.03440791 0.03893964 0.03573133 0.03743222
 0.04086285 0.03518597 0.04531534 0.05204991 0.04869899 0.04379412]


In [43]:
eval_class(gs.best_estimator_, X_test, y_test)

Classification Metrics, between 0 and 1

Accuracy Score: 0.277
proportion of correct classifications - higher better

Precision Score: 0.227
tp / (tp + fp), how good at not having fp - higher better

Recall Score: 0.277
tp / (tp + fn), how good at finding positives - higher better

Zero-One Loss: 0.723
fraction of misclassifications - smaller better

Feature Importances: [0.0425916  0.04262152 0.04069943 0.03980456 0.04040433 0.04217345
 0.04661018 0.04615242 0.04600586 0.04545546 0.04376005 0.04185155
 0.03786187 0.03571514 0.03728017 0.03784756 0.03602148 0.03800747
 0.03942975 0.03478082 0.04032359 0.05230722 0.04684658 0.04544794]


## Save Model for Future Use

In [44]:
with open('model.pkl', 'wb') as f:
    pkl.dump({'model': gs, 'X_scaler': X_scaler}, f)

## Package Classifier for Future Use

In [45]:
def run_final_classifier(path, img_type = '.jpg', model_file = 'model.pkl', output_fname = 'predicted_classes.txt', return_arrays = False):
    
    # read model and scaler from model file
    with open(model_file, 'rb') as f:
        m = pkl.load(f)
    model = m['model']
    X_scaler = m['X_scaler']
    
    # do basic validation of provide path
    if '/' != path[-1]:
        path += '/'
    
    # get list of image files to classify
    im_files = np.array([fl for fl in os.listdir(path) if img_type in fl])
    
    # get timing estimate
    t1 = time.time()
    feature_extract(path + im_files[0])
    t2 = time.time()
    dt = t2 - t1
    
    # do feature extraction
    print('Extracting Features from {} images.'.format(len(im_files)))
    X = []
    for idx, fl in enumerate(im_files):
        if idx % 50 == 0:
            print('\n\tIteration: {} of {}'.format(idx, len(im_files)))
            print('\tEstimated Time Remaining: {:.1f} seconds'.format(dt * (len(im_files) - idx)))
        X.append(feature_extract(path + fl))
        
    # scale X data
    X = X_scaler.transform(np.array(X))
    
    # do classification
    predicted_classes = model.predict(X)
    
    print('\nClasses predicted, writing output to: {}'.format(output_fname))
    
    # write output file
    with open(output_fname, 'w') as f:
        f.write('{:<20} {}\n'.format('filename', 'predicted_class'))
        f.write('-'*37 + '\n')
        for idx, fl in enumerate(im_files):
            f.write('{:<20} {}\n'.format(fl, predicted_classes[idx]))
            
    # optionally return arrays
    if return_arrays is True:
        return im_files, predicted_classes

In [None]:
run_final_classifier('50_categories/bat/')

Extracting Features from 71 images.

	Iteration: 0 of 71
	Estimated Time Remaining: 182.8 seconds

	Iteration: 50 of 71
	Estimated Time Remaining: 54.1 seconds
