In [161]:

import torch
import numpy as np
from autoencoder import ConvDecoder, ConvEncoder
from sklearn.neighbors import NearestNeighbors
import torchvision.transforms as T
import os
from PIL import Image
import matplotlib.pyplot as plt
import json
import time
import random
%matplotlib inline

In [162]:
def load_image_tensor(image_path, device):
    image_tensor = T.ToTensor()(Image.open(image_path))
    image_tensor = image_tensor.unsqueeze(0)
    #print(image_tensor.shape)
    # input_images = image_tensor.to(device)
    return image_tensor

In [163]:
image_paths=[]
with open("geological_map.json", 'r', encoding='utf-8') as f:
    image_paths=json.load(f)
print(len(image_paths))

29998


In [164]:
def compute_similar_images(image_path, num_images, embedding, device):
    image_tensor = load_image_tensor(image_path, device)
    # image_tensor = image_tensor.to(device)

    with torch.no_grad():
        image_embedding = encoder(image_tensor).cpu().detach().numpy()

    #print(image_embedding.shape)

    flattened_embedding = image_embedding.reshape((image_embedding.shape[0], -1))
    #print(flattened_embedding.shape)

    knn = NearestNeighbors(n_neighbors=num_images, metric="jaccard")
    knn.fit(embedding)
    start_time=time.time()
    _, indices = knn.kneighbors(flattened_embedding)
    end_time=time.time()
    time_taken=end_time-start_time
    #print("Time taken: ",end_time-start_time)
    indices_list = indices.tolist()
    
    #print(indices_list)
    return indices_list,time_taken


In [165]:
def plot_similar_images(indices_list):
    indices = indices_list[0]
    print("total indices: ", len(indices))
    print(indices_list)
    for index in indices:
        # img_name = str(index - 1) + ".jpg"
        # print(img_name)
        img_path = image_paths[index]
        print(img_path)
        img = Image.open(img_path).convert("RGB")
        plt.imshow(img)
        plt.show()

In [166]:
TEST_IMAGE_PATH = "geological_similarity/schist/ZZ5Z5.jpg"
NUM_IMAGES = 50
ENCODER_MODEL_PATH = "geological_encoding.pt"
EMBEDDING_PATH = "geological_embed.npy"

In [167]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
encoder = ConvEncoder()

# Load the state dict of encoder
encoder.load_state_dict(torch.load(ENCODER_MODEL_PATH, map_location=device))
encoder.eval()
encoder.to(device)

# Loads the embedding
embedding = np.load(EMBEDDING_PATH)
# test_img = Image.open(TEST_IMAGE_PATH).convert("RGB")
# plt.imshow(test_img)
# plt.show()
# indices_list = compute_similar_images(TEST_IMAGE_PATH, NUM_IMAGES, embedding, device)
# plot_similar_images(indices_list)

In [168]:
def write_to_file(test_img_path,indices_list,file_name):
    indices = indices_list[0]
    results=[]
    for index in indices:
        img_path = image_paths[index]
        results.append(img_path)
    #results_string=' '.join(i for i in results)
    #print(results)
    with open(file_name, 'r', encoding='utf-8') as f:
        current_dir=json.load(f)
    current_dir[test_img_path]=results
    with open(file_name, "w") as outfile:
        json.dump(current_dir, outfile)


In [169]:
testing_times=[]

In [170]:
def test_method(test_img_path,file_name):
    test_img = Image.open(test_img_path).convert("RGB")
    # plt.imshow(test_img)
    # plt.show()
    indices_list,time_taken = compute_similar_images(test_img_path, NUM_IMAGES, embedding, device)
    write_to_file(test_img_path,indices_list,file_name)
    #plot_similar_images(indices_list)
    testing_times.append(time_taken)
    return time_taken

In [171]:
testing_images=[]
for i in range(0,100):
    random_index=random.randint(0, len(image_paths)-1)
    testing_images.append(image_paths[random_index])

for i in range(len(testing_images)):
    time_taken=test_method(testing_images[i],"jaccard_knn_results.json")
    #print(time_taken)
    print("Done with : ", i)

Done with :  0
Done with :  1
Done with :  2
Done with :  3
Done with :  4
Done with :  5
Done with :  6
Done with :  7
Done with :  8
Done with :  9
Done with :  10
Done with :  11
Done with :  12
Done with :  13
Done with :  14
Done with :  15
Done with :  16
Done with :  17
Done with :  18
Done with :  19
Done with :  20
Done with :  21
Done with :  22
Done with :  23
Done with :  24
Done with :  25
Done with :  26
Done with :  27
Done with :  28
Done with :  29
Done with :  30
Done with :  31
Done with :  32
Done with :  33
Done with :  34
Done with :  35
Done with :  36
Done with :  37
Done with :  38
Done with :  39
Done with :  40
Done with :  41
Done with :  42
Done with :  43
Done with :  44
Done with :  45
Done with :  46
Done with :  47
Done with :  48
Done with :  49
Done with :  50
Done with :  51
Done with :  52
Done with :  53
Done with :  54
Done with :  55
Done with :  56
Done with :  57
Done with :  58
Done with :  59
Done with :  60
Done with :  61
Done with :  62
Do

In [47]:
# dic={"Hello":"Wow"}
# with open("cosine_knn_results.json", "w") as outfile:
#     json.dump(dic, outfile)

In [172]:
testing_times

[0.042998552322387695,
 0.0449984073638916,
 0.04399871826171875,
 0.04800057411193848,
 0.06400179862976074,
 0.018003463745117188,
 0.04400157928466797,
 0.062000274658203125,
 0.04704618453979492,
 0.04599905014038086,
 0.04799914360046387,
 0.02096843719482422,
 0.04430103302001953,
 0.04596757888793945,
 0.04305076599121094,
 0.055005788803100586,
 0.07800436019897461,
 0.04404759407043457,
 0.044049978256225586,
 0.04798436164855957,
 0.05100250244140625,
 0.04400324821472168,
 0.04398322105407715,
 0.04700446128845215,
 0.04402875900268555,
 0.04705405235290527,
 0.05299639701843262,
 0.04605436325073242,
 0.04496407508850098,
 0.046003103256225586,
 0.051011085510253906,
 0.056000471115112305,
 0.044043779373168945,
 0.0439457893371582,
 0.04504656791687012,
 0.04300045967102051,
 0.04994535446166992,
 0.04294776916503906,
 0.06400060653686523,
 0.06417465209960938,
 0.05300021171569824,
 0.048996686935424805,
 0.06503701210021973,
 0.046944618225097656,
 0.04293942451477051,
 

In [173]:
with open("jaccard_knn_results.json", 'r', encoding='utf-8') as f:
    content=json.load(f)
print(len(content.keys()))

100


In [174]:
textfile = open("timing_results/jaccard_knn_timings.txt", "w")
for i in range(len(testing_times)):

    textfile.write(str(testing_times[i])+"\n")

textfile.close()

### Next cell is to read the timings files

In [160]:
stored_times=[]
with open("timing_results/jaccard_knn_timings.txt","r") as f:
    content=f.read()
stored_times=content.split('\n')
print(len(stored_times))

101
