In [2]:
import numpy as np
import pandas as pd
from elasticsearch import Elasticsearch
from elasticsearch.helpers import bulk
import os
from img2vec_pytorch import Img2Vec
from PIL import Image
import xml.etree.ElementTree as ET

In [3]:
elastic_client = Elasticsearch(hosts=['http://localhost:9200'],
                               basic_auth=('elastic', 'M6fSr6FCABQhGwqR8HPf'))
model = Img2Vec()

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /home/jovyan/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:20<00:00, 2.31MB/s]


In [4]:
def getSimilarity(vector:list, embedding_field:str, index_name:str, size:int, k:int, candidate:int):
    result = elastic_client.search(
        index=index_name,
        body={
            "size": size,
            "knn": {
            "field": "{}".format(embedding_field),
            "query_vector": vector,
            "k": k,
            "num_candidates": candidate
            },
        "fields": [ "no", "class_label" ],
        "_source": "false"
        }
    )
    return result

In [5]:
def frequency_histogram(subclasses:list, k:int):
    result = {}
    ctr=1
    for value, key in sorted(((subclasses.count(e), e) for e in set(subclasses)), reverse=True):
        if (ctr > k): break
        result[key] = value
        ctr+=1
        
    return result

In [6]:
def score_histogram(subclasses:list, k:int):
    temp_result = {}
    result = {}
    ctr=1
    for entry in subclasses:
        elements = entry.split(";")
        subclass_list = elements[0].split(",")
        for subclass in subclass_list:
            if (subclass not in temp_result):
                temp_result[subclass] = float(elements[1])
            else:
                temp_result[subclass] = temp_result.get(subclass) + float(elements[1])

    for key, value in sorted(temp_result.items(), key=lambda x:x[1], reverse=True):
        if (ctr > k): break
        result[key] = value
        ctr+=1

    return result

In [7]:
def process_result(accuracy_dict, k, n, type):
    index = "{}-{}-{}".format(k,n,type)
    if (index in accuracy_dict): 
        accuracy_dict[index] = accuracy_dict.get(index) + 1
    else:
        accuracy_dict[index] = 1

In [8]:
def get_process_result(accuracy_dict, k, n, type):
    index = "{}-{}-{}".format(k,n,type)
    if (index in accuracy_dict): 
        return accuracy_dict[index]
    else:
        return 0

In [9]:
def print_process_result(accuracy_dict, k_list, n_list):
    for k in k_list:
        for n in n_list:
            positive = get_process_result(accuracy_dict, k, n, 'positive')
            negative = get_process_result(accuracy_dict, k, n, 'negative')
            accuracy = positive / (positive + negative)
            print("k={} - n={} - Positive: {} - Negative: {} - " 
                "Accuracy: {} ".format(k,n,positive,negative,accuracy))

In [10]:
def transform_process_result(accuracy_dict, k_list, n_list):
    matrix = np.zeros((len(k_list), len(n_list)))
    i = j = 0
    for k in k_list:
        j=0
        for n in n_list:
            positive = get_process_result(accuracy_dict, k, n, 'positive')
            negative = get_process_result(accuracy_dict, k, n, 'negative')
            accuracy = positive / (positive + negative)
            matrix[i][j] = accuracy
            j+=1
        i+=1
    return matrix

In [11]:
def generate_embedding(segment):
    #img = Image.open(filename).convert("RGB")
    vec = model.get_vec(segment)
    return vec.tolist()

In [12]:
def get_files(path:str, extension:str):
    file_list = [f for f in os.listdir(path) if f.endswith(extension)]
    return file_list

In [13]:
def get_classes_segments(path:str, file_name:str):
    segments = []
    class_ = {}
    values = []
    tree = ET.parse(path+file_name)
    root = tree.getroot()
    for child in root:
        if (child.tag == 'object'):
            for new_child in child:
                if (new_child.tag == 'name'):
                    class_ = new_child.text
                if (new_child.tag == 'bndbox'):
                    for bndbox in new_child:
                        values.append(int(bndbox.text)) 
                    segments.append({'class': class_, 'segment_values': values})
                    values = []            
    return segments

In [14]:
dataset_path = "./datasets/road/test/"
index_name = "road_image_segment"
field = "embedding"

In [15]:
#Performs queries to group returned patent subclasses for each input patent
#Ranking strategy based on sum of occurrencies taking into account the reverse mode 
k_list = [1,2,3,4,5,6,7,8,9,10]
n_list = [1,5,10,25,50,75,100]
file_list = get_files(dataset_path, ".jpg")
accuracy_dict = {}
max_n = 100
candidate = 100
id = 0
file_id = 0
class_label = ""

print("\nInitializing the ranking strategy based on sum of occurrencies!!!\n")
for file in file_list:
    subclass_list = []
    hit_list = []
    file_id += 1
    image_file = dataset_path+file
    xml_file = file[0:len(file)-4]+".xml"
    classes_segments = get_classes_segments(dataset_path, xml_file)
    image = Image.open(image_file).convert("RGB")
    print("Image: ", file_id, " -> ", image_file)
    for class_segment in classes_segments:
        id += 1
        class_ = class_segment.get("class")
        values = class_segment.get("segment_values") 
        xmin = values[0]
        xmax = values[1]
        ymin = values[2]
        ymax = values[3]
        segment = image.crop((xmin,ymin,xmax,ymax))
        vector = generate_embedding(segment)
 
        result = getSimilarity(vector, field, index_name, max_n, max_n, candidate)

        hit_list.clear()
        hits=0

        subclass_dict = {}
        for hit in result['hits']['hits']:
            #print(hit['fields']['no'])
            try:
                #class_label = hit["fields"]["class_label"]
                hit_list.append(hit["fields"]["class_label"])
            except:
                print('Error ',hit['fields']['no'])
            hits+=1

        print("Query id: ", id, " - Classes: ", class_, " - Hits: ", hits, " - ", xml_file) 
        
        for k in k_list: 
            for n in n_list:
                ctr_hit = 0
                for subclass in hit_list:
                    ctr_hit+=1
                    if (ctr_hit > n): break
                    subclass_list.extend(subclass)

                histogram_res = frequency_histogram(subclass_list, k)
                subclass_list.clear()

                if (class_ in histogram_res): 
                    process_result(accuracy_dict, k, n, 'positive')
                else:
                    process_result(accuracy_dict, k, n, 'negative')
    
print_process_result(accuracy_dict, k_list, n_list)
matrix = transform_process_result(accuracy_dict, k_list, n_list)
print("Accuracy by k and n")
print(matrix)


Initializing the ranking strategy based on sum of occurrencies!!!

Image:  1  ->  ./datasets/road/test/China_Drone_002204_jpg.rf.43dd1e80df9ea2ae9f085e1acc671500.jpg


GET http://localhost:9200/road_image_segment/_search [status:N/A request:0.001s]
Traceback (most recent call last):
  File "/opt/conda/lib/python3.11/site-packages/urllib3/connection.py", line 203, in _new_conn
    sock = connection.create_connection(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/urllib3/util/connection.py", line 85, in create_connection
    raise err
  File "/opt/conda/lib/python3.11/site-packages/urllib3/util/connection.py", line 73, in create_connection
    sock.connect(sa)
ConnectionRefusedError: [Errno 111] Connection refused

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/opt/conda/lib/python3.11/site-packages/elasticsearch/connection/http_urllib3.py", line 172, in perform_request
    response = self.pool.urlopen(method, url, body, retries=Retry(False), headers=request_headers, **kw)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

ConnectionError: ConnectionError(<urllib3.connection.HTTPConnection object at 0x7f5249917410>: Failed to establish a new connection: [Errno 111] Connection refused) caused by: NewConnectionError(<urllib3.connection.HTTPConnection object at 0x7f5249917410>: Failed to establish a new connection: [Errno 111] Connection refused)