In [1]:
import cv2
import time
import os
import glob
import json
import numpy as np
import natsort
import pickle5 as pickle
import matplotlib.pyplot as plt
from multiprocessing import Process
from pathlib import Path
import glob
import faiss
from models.vit.main import main_val as ret_vit_val
from models.sift_vlad.rank_val import ret_vlad_val as ret_vlad_val

# Validation
    1. Parsing
    2. Matching& Validation

## 1. Parsing Data for validation
    - json example
    {
      "version": "5.1.1",
      "flags": {},
      "shapes": [
        {
          "label": "1",
          "points": [
            [
              234.0,
              362.0
            ],
            [
              251.0,
              362.0
            ],
            [
              251.0,
              426.0
            ],
            [
              234.0,
              426.0
            ]
          ],
          "group_id": null,
          "shape_type": "polygon",
          "flags": {
            "matched": true,
            "changed": false
          }
        },
    - {cropped_index}_{label}_{changed_flag}
        ex) 0_100_False.jpg -> label 100 is unmatched pair
            1_3_False.jpg -> False means unchanged design signs

In [2]:
def parsing(query_dir, db_dir):
    
    for img_lbl_dir in [query_dir, db_dir]: #parsing both query and db
    
        for i, (img_path, lbl_path) in enumerate(zip(sorted(glob.glob(img_lbl_dir + '*.jpg')), sorted(glob.glob(img_lbl_dir + '*.json')))):

            #check match img with label
            img_id = int(Path(img_path).stem.split('@')[0])
            lbl_id = int(Path(lbl_path).stem.split('@')[0])
            if img_id == lbl_id:

                #read image origin
                img_org = cv2.imread(img_path)

                #open label
                with open(lbl_path, "r") as js:
                    lbl = json.load(js)

                    #read label
                    for i in range(len(lbl['shapes'])):

                        # crop image
                        coord_list = lbl['shapes'][i]['points']
                        x_list = list(map(lambda x: int(x[0]), coord_list))
                        y_list = list(map(lambda x: int(x[1]), coord_list))
                        del coord_list

                        min_x = min(x_list)
                        max_x = max(x_list)
                        del x_list

                        min_y = min(y_list)
                        max_y = max(y_list)
                        del y_list

                        croppedImage = img_org[min_y: max_y, min_x: max_x]

                        # match info
                        match = lbl['shapes'][i]['label'] #if matched specific id, unmatched 100
                        change = str(lbl['shapes'][i]['flags']['changed']) # False F, True T

                        #save cropped Image
                        save_dir = str(Path(img_lbl_dir).parent) + '/' + str(Path(img_lbl_dir).stem) + '_val/'
                        dir_cropped = save_dir + f"/{img_id}/" #####
                        if not os.path.exists(dir_cropped):
                            os.makedirs(dir_cropped, exist_ok=True)
                        cropped_fname = f"{str(i)}_{match}_{change}"
                        cv2.imwrite(dir_cropped + cropped_fname + '.jpg', croppedImage)
            else:
                raise Exception("ID Matching Error")
    print("\nParsing...\nDone.")
    
    query_val_dir = str(Path(query_dir).parent) + '/' + str(Path(query_dir).stem) + '_val/'
    db_val_dir = str(Path(db_dir).parent) + '/' + str(Path(db_dir).stem) + '_val/'
    
    return query_val_dir, db_val_dir

In [4]:
### exec
query_dir = './data/gt/query/'
db_dir = './data/gt/db/'

query_val_dir, db_val_dir = parsing(query_dir, db_dir)
query_val_dir, db_val_dir


Parsing...
Done.


('data/gt/query_val/', 'data/gt/db_val/')

## 2. Path Dictionary
    - query
    - db

In [5]:
# DB dataset root path

pnorm_db_dic = {} #panorama db 
for db_idx in os.listdir(db_val_dir):
    
    db_pnorm_dir = db_val_dir + db_idx + '/'
    db_path_list = natsort.natsorted([db_pnorm_dir + db_name for db_name in os.listdir(db_pnorm_dir) if 'jpg' in db_name or 'png' in db_name])
    pnorm_db_dic[db_idx] = db_path_list

# Query dataset root path 

pnorm_q_dic = {} #panorama query
for q_idx in os.listdir(query_val_dir):
    
    q_pnorm_dir = query_val_dir + q_idx + '/'
    q_path_list =natsort.natsorted( [q_pnorm_dir + q_name for q_name in os.listdir(q_pnorm_dir) if 'jpg' in q_name or 'png' in q_name])
    pnorm_q_dic[q_idx] = q_path_list

## 3. Matching& Rank
    - vit
    - sift

In [5]:
result_path = './data/result/'
device ='cuda'

In [None]:
start = time.perf_counter()
for i, p_id in enumerate(pnorm_q_dic.keys()): #p_id : panorama_id

        #process
        print(f"Panoram_ID : {p_id} - {i+1}/{len(pnorm_q_dic)}")

        try: # query ID == db ID

            proc1 = Process(target=ret_vlad_val, args=(pnorm_q_dic[p_id], pnorm_db_dic[p_id], p_id, p_id, result_path, "cs", device))
            proc2 = Process(target=ret_vit_val, args=(pnorm_q_dic[p_id], pnorm_db_dic[p_id], p_id, p_id, result_path, batch_size, num_workers, device))

            proc1.start(); proc2.start()
            proc1.join(); proc2.join()

        except: # if query ID not in db ID
            print("ID matching error")
            print(f"query ID : {p_id}")

print("total_ret_time", time.perf_counter()-start)

## 4 Merge& Validation : merge_val
    - Merge SIFT& VIT 
        1. Cosine-Similarity Matching Score
    - Sort& Filter by topk
    - Validation
        - mAP@topk
            - micro/macro
        - Recall@1
        - Precision@1

In [10]:
#input : path
#if you want to do Validation, turn on val is true
def merge_topk_val(result_path, q_panorama_id, db_panorama_id, topk, match_weight=1/4, method='vit', algo='max'):
    
    '''
    topk: int, the number of matching candidates
    match_weight: float, threshold of whether matched or not
    result_path: str, folder path saving results, './data/result/'
    method: str, what you want to use, ['vit', 'sift', 'vit_sift', 'sift_vit']
    algo: str, matching algorithm, ['max', 'erase']
        - 'max' : matching pairs for maximizing score
        - 'erase' : matching pairs removing top1 prediction sequentially.
    '''
    
    #merge
    merge_dict = {}
        
    for m in method.split('_'):
#         try:
        with open(f"{result_path}/eval/{m}_best_pair/pair_{q_panorama_id}-{db_panorama_id}_{m}.txt", "r") as f:
            for line in f.readlines():
                data = line.strip().split(',')
                img1 = data[0].split('.')[0]
                img2 = data[1].split('.')[0]
                score = float(data[2]) #cosine_similarity

                if img1+'-'+img2 not in merge_dict: #3가지 방법에서 특정 q_db pair가 중복된다면, 점수를 합산!
                    merge_dict[img1+'-'+img2] = score
                else:
                    merge_dict[img1+'-'+img2] += score
#         except:
#             print(f"method {m} error")
    
    #change form
    result_dict = {}
    
    for key in merge_dict:
        imgs = key.split('-')
        q = imgs[0]
        db = imgs[1]
        del imgs
        score = merge_dict[key]
        
        if q not in result_dict: #query emege at first
            result_dict[q] = [(db, score)]
        else: #query emege repeatedly
            db_list = result_dict[q]
            db_list.append((db, score))
            result_dict[q] = db_list
    
    #Matching Algorithm
    if algo == 'max':
    
        result_topk = {}

        for q in result_dict.keys():

            #use match score sorted
            topk_list = sorted(result_dict[q], key=lambda x: -x[1])[:topk] #(db, score)
            result_topk[q] = topk_list
            
    else: #algo == 'erase'
        result_topk = {}
        erase_list = []

        for q in result_dict.keys():

            #erase top1 matching
            erased_list = [(db, score) for (db, score) in result_dict[q] if db not in erase_list]

            if len(erased_list) != 0: #not empty
                #use match score sorted
                topk_list = sorted(erased_list, key=lambda x: -x[1])[:topk] #(db, score)
                result_topk[q] = topk_list

                #add erased top1 matched db, not score
                erase_list.append(topk_list[0][0])
            else: #empty -> score maximization
                topk_list = sorted(result_dict[q], key=lambda x: -x[1])[:topk] #(db, score)
                result_topk[q] = topk_list
    
    
    ############ evaluation ################
    #ex) query_name : '0_100_False'/ db_name : '2_100_False'
    ########################################
        
    #matched pair의 score
    matched_score = []
    #unmatched pair의 score
    unmatched_score = []

    #panorama mAP : topk 안에 들면 +1/ 아니면 0 -> 파노라마 당 ap 
    pap = 0
    #crop mAP : 간판 크롭 1개 당 ap
    ap = 0

    for q_name in list(result_topk.keys()):
        #query_label
        q_lbl = int(q_name.split('_')[1])

        for rank, (db_name, score) in enumerate(result_topk[q_name], start=1): #only db_name
            #db_label
            db_lbl = int(db_name.split('_')[1])

            if (q_lbl == db_lbl) & (q_lbl != 100): #('False' in q_name) # Count #100 not matched / + except change

                #if matched, record score
                matched_score.append(score)

                #panorama mAP
                pap += 1
                #crop mAP
                qap = 1/rank
                ap += qap
            else: #unmatched

                unmatched_score.append(score)

    #crop mAP
    matched_cnt = len([q_name for q_name in result_topk.keys() if ('100' not in q_name)]) #& ('True' not in q_name)]) ##Divide #except unmatched '100' / + except change

    #panorama mAP
    try:
        pap/=matched_cnt
        print(f"macro AP@{topk} is {round(pap,2)}") #result_topk 자체가 filter

    except:
        if matched_cnt == 0:
            print(f"there is no gt, pass!") #result_topk 자체가 filter
        else:
            print("stop! - what happen now?")
            
    return result_topk, pap, (ap, matched_cnt), (matched_score, unmatched_score)

In [11]:
####### hyper params ########

result_path = './data/result/'

topk = 1
method = 'vit'
match_weight = 1/4
algo = 'max'

In [14]:
### exec

#### variables
#result
final_result_topk = {}

#score
matched_score = []
unmatched_score = []

#panorama mAP
pap_list = []
#crop mAP
cap = 0
matched_cnt = 0

# check time
start = time.perf_counter()

# run merge_val
for i, p_id in enumerate(pnorm_q_dic.keys()): #p_id : panorama_id
    
    #process
    print(f"Panoram_ID : {p_id} - {i+1}/{len(pnorm_q_dic)}")
    
    try: # query ID == db ID
        result_topk, pap, cap_tuple, score_tuple = merge_topk_val(result_path=result_path, q_panorama_id=p_id, db_panorama_id=p_id, \
                                                                  topk=topk, match_weight=match_weight, method=method, algo=algo)
        print(result_topk)
        final_result_topk[p_id] = result_topk #[(db_name, score), ..., (db_name, score)]
        
        #score thres?
        matched_score += score_tuple[0]
        unmatched_score += score_tuple[1]

        #panorama mAP
        pap_list.append(pap)
        #crop mAP == recall
        cap += cap_tuple[0]
        matched_cnt += cap_tuple[1]

    except: # panorama id가 없을 때
        raise Exception(f"panorama_id {p_id} error")

################ Recall& Precision #####################
#check time
total_time = time.perf_counter()-start
print("total_time", time.perf_counter()-start)
print("query_time", total_time/len(pnorm_q_dic.keys()))

# magic line for distinguish 'matched' from 'unmatched'
match_thres = match_weight*np.mean(matched_score) + (1-match_weight)*np.mean(unmatched_score)

# final result dict
final_dict = {}

# TP : True Positive
# FP : False Positive
tp = 0
fp = 0 

for p_id in list(final_result_topk.keys()): #p_id : panorama_id
    
    thres_result = {}
    
    for q_name in final_result_topk[p_id].keys():
        #query_label
        q_lbl = int(q_name.split('_')[1])
        
        thres_list = []
        
        for rank, (db_name, score) in enumerate(final_result_topk[p_id][q_name], start=1): #only db_name
            #db_label
            db_lbl = int(db_name.split('_')[1])
            
            if score > match_thres: # positve : predict top 1 db
                
                if (q_lbl == db_lbl) & (q_lbl != 100):
                    tp += 1
                else:
                    fp += 1
                    
                thres_list.append(db_name)    
            else: # negative : predict empty list []
                pass
        thres_result[q_name] = thres_list
        
    final_dict[p_id] = thres_result

print(f'macro_mAP@{topk} : {round(np.mean(pap_list),2)}')
print(f'micro_mAP@{topk} : {round(cap/matched_cnt,2)}')
print(f"recall@1:{round(tp/matched_cnt,2)}") #-> 이건 match_thres를 필터링 하지 않았을 때 top1을 recall로 하고!
print(f"precision@1:{round(tp/(tp+fp),2)}") #-> 이건 match_thres 잡아서 precision 잡고!

ID : 870 - 0/97
macro AP@1 is 1.0
{'0_16_False': [('2_16_False', 0.17323413)], '1_5_False': [('6_5_False', 0.43436024)], '2_16_False': [('2_16_False', 0.24676284)], '3_100_False': [('0_100_False', 0.16534372)], '4_100_False': [('5_100_False', 0.19759382)], '5_100_False': [('8_100_False', 0.20762773)], '6_6_False': [('9_6_False', 0.29445603)], '7_10_False': [('11_10_False', 0.20513643)], '8_9_False': [('16_9_False', 0.2966834)], '9_7_False': [('13_7_False', 0.48219416)], '10_15_False': [('22_15_False', 0.56877494)], '11_12_False': [('10_12_False', 0.33753365)], '12_11_False': [('18_11_False', 0.32965082)], '13_4_False': [('19_4_False', 0.46326485)], '14_8_False': [('20_8_False', 0.6241819)], '15_2_False': [('24_2_False', 0.49871507)], '16_13_False': [('25_13_False', 0.6760753)], '17_1_False': [('17_1_False', 0.37368995)], '18_100_False': [('0_100_False', 0.19516557)], '19_14_False': [('27_14_False', 0.4602993)], '20_100_False': [('21_100_False', 0.22136295)], '21_100_False': [('23_100_F