<a href="https://colab.research.google.com/github/bergeramit/Cactus/blob/master/test_KNN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#Settings

In [16]:
import os
import csv
from pathlib import Path

import cv2
import math

import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from scipy import ndimage as ndi
from sklearn import preprocessing 
from sklearn.neighbors import KNeighborsClassifier

import networkx as nx

import pandas as pd
import scipy.misc
import glob
import scipy.misc


In [56]:
GIST_FEATURES_DIR_PATH = "/content/drive/MyDrive/CD_Project_drive/GIST_features"
LABELS_PATH = "/content/drive/MyDrive/CD_Project_drive/trainLabels.csv"
LABELS = ['Ramnit', 'Lollipop', 'Kelihos_ver3', 'Vundo', 'Simda', 'Tracur', 'Kelihos_ver1', 'Obfuscator.ACY', 'Gatak']
TRAIN_SET_SIZE = 300
TEST_SET_SIZE = 100

# Setup

In [47]:
# generating label mapping
name_to_label = {}
with open(LABELS_PATH, "r") as f:
  reader = csv.reader(f)
  name_to_label = {row[0]: row[1] for row in reader}

In [48]:
# Extract hash form file names
hashes = []
dataset = []
for file_name in os.listdir(GIST_FEATURES_DIR_PATH):
  hashes.append(Path(file_name).stem)
  dataset.append(Path(file_name).stem, )

In [49]:
# Example hash
hashes[0]

'7cPUox9BR6vdGAWIKZg0'

In [50]:
# Example map
hash_to_label[hashes[30]]

'8'

# Gist Descriptor

In [51]:
def compute_feats(image, kernels):
    feats = np.zeros((len(kernels), 2), dtype=np.double)
    for k, kernel in enumerate(kernels):
        filtered = ndi.convolve(image, kernel, mode='wrap')
        feats[k, 0] = filtered.mean()
        feats[k, 1] = filtered.var()
    return feats

def match(feats, ref_feats):
    min_error = np.inf
    min_i = None
    for i in range(ref_feats.shape[0]):
        error = np.sum((feats - ref_feats[i, :])**2)
        if error < min_error:
            min_error = error
            min_i = i
    return min_i

def plot_single(gs):
    (kernel, powers) = gs
    a = (kernel * 255).astype(np.uint8)
    plt.figure()
    plt.imshow(a, 'gray')
    
def power_single(gs):
    (kernel, powers) = gs
    return powers*255

def implot(im, gray=False):
    cv_rgb = cv2.cvtColor(im.astype(np.uint8), cv2.COLOR_BGR2RGB)
    plt.imshow(cv_rgb)
    
def np2to3(im):
    # convert 2d to 3d naively
    new_im = np.zeros((im.shape[0], im.shape[1], 3))
    r, c = im.shape
    for x in range(r):
        for y in range(c):
            new_im[x, y, :] = im[x,y]
    return new_im

def power(image, kernel):
    # Normalize images for better comparison.
    image = (image - image.mean()) / image.std()
    return np.sqrt(ndi.convolve(image, np.real(kernel), mode='wrap')**2 +
                   ndi.convolve(image, np.imag(kernel), mode='wrap')**2)

def compute_avg(img):
    r,c = img.shape

    # chunks_row = np.split(np.array(range(r)), 4)
    # chunks_col = np.split(np.array(range(c)), 4)
    chunks_row = np.array_split(np.array(range(r)), 4)
    chunks_col = np.array_split(np.array(range(c)), 4)

    grid_images = []

    for row in chunks_row:
        for col in chunks_col:
            grid_images.append(np.mean(img[np.min(row):np.max(row), np.min(col):np.max(col)]))
    return np.array(grid_images).reshape((4,4))

def compute_gist_descriptor(img_loc):
    # build average feature map:

    # shrink makes the image smaller...
    images = cv2.imread(img_loc, cv2.COLOR_GRAY2BGR) 
    images = images/255.0
    
    # Plot a selection of the filter bank kernels and their responses.
    results = []
    kernel_params = []    
    
    for theta in (0, 1, 2, 3, 4, 5, 6, 7):
        theta = theta / 8. * np.pi
        for frequency in (0.1, 0.2, 0.3, 0.4):
            kernel = gabor_kernel(frequency, theta=theta)
            params = 'theta={}, frequency={:.2f}'.format(theta * 180 / np.pi, frequency)
            kernel_params.append(params)
            results.append((kernel, power(images, kernel)))
    
    print("User these Filters:")
    print(kernel_params)
    print()

    # 16 grids * 32 kernels = 512
    return np.array([compute_avg(power_single(img)) for img in results]).reshape(512,)

def get_feature_vector_path(image_path, dest_dir):
    dest_name = ".".join(os.path.basename(image_path).split(".")[:-1]) + ".gist"
    dest_path = os.path.join(dest_dir, dest_name)
    return dest_path
                
def save_feature_vector(feature_vector, image_path, dest_dir):
    dest_path = get_feature_vector_path(image_path, dest_dir)
    np.savetxt(dest_path, feature_vector, delimiter=',')


def save_feature_vector(feature_vector, image_path, dest_dir="./GIST_features"):
    dest_name = os.path.basename(image_path).split(".")[0] + ".gist"
    dest_path = os.path.join(dest_dir, dest_name)
    if not os.path.exists(dest_dir):
        os.mkdir(dest_dir)
    np.savetxt(dest_path, feature_vector, delimiter=',')


def load_feature_vector(feature_path):
    array = np.loadtxt(feature_path, delimiter=',')
    return array

# KNN

In [52]:
def KNN(X, Y, y_test):
    model = KNeighborsClassifier(n_neighbors=3, metric=calculate_gist_distance)
    model.fit(X, Y)
    preds = model.predict(y_test)
    return preds


# get the closest image to the tram
def ssd(imageA, imageB):
    ssd = np.sum(np.square((imageA.astype("float") - imageB.astype("float"))))     
    return ssd


def calculate_gist_distance(current_feature_vector, other_feature_vector):
    return ssd(current_feature_vector, other_feature_vector)


def implot(im, gray=False):
    cv_rgb = cv2.cvtColor(im.astype(np.uint8), cv2.COLOR_BGR2RGB)
    plt.imshow(cv_rgb)
    plt.show()


def prepare_data(classsificaton, gist_feature_dir):
    X = []
    Y = []
    for file_name in os.listdir(gist_feature_dir):
        fv = load_feature_vector(os.path.join(gist_feature_dir, file_name))
        fv = np.nan_to_num(fv)
        X.append(fv)
        name = Path(file_name).stem
        # name = os.path.basename(path).split('-')[1].split('.')[0]
        #Y.append(name.split('-')[0])
        Y.append(classsificaton[name])

    return X, Y

def get_classification(labels_csv):
    Y = pd.read_csv(labels_csv)
    filenames = Y['Id'].tolist()
    classes = Y['Class'].tolist()

    return dict(zip(filenames, classes))


def get_items_to_predict(source_dir):
    names = []
    fvs = []
    for file_name in os.listdir(source_dir):
        fv = load_feature_vector(os.path.join(source_dir, file_name))
        fv = np.nan_to_num(fv)
        names.append( Path(file_name).stem )
        fvs.append(fv)

    return names, fvs



In [53]:
# Run the KNN on Train Data
classification = get_classification(LABELS_PATH)
X, Y = prepare_data(classification, GIST_FEATURES_DIR_PATH)
names, y_test = get_items_to_predict(GIST_FEATURES_DIR_PATH)
preds = KNN(X, Y, y_test)
for name, pred in zip(names, preds):
    print(name, pred)

7cPUox9BR6vdGAWIKZg0 3
GeTocyt4jDP3OpSI2u5W 3
gelhQLoiRE0HZTOI2fYv 7
3vsl9KYdAo1tibP5DRSX 9
gBvO2rS0FyHbzcR5Qp76 3
gBUSDYdZqJMGCQuzWTrX 9
3v7fe0Diw9cOu6dAzsaE 1
783NTBAd2k1E6wHDGjiJ 2
76cGRgDJyFW08rY1tz59 1
J4xI6SAhl3OuEdWmw1co 8
J4zTlSBqN5r1v7MIYHQm 1
j4LbkH0Dx8eXOJnPBtZT 2
3suymN9GSHAtvYZ71Kaw 3
3RiBw0ntobeCh5VEQrY1 4
72cG6vkEjAuFBzym5P4e 3
6ZO9dMHY5JkCofwTSyvm 1
g9izZfbqECNJw7RojkIY 3
6ZdpisIvCL1GornaSU4D 9
3oA2emBxRQTkvDrVqPfL 9
6xjq1v93FCPGUZpyslVA 2
6WuXimIJaPd3x8QoGUH0 3
izRtJ6ELUuCPWAxB9ZOs 8
G8bXrIAFQ0HYlnZ7eDym 2
3mCGDSjRVxpeoaYXLfc9 9
g7qEpdzW3mhFL1Tf0kHV 8
iz2hCIvZncJ9jDUY7MXq 8
6WbENDkcC750euPGqApQ 4
g5NnsVCLJdb1lzPBEUeT 1
6v7e1UtEQSGMgmw2Ay5F 3
3k1meX0gV2WMjAvGDrCq 3
6tvBurLRHc15EYh2mXba 8
3IZHa2xXVu6KrgNoPd9W 6
G3sXTkHj07CSAKW84LQb 2
6tmqe5B8JDsQnyhGo7vP 2
IxN43MpQlgV0ZrPdGjk9 1
6tMgf5pGIsvHSn83WZyx 1
G37jI5Fu8YEJVgWhiKCQ 1
6tGhqxwai7TfWSby8lUO 9
3hYGWBHZ6iJr9X8KxENg 3
3GIiqFjPozTtc4Kms2NW 1
3FRuYdNiL4ye96gBoWK2 2
6rBCZSJ29AP1VQtkwz7u 4
g1639t4hNmJ8DWwBO5cA 9
G0lOhNpMKqB

In [54]:
def calculate_accuracy(names, preds):
    classification = get_classification(LABELS_PATH)
    correctly_classified = 0
    for name, pred in zip(names, preds):
        if classification[name] == pred:
            correctly_classified += 1
    return 100 * (float(correctly_classified) / len(preds))

In [55]:
# prediction over the trained data
calculate_accuracy(names, preds)

94.69964664310953