In [3]:
import cv2
import pickle
import numpy as np
from sklearn.cluster import KMeans
from sklearn import svm
from sklearn.linear_model import LogisticRegression
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score


class CIFAR_10:
    def __init__(self):
        # Path for the input data
        self.INPUT_PATH = './cifar-10-batches-py'
        self.TEST_BATCH = '/test_batch'

        # Number of clusters for KMeans clustering
        self.n_clusters = 128

        # SIFT feature extractor
        self.sift = cv2.SIFT_create()

        # Load the training data and perform preprocessing
        train_images, self.train_labels = self.get_imgs_labels('/data_batch_1')
        vocab = self.preprocessing(train_images)
        self.train_features = self.BOW(train_images, vocab)

        # Load the test data and perform preprocessing
        test_images, self.test_labels = self.get_imgs_labels(self.TEST_BATCH)
        self.test_features = self.BOW(test_images, vocab)

        # Calculate and print the accuracies for SVM and Logistic Regression classifiers
        self.accuracy_calc()

    # Function to unpickle data from the CIFAR-10 dataset
    def unpickle(self, file):
        with open(file, 'rb') as fo:
            dict1 = pickle.load(fo, encoding='bytes')
        return dict1

    # Function to load the images and labels from a batch file
    def get_imgs_labels(self, batch_path):
        loaded_data = self.unpickle(self.INPUT_PATH + batch_path)

        # Extract images and labels from the loaded data
        imgs = loaded_data[b'data']
        labels = loaded_data[b'labels']

        img_list = []
        # Convert image data from raw pixel values to image format
        for img in imgs:
            list1 = img[:1024]
            list2 = img[1024:2048]
            list3 = img[2048:3072]
            list1 = list1.reshape((32, 32))
            list2 = list2.reshape((32, 32))
            list3 = list3.reshape((32, 32))
            final = np.dstack((list1, list2, list3))
            img_list.append(final)

        return (img_list, labels)

    # Function to preprocess the images using SIFT feature extractor and KMeans clustering
    def preprocessing(self, imgs):
        descriptors_list = []

        for img in imgs:
            # Convert image to grayscale and extract SIFT keypoints and descriptors
            gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
            keypoints, desc = self.sift.detectAndCompute(gray, None)

            # If descriptors are found, add them to the list of descriptors
            if isinstance(desc, np.ndarray):
                descriptors_list.append(desc)

        # Concatenate the descriptors into a single array
        descriptors = np.concatenate(descriptors_list, axis=0) if descriptors_list else np.asarray([])

        return KMeans(n_clusters=self.n_clusters).fit(descriptors)

    # Function to compute the Bag of Words representation of images
    def BOW(self, imgs, vocab):
        BOW_Array = []

        for img in imgs:
            counts = [0 for _ in range(self.n_clusters)]
            
            # Convert image to grayscale and extract SIFT keypoints and descriptors
            gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
            keypoints, descriptors = self.sift.detectAndCompute(gray, None)
            if isinstance(descriptors, np.ndarray):
                pred = vocab.predict(descriptors)
                for element in pred:
                    counts[element] += 1
            BOW_Array.append(counts)

        return np.asarray(BOW_Array)

    def accuracy_calc(self):
        self.print_accuracy(svm.SVC(C=0.005, kernel='linear'), "Accuracy (SVM): ")
        self.print_accuracy(LogisticRegression(max_iter=1000), "Accuracy (LR): ")

    def print_accuracy(self, model, score):
        model.fit(self.train_features, self.train_labels)
        prediction = model.predict(self.test_features)
        accuracy = accuracy_score(self.test_labels, prediction)
        print(score, accuracy)


CIFAR_DATA = CIFAR_10()

Accuracy (SVM):  0.2569
Accuracy (LR):  0.2594
