In [None]:
#!/usr/bin/env python

"""image_classification.py: Classify images to classify fish types"""

import os
import glob
import joblib
import cv2
import numpy as np
from scipy.cluster import vq

import pandas as pd

from sklearn.naive_bayes import GaussianNB
from sklearn.preprocessing import StandardScaler


__author__ = "Pradeep Kumar A.V."


CLASSES = {
    'ALB': 1,
    'BET': 2,
    'DOL': 3,
    'LAG': 4,
    'NoF': 5,
    'OTHER': 6,
    'SHARK': 7,
    'YFT': 8
}

CLASSES_REV = {value: key for key, value in CLASSES.items()}


class Saliency(object):
    """Generate saliency map from RGB images with the spectral residual method

        This class implements an algorithm that is based on the spectral
        residual approach (Hou & Zhang, 2007).
    """
    def __init__(self, img, use_numpy_fft=True, gauss_kernel=(3, 3)):
        """Constructor

            This method initializes the saliency algorithm.

            :param img: an RGB input image
            :param use_numpy_fft: flag whether to use NumPy's FFT (True) or
                                  OpenCV's FFT (False)
            :param gauss_kernel: Kernel size for Gaussian blur
        """
        self.use_numpy_fft = use_numpy_fft
        self.gauss_kernel = gauss_kernel
        self.frame_orig = self._enhance_image(img)

        # downsample image for processing
        self.small_shape = (24, 24)
        self.frame_small = cv2.resize(img, self.small_shape[1::-1])

    @staticmethod
    def _enhance_image(img):
        """
        :param img: RGB color image
        :return: enhanced image
        """
        img_yuv = cv2.cvtColor(img, cv2.COLOR_RGB2YUV)

        # equalize the histogram of the Y channel
        img_yuv[:, :, 0] = cv2.equalizeHist(img_yuv[:, :, 0])

        # convert the YUV image back to RGB format
        img_output = cv2.cvtColor(img_yuv, cv2.COLOR_YUV2BGR)
        return img_output

    def get_saliency_map(self):
        """Returns a saliency map

            This method generates a saliency map for the image that was
            passed to the class constructor.

            :returns: grayscale saliency map
        """
        # haven't calculated saliency map for this image yet
        # multiple channels: consider each channel independently
        sal = np.zeros_like(self.frame_small).astype(np.float32)
        for c in xrange(self.frame_small.shape[2]):
            small = self.frame_small[:, :, c]
            sal[:, :, c] = self._get_channel_sal_magn(small)

        # overall saliency: channel mean
        sal = np.mean(sal, 2)

        # postprocess: blur, square, and normalize
        if self.gauss_kernel is not None:
            sal = cv2.GaussianBlur(sal, self.gauss_kernel, sigmaX=1,
                                   sigmaY=0)
        sal **= 2
        sal = np.float32(sal)/np.max(sal)

        # scale up
        sal = cv2.resize(sal, self.frame_orig.shape[1::-1])

        return sal

    def _get_channel_sal_magn(self, channel):
        """Returns the log-magnitude of the Fourier spectrum

            This method calculates the log-magnitude of the Fourier spectrum
            of a single-channel image. This image could be a regular grayscale
            image, or a single color channel of an RGB image.

            :param channel: single-channel input image
            :returns: log-magnitude of Fourier spectrum
        """
        # do FFT and get log-spectrum
        if self.use_numpy_fft:
            img_dft = np.fft.fft2(channel)
            magnitude, angle = cv2.cartToPolar(np.real(img_dft),
                                               np.imag(img_dft))
        else:
            img_dft = cv2.dft(np.float32(channel),
                              flags=cv2.DFT_COMPLEX_OUTPUT)
            magnitude, angle = cv2.cartToPolar(img_dft[:, :, 0],
                                               img_dft[:, :, 1])

        # get log amplitude
        log_ampl = np.log10(magnitude.clip(min=1e-9))

        # blur log amplitude with avg filter
        log_ampl_blur = cv2.blur(log_ampl, (3, 3))

        # residual
        residual = np.exp(log_ampl - log_ampl_blur)

        # back to cartesian frequency domain
        if self.use_numpy_fft:
            real_part, imag_part = cv2.polarToCart(residual, angle)
            img_combined = np.fft.ifft2(real_part + 1j*imag_part)
            magnitude, _ = cv2.cartToPolar(np.real(img_combined),
                                           np.imag(img_combined))
        else:
            img_dft[:, :, 0], img_dft[:, :, 1] = cv2.polarToCart(residual,
                                                                 angle)
            img_combined = cv2.idft(img_dft)
            magnitude, _ = cv2.cartToPolar(img_combined[:, :, 0],
                                           img_combined[:, :, 1])

        return magnitude


class FishClassifierBOVW(object):
    def __init__(self, pre_trained_model=None, descriptor='ORB',
                 n_visual_words=5, use_saliency=False):
        self.model = joblib.load(pre_trained_model) \
            if pre_trained_model else None
        self.n_visual_words = n_visual_words
        self.descriptor = descriptor
        self.use_saliency = use_saliency

    # Helper functions
    @staticmethod
    def _load_img(path):
        """
        :param path: path of image to be loaded.
        :return: cv2 image object
        """
        img = cv2.imread(path)
        # Convert the image from cv2 default BGR format to RGB
        return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

    @staticmethod
    def _pretty_print(msg):
            print()
            print('=' * len(msg))
            print(msg)
            print('=' * len(msg))

    def _detect_and_describe(self, image):
        """
        :param image: Input RGB color image
        :return: keypoints and features tuple
        """
        # detect and extract features from the image
        if self.descriptor == 'SIFT':
            descriptor = cv2.xfeatures2d.SIFT_create()
        elif self.descriptor == 'ORB':
            descriptor = cv2.ORB_create()
        else:
            data = image.reshape((-1, 3))
            data = np.float32(data)
            return None, data
        (kps, features) = descriptor.detectAndCompute(image, None)

        # convert the keypoints from KeyPoint objects to NumPy
        # arrays
        kps = np.float32([kp.pt for kp in kps])
        features = np.float32(features)

        # return a tuple of keypoints and features
        return kps, features

    @staticmethod
    def _kmeans_clustering(data, k=5):
        """
        :param data: input data
        :param k: K value
        :return: k-Means clusters
        """
        crit = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 10, 1.0)
        flags = cv2.KMEANS_RANDOM_CENTERS
        ret, label, centers = cv2.kmeans(data, k, None, crit, 10, flags)
        return centers

    #  Main wrapper methods

    def _extract_img_features(self, img_data_dir, mode='train'):
        """
        :param img_data_dir: directory path where the images reside.
         The training images should reside in class named folders
        :param: mode: 'train' or 'test'
        :return:
        """
        if mode == 'train':
            files = glob.glob("%s/*/*" % img_data_dir)
        else:
            files = glob.glob("%s/*" % img_data_dir)
        dataset_size = len(files)
        resp = np.zeros((dataset_size, 1))

        print("\nProcessing images, and generating descriptors..\n")
        ctr = 0
        des_list = []
        for f in files:
            print("Processing image %s" % f)
            img = self._load_img(f)
            if self.use_saliency:
                sal = Saliency(img)
                smap = sal.get_saliency_map()
                img[:, :, 0] *= smap
                img[:, :, 1] *= smap
                img[:, :, 2] *= smap
            kpts, des = self._detect_and_describe(img)
            des_list.append((f, des))
            if type == 'train':
                resp[ctr] = CLASSES[f.split('/')[-2]]
                ctr += 1

        descriptors = des_list[0][1]
        for image_path, descriptor in des_list[1:]:
            descriptors = np.vstack((descriptors, descriptor))

        print("\nClustering the descriptors to form BOVW dictionary..\n")
        centers = self._kmeans_clustering(descriptors, self.n_visual_words)
        im_features = np.zeros((dataset_size, self.n_visual_words), "float32")
        for i in range(dataset_size):
            words, distance = vq.vq(des_list[i][1], centers)
            for w in words:
                im_features[i][w] += 1

        # Scaling the values of features
        slr = StandardScaler().fit(im_features)
        im_features = slr.transform(im_features)

        resp = np.float32(resp)
        return files, im_features, resp

    def train_classifier(self, train_data_dir, save_model=True):
        """
        :param train_data_dir: training data directory
        :param save_model: save classifier model as a pickle file
        :return: None
        """
        # Extract features and train the classifier
        self._pretty_print("Extracting training image features")
        train_files, train_data, train_resp = \
            self._extract_img_features(train_data_dir)
        self._pretty_print("Training the classifier")
        self.model = GaussianNB()
        self.model.fit(train_data, train_resp)
        if save_model:
            joblib.dump(self.model, 'model.pkl', protocol=2)

    def test_classifier(self, test_data_dir, submission_file="submission.csv"):
        """
        :param test_data_dir: test data directory
        :param submission_file: file name to save predictions
        :return: None
        """
        self._pretty_print("Extracting testing image features")
        test_files, test_data, test_resp = \
            self._extract_img_features(test_data_dir, type='test')
        self._pretty_print("Testing the classifier")
        predictions = self.model.predict_proba(test_data)

        columns = [CLASSES_REV[int(entry)] for entry in self.model.classes_]
        submission = pd.DataFrame(predictions, columns=columns)
        images = [f.split('/')[-1] for f in test_files]
        submission.insert(0, 'image', images)
        submission.head()
        submission.to_csv(submission_file, index=False)


def main():
    """
    Main wrapper to call the classifier
    :return: None
    """
    pre_trained_model_file = 'model.pkl'
    training_data_dir = '../input/train'
    testing_data_dir = '../input/test_stg1'
    submission_file_name = 'Bag_of_visual_words_ORB_NB.csv'

    parameters = {
        'descriptor': 'ORB',
        'n_visual_words': 5,
        'use_saliency': True
    }

    if os.path.exists(pre_trained_model_file):
        cls = FishClassifierBOVW(pre_trained_model_file, **parameters)
        cls.train_classifier(training_data_dir)
    else:
        cls = FishClassifierBOVW(**parameters)

    cls.test_classifier(testing_data_dir, submission_file_name)


if __name__ == '__main__':
    main()


In [None]:
from subprocess import check_output
print(check_output(["ls"]).decode("utf8"))