# Similarity Search

#### Required libraries

- tensorflow-probability 0.7.0 (for use with tf1.4)
- hnswlib 0.3.2.0
- nmslib 1.8.1 if using search_location_brute() or search_image_brute()

##### Note: There are a lot of images shown. Use plt.rcParams["figure.figsize"] = (x,y) to adjust the size of figures in the notebook, where x and y are ints

### Importing modules

In [None]:
# This future stuff is necessary for TF
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import glob, sys, os
import matplotlib.pyplot as plt
for file in glob.glob('modules/*'):
    sys.path.insert(0, file)
import searchtiles, searchmodel
os.environ["API_KEY"] = 'Y2RiZGU0YTAtYjk5Ni00YTgyLWI4MzgtYmM0OGUyNzMyMzAx'

## Model training

### GPU setup

In [None]:
searchmodel.gpu_setup(gpu_number = 6,
                      gpu_fraction = 0.45)

### Preparing training dataset

In [None]:
searchtiles.batch_inference(save_path = 'apollo_predictions_small/',
                            base_path = '/mnt/data/datasets/apollo_20190509/',
                            num_imgs = 100,
                            batches = 5,
                            zoom_level = 21)

In [None]:
searchtiles.copy_n_files(base_path = '/mnt/data/data_filip/apollo_predictions_small/',
                         n = 100,
                         zoom_level = 21,
                         category = 'all')

In [None]:
searchtiles.rotate_model_tiles(base_path = '/mnt/data/data_filip/apollo_predictions_small/all_100_z21/')

### Training

In [None]:
encoder, decoder, vae = searchmodel.build_vae_resnet(encoded_size = 65)

#### Either train the model or load weights from disk

In [None]:
searchmodel.train_model(model = vae,
                        encoder = encoder,
                        decoder = decoder,
                        dataset_directory = '/mnt/data/data_filip/apollo_predictions_small/all_100_z21/',
                        epochs = 50,
                        run_name = '100 z21 resnet',
                        tensorboard_directory = "/mnt/data/data_filip/tensorboard_logs/")

In [None]:
# If you get 'No locks available' errors, either move the .h5 file to your 
# /home/ directory or run 'export HDF5_USE_FILE_LOCKING=FALSE' and restart your notebook
vae.load_weights('/mnt/data/data_filip/models/50k_z21_e65.h5')

## Creating tiles, encodings and indexes

### Fetching tiles

In [None]:
searchtiles.fetch_tiles(lon = 151.1518715,
                        lat = -33.8223945,
                        zoom_level = 20,
                        datestr = "2018-12-27",
                        box_size = 10,
                        categories = [1,2,3,8,33],
                        file_path = '/mnt/data/data_filip/tiles_sydney/sydney_small_20/',
                        verbose = True)

### Cropping and aligning dataset and creating encoded predictions

In [None]:
searchtiles.crop_rotate_encode(base_path = '/mnt/data/data_filip/tiles_sydney/sydney_small_20/',
                               categories = [1,2,3,8,33],
                               encoder = encoder,
                               save_path = '/mnt/data/data_filip/encoded_predictions/sydney_small_21_e65/',
                               coordinates_path = '/mnt/data/data_filip/encoded_predictions/sydney_small_21_e65_coordinates.npy',
                               split = 5)

### Creating search index

In [None]:
searchtiles.create_index(encoding_path = '/mnt/data/data_filip/encoded_predictions/sydney_small_21/',
                         save_path = '/mnt/data/data_filip/indexes/sydney_small_21/',
                         ef_construction = 10000, 
                         M = 100, 
                         num_threads = 8,
                         index_space = 'ip')

## Indexed Search

### Search using a given set of lat/long coordinates
##### The example provided is lon = 151.110166, lat = -33.772612 for category 1 (pools) at z19
##### This finds long, thin pools that are big enough to span a z19 tile (like 50m public pools)

In [None]:
result_coordinates, result_images, result_ids = searchtiles.search_location(encoder = encoder,
                                                index_directory = '/mnt/data/data_filip/indexes/sydney_bigger_19_e65/',
                                                coordinates_directory = '/mnt/data/data_filip/encoded_predictions/sydney_bigger_19_e65_coordinates.npy',
                                                lon = 151.110166,
                                                lat = -33.772612, 
                                                datestr = '2018-12-27',
                                                dataset_datestr = '2018-12-27',
                                                zoom_level = 19,
                                                categories = [1,2,3,8,33],
                                                category_weights = [1,0,0,0,0],
                                                index_space = 'ip',
                                                index_dim = 65,
                                                num_nearest = 25,
                                                show_images = True)

### Search using a given image
##### The image should be 128x128xC, where C is the amount of categories you wish to use

##### The example provided is a search for three-pronged flat roofs at zoom 19 using a quickly drawn query image

In [None]:
import matplotlib.pyplot as plt
import numpy as np
im = plt.imread('/mnt/data/data_filip/queryimages/6-bad.png')[:,:,:1]
im_model = vae(np.array([im])).mean()[0,:,:,:1]
print('Original image (left) and that image passed through the VAE (right)')
print('The right-hand image is used for the search query')
plt.subplot(1,2,1)
plt.imshow(im[:,:,0])
plt.subplot(1,2,2)
plt.imshow(im_model[:,:,0])
plt.show()

In [None]:
result_coordinates, result_images, result_ids = searchtiles.search_image(preds = im_model,
                                                encoder = encoder,
                                                index_directory = '/mnt/data/data_filip/indexes/sydney_bigger_19_e65/',
                                                coordinates_directory = '/mnt/data/data_filip/encoded_predictions/sydney_bigger_19_e65_coordinates.npy',
                                                dataset_datestr = '2018-12-27',
                                                zoom_level = 19,
                                                categories = [33],
                                                category_weights = [1],
                                                index_space = 'ip',
                                                index_dim = 65,
                                                num_nearest = 25,
                                                show_images = False)

## Brute-force Search
#### Use this if you haven't built indexes yet. It will take up to a minute to do the search

### Search using a given set of lat/long coordinates
##### The example provided is lon = 151.105687, lat = -33.766541, for category 3 (solar panels) at z21
##### This finds long, thin solar panels

In [None]:
result_coordinates, result_images, result_ids = searchtiles.search_location_brute(encoder = encoder,
                                                encodings_directory = '/mnt/data/data_filip/encoded_predictions/sydney_bigger_21_e65/',
                                                coordinates_directory = '/mnt/data/data_filip/encoded_predictions/sydney_bigger_21_e65_coordinates.npy',
                                                lon = 151.105687,
                                                lat = -33.766541, 
                                                datestr = '2018-12-27',
                                                dataset_datestr = '2018-12-27',
                                                zoom_level = 21,
                                                categories = [1,2,3,8,33],
                                                category_weights = [0,0,1,0,0],
                                                num_nearest = 25,
                                                show_images = False)

### Search using a given image
##### The image should be 128x128xC, where C is the amount of categories you wish to use
##### The example provided is a multi-category search for pools surrounded by high (>2m) vegetation

In [None]:
import matplotlib.pyplot as plt
import numpy as np
plt.rcParams["figure.figsize"] = (5,5)
im = np.zeros((128,128,2), dtype = np.float32)
im[:,:,0] = np.float32(plt.imread('/mnt/data/data_filip/queryimages/multi-pool.png')[:,:,0])
im[:,:,1] = np.float32(plt.imread('/mnt/data/data_filip/queryimages/multi-veg.png')[:,:,0])

im_model = np.zeros((128,128,2), dtype = np.float32)
im_model[:,:,0] = vae(np.array([im[:,:,:1]])).mean()[0,:,:,0]
im_model[:,:,1] = vae(np.array([im[:,:,1:]])).mean()[0,:,:,0]

print('Original pool (left) and vegetation (right) query images')
plt.subplot(1,2,1)
plt.imshow(im[:,:,0])
plt.subplot(1,2,2)
plt.imshow(im[:,:,1])
plt.show()

print('Above query images passed through the VAE. These are used to search')
plt.subplot(1,2,1)
plt.imshow(im_model[:,:,0])
plt.subplot(1,2,2)
plt.imshow(im_model[:,:,1])
plt.show()

##### Pools are weighted 0.5x such that the search values the fact that there is a pool surrounded by vegetation, rather than the exact pool shape

In [None]:
result_coordinates, result_images, result_ids = searchtiles.search_image_brute(preds = im_model,
                                                encoder = encoder,
                                                encodings_directory = '/mnt/data/data_filip/encoded_predictions/sydney_bigger_21_e65/',
                                                coordinates_directory = '/mnt/data/data_filip/encoded_predictions/sydney_bigger_21_e65_coordinates.npy',
                                                dataset_datestr = '2018-12-27',
                                                zoom_level = 21,
                                                categories = [1,8],
                                                category_weights = [0.5,1],
                                                num_nearest = 25,
                                                show_images = False)

#### Show the nearest images in a grid

In [None]:
plt.rcParams["figure.figsize"] = (15,60)
for i in range(100):
    plt.subplot(20,5,i+1)
    plt.imshow(result_images[i])
    plt.axis('off')