In [None]:
from final_evaluation import *

In [None]:
import os
import pandas as pd

test_dir = '/media/minhduck/One Touch/official/images/test_retrieval'
output_dir = os.path.join('working/', test_dir.replace('/', '_'))

In [None]:
print("Capturing depth buffer of stl meshes...")
capture_depth_buffer(test_dir, output_dir)

In [None]:
from argparse import Namespace
# args = Namespace(
#   model_name='caformer_m36.sail_in22k_ft_in1k_384',
#   no_head=True,
#   device='cuda',
#   save_path='./working/trained_models/oml/ckpt_02',
# )

args = Namespace(
  model_name='caformer_s36.sail_in22k_ft_in1k_384',
  no_head=True,
  device='cuda',
  save_path='./working/trained_models/oml/ckpt_05',
)


def image_retrieval(args, test_dir, output_dir):
  config = dict(args._get_kwargs())
  extractor = TimmExtractor(args.model_name, config)
  
  if not os.path.isfile(os.path.join(args.save_path, 'pytorch_model.bin')):
    assert False, "Model does not exist."
    
  best_score, state_dict = torch.load(os.path.join(args.save_path, 'pytorch_model.bin'))
  extractor.load_state_dict(state_dict)
  print('Loaded pretrained model with score=%.4f' % best_score)
  
  db_img_path = os.path.join(output_dir, 'pairs')
  q_img_path = os.path.join(test_dir, 'queries')
  n_db = len(os.listdir(db_img_path))
  n_q = len(os.listdir(q_img_path))
  
  eval_model = extractor
  db_vecs = []
  eval_model.to(args.device)

  for i in tqdm(range(n_db), position=0, leave=True):
    db_img = cv2.imread(os.path.join(db_img_path, '%d.png' % i))
    x = test_transform(image=db_img)['image'][[0], :, :]
    x = x.to(args.device)[None]
    eval_model.eval()
    with torch.no_grad():
      db_vec = eval_model(x)
    db_vec = db_vec.detach().cpu().numpy()
    db_vecs += [db_vec]

  def score_query(q_vec):
    db_scores = []
    for i in range(n_db):
      db_vec = db_vecs[i]
      db_scores += [-((db_vec - q_vec) ** 2).sum()]
    return db_scores
    

  def retrieve():
    res = []
    for i in tqdm(range(n_q), position=0, leave=True):
      img_name = os.path.join(q_img_path, '%d.png' % i)
      q_img = cv2.imread(img_name)
      x = test_transform(image=q_img)['image'][[0], :, :]
      x = x.to(args.device)[None]
      eval_model.eval()
      with torch.no_grad():
        q_vec = eval_model(x)
      q_vec = q_vec.detach().cpu().numpy()
      db_scores = score_query(q_vec)
      rank = np.argsort(db_scores)[::-1]
      res += [{
        'query_name': img_name.split('/')[-1],
        'predictions': ','.join(['%d.stl' % i for i in rank[:5]]),
      }]
    res = pd.DataFrame(res)
    res.to_csv(os.path.join(output_dir, 'pred.csv'), index=False)

  retrieve()
      
      
image_retrieval(args, test_dir, output_dir)

In [None]:
pd.read_csv(os.path.join(output_dir, 'pred.csv'), skiprows=1, names=['query_name', 'correct_output'])

In [None]:


def calculate_mrr_at_5(gt_csv, pred_csv):
    # gt_df = pd.read_csv(gt_csv, header=None, names=['query_name', 'correct_output'])
    # pred_df = pd.read_csv(pred_csv, header=None, names=['query_name', 'predictions'])
    gt_df = pd.read_csv(gt_csv, skiprows=1, names=['query_name', 'correct_output'])
    pred_df = pd.read_csv(pred_csv, skiprows=1, names=['query_name', 'predictions'])
    
    reciprocal_ranks = []

    for _, gt_row in gt_df.iterrows():
        query_name = gt_row['query_name']
        correct_output = gt_row['correct_output'].split(',')[0]
        
        pred_row = pred_df[pred_df['query_name'] == query_name]
        if pred_row.empty:
            reciprocal_rank = 0
        
        predicted_outputs = pred_row.iloc[0]['predictions'].split(',')
        print(predicted_outputs, gt_row)
        try:
            rank = predicted_outputs[:5].index(correct_output) + 1
            reciprocal_rank = 1 / rank
        except ValueError:
            reciprocal_rank = 0
        print("=>", reciprocal_rank)
        
        reciprocal_ranks.append(reciprocal_rank)
    
    # Calculate Mean Reciprocal Rank (MRR)@5
    mrr_at_5 = sum(reciprocal_ranks) / len(reciprocal_ranks)
    
    return mrr_at_5

calculate_mrr_at_5(os.path.join(test_dir, 'labels.csv'), os.path.join(output_dir, 'pred.csv'))