# MIE524 - Assignment 3
Please complete this notebook for Assignment 3.

## Q2 - Locality-Sensitive Hashing for Approximate Nearest Neighbour

In [None]:
import numpy as np
import random
import time
import unittest
from PIL import Image
from sklearn.model_selection import train_test_split
from collections import defaultdict

In [None]:
def l1(u, v):
    """
    Finds the L1 distance between two vectors
    u and v are 1-dimensional np.array objects
    """

    # TODO: Implement this
    raise NotImplementedError

In [None]:
class my_LSH:
    def __init__(self, dataset, n_bands, n_rows):
        """
        Initializes the LSH object
        dataset - dataset to be searched
        n_bands - number of bands
        n_rows - number of rows in each band
        """
        self.n_bands =  n_bands
        self.n_rows =  n_rows

        self.A  =  dataset
        self.bands  =  self.create_bands()
        self.bands_buckets = [defaultdict(set) for band in self.bands]
        self.index_data()

    def create_band_function(self, dimensions, thresholds):
        """
        Creates a g_j function from the original image and compare from a list of dimensions and thresholds.
        Each g_j function generates the j'th band values for a given datapoint (a concatenation of r hash values).
        Each (locality-sensitive) hash value is equal 1 if the corresponding dimension in the original image is higher than the corresppnding threshold.
        """
        def band_function(v):
            boolarray = [v[dimensions[i]] >= thresholds[i] for i in range(len(dimensions))]
            return "".join(map(str, map(int, boolarray)))
        return band_function

    def create_bands(self, num_dimensions=400, min_threshold=0, max_threshold=255):
        """
        Creates the collection of g_j functions 1<=j<=b, one for each band.
        Each function selects r dimensions (i.e. column indices of the image matrix)
        at random, and then chooses a random threshold for each dimension, between 0 and
        255.  For any image, if its value on a given dimension is greater than or equal to
        the randomly chosen threshold, we set that bit to 1.  Each hash function returns
        a length-r bit string of the form "0101010001101001...".
        """
        bands = []
        for i in range(self.n_bands):
            dimensions = np.random.randint(low = 0,
                                        high = num_dimensions,
                                        size = self.n_rows)
            thresholds = np.random.randint(low = min_threshold,
                                        high = max_threshold + 1,
                                        size = self.n_rows)

            bands.append(self.create_band_function(dimensions, thresholds))
        return bands

    def hash_vector(self, v):
        """
        Hashes an individual vector (i.e. image).  This produces an array with b
        entries (one for each band), where each entry is a string of r bits (one for each row).
        """
        return [f(v) for f in self.bands]

    def index_data(self):
        """
        Hashes the data in A, where each row is a datapoint, using the b band
        functions in "functions." For each band, we use the corresponding dictionary (hash table) to index the data point.
        """
        for doc_index in range(self.A.shape[0]):
            doc_hash = self.hash_vector(self.A[doc_index,:])
            for band_idx, doc_band_hash in enumerate(doc_hash):
                self.bands_buckets[band_idx][doc_band_hash].add(doc_index)

    def get_candidates(self, query_vector):
        """
        Retrieve all of the points that hash to one of the same buckets as the query point.
        """
        query_hash = self.hash_vector(query_vector)
        candidates = set()
        for band_idx, query_band_hash in enumerate(query_hash):
          bucket_candidates = self.bands_buckets[band_idx][query_band_hash]
          candidates = candidates.union(bucket_candidates)
        return candidates

    def lsh_search(self, query_vector , num_neighbours=10):
        """
        Run the entire LSH algorithm
        """
        # 1. Get the candidates for nearest neighbours
        # 2. Determine distance of candidates
        # 3. Extract best neighbours
        # 4. Return a list of the best neighbours and a list of the corresponding distances between each neighbour and the query

        # TODO: YOUR CODE HERE

        return lsh_neighbours, lsh_dist

In [None]:
def plot(A, row_nums, base_filename):
    """
    Plots images at the specified rows and saves them each to files.
    """
    for row_num in row_nums:
        patch = np.reshape(A[row_num, :], [20, 20])
        im = Image.fromarray(patch)
        if im.mode != 'RGB':
            im = im.convert('RGB')
        display(im)
        im.save(base_filename + "-" + str(row_num) + ".png")

In [None]:
#### TESTS #####

class TestLSH(unittest.TestCase):
    def test_l1(self):
        u = np.array([1, 2, 3, 4])
        v = np.array([2, 3, 2, 3])
        self.assertEqual(l1(u, v), 4)

    def test_hash_data(self):
        f1 = lambda v: sum(v)
        f2 = lambda v: sum([x * x for x in v])
        A = np.array([[1, 2, 3], [4, 5, 6]])
        self.assertEqual(f1(A[0,:]), 6)
        self.assertEqual(f2(A[0,:]), 14)

        functions = [f1, f2]
        self.assertTrue(np.array_equal(lsh.hash_vector(functions, A[0, :]), np.array([6, 14])))
        self.assertTrue(np.array_equal(lsh.hash_data(functions, A), np.array([[6, 14], [15, 77]])))

    ### You may write your own tests here

### b) Split data with 100 random query points

In [None]:
data = np.genfromtxt('patches.csv', delimiter=',')
dataset, query_points = train_test_split(data, test_size=100, random_state=42)

### c) Compare LSH and linear search for the 100 query points

In [None]:
# Initialize a my_LSH class
lsh = my_LSH(dataset , n_rows=24 , n_bands=10 )

In [None]:
def linear_search(A, query_vector, num_neighbours):
    """
    Finds the nearest neighbours to a given vector, using linear search.
    """
    # TODO: YOUR CODE HERE
    raise NotImplementedError

In [2]:
# TODO: YOUR CODE HERE

### d) Plot errors vs b and r

In [None]:
def lsh_error( lsh_dist, linear_dist ):
    """
    Computes the error measure
    """
    # TODO: YOUR CODE HERE
    raise NotImplementedError

In [None]:
# TODO: YOUR CODE HERE

### e) Plot 10 nearest neighbours

In [1]:
# TODO: YOUR CODE HERE

### f) Change hash function

In [None]:
# TODO: YOUR CODE HERE