In [None]:
%cd D:/VNM-Multimodal-Video-Search

D:\VNM-Multimodal-Video-Search


In [None]:
from utils.embedding_based_search.embedding_based_search import EmbeddingBasedSearch
from utils.embedding_based_search.clip_engine import CLIP
from utils.embedding_based_search.blip_engine import BLIP
from utils.embedding_based_search.beit_engine import BEIT

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
BLIP_DIR = './dict/blip'

blip_vit_engine = BLIP(
    blip_bin_file=f'{BLIP_DIR}/blip_vit.bin',
    blip_id2image_path=f'{BLIP_DIR}/blip_vit.json',
    model_type='pretrain_vitL'
)

  state_dict = torch.load(cached_file, map_location="cpu")
  checkpoint = torch.load(cached_file, map_location="cpu")


In [None]:
embedding_based_search = EmbeddingBasedSearch(blip_vit_engine=blip_vit_engine)

In [None]:
embedding_based_search.update_searching_mode(
    clip_h14_engine = False,
    clip_h14_xlm_engine= False,
    clip_l14_engine = False,
    blip_vit_engine = True,
    blip_pretrain_engine = False,
    beit_base_engine = False,
    beit_large_engine = False
)

In [None]:
result = embedding_based_search.text_search(query_text='A wall with a painting of environmental protection. On the wall is a painting of a dolphin and a sea turtle.',
                                            image_path_subset=None,
                                            top_k=10)
result

{'./distilled_keyframe/L12/V003/8933.jpg': 0.43697613,
 './distilled_keyframe/L12/V003/9000.jpg': 0.43621606,
 './distilled_keyframe/L12/V003/9046.jpg': 0.42924848,
 './distilled_keyframe/L22/V013/9230.jpg': 0.4277385,
 './distilled_keyframe/L12/V003/9123.jpg': 0.42514232,
 './distilled_keyframe/L22/V013/9180.jpg': 0.4193089,
 './distilled_keyframe/L12/V003/9063.jpg': 0.41604197,
 './distilled_keyframe/L22/V013/8962.jpg': 0.4158685,
 './distilled_keyframe/L22/V013/8928.jpg': 0.4157328,
 './distilled_keyframe/L01/V013/24170.jpg': 0.41529515}

### **Evaluation**

In [None]:
def calculate_recall_at_k(ground_truth, results, k):
    """
    Tính Recall@k.
    ground_truth: List chứa các keyframe đúng.
    results: Dictionary kết quả trả về {keyframe: score}.
    k: Số lượng kết quả được xét.
    """
    # Lấy top-k keyframes từ kết quả trả về
    top_k_results = list(results)[:k]

    # Đếm số lượng ground truth xuất hiện trong top-k
    relevant_count = sum(1 for item in ground_truth if item in top_k_results)

    # Tính Recall@k
    return relevant_count / len(ground_truth)


In [None]:
def calculate_map(ground_truth, results):
    """
    Tính Mean Average Precision (mAP).
    ground_truth: List chứa các keyframe đúng.
    results: Dictionary kết quả trả về {keyframe: score}.
    """
    relevant_count = 0
    precision_sum = 0.0
    for i, keyframe in enumerate(results):
        if keyframe in ground_truth:
            relevant_count += 1
            precision_sum += relevant_count / (i + 1)  # Precision tại vị trí i+1
    return precision_sum / relevant_count if relevant_count > 0 else 0.0

In [None]:
def calculate_mrr(ground_truth, results):
    """
    Tính Mean Reciprocal Rank (MRR).
    ground_truth: List chứa các keyframe đúng.
    results: Dictionary kết quả trả về {keyframe: score}.
    """
    for i, keyframe in enumerate(results):
        if keyframe in ground_truth:
            return 1 / (i + 1)  # Reciprocal Rank
    return 0  # Không có Ground Truth nào trong kết quả

In [None]:
import csv

def search_function(query_text, top_k=10):
    return embedding_based_search.text_search(query_text=query_text, image_path_subset=None, top_k=top_k)

def extract_relevant_parts(results):
    """
    Trích xuất phần 'L__/V___/id_frame' từ danh sách đường dẫn.

    Args:
        results (dict): Kết quả tìm kiếm {path: score}.

    Returns:
        list: Danh sách các phần 'L__/V___/id_frame'.
    """
    extracted = []
    for path in results.keys():
        # Tách đường dẫn và giữ lại phần 'L__/V___/id_frame'
        parts = path.split('/')
        if len(parts) >= 4:  # Đảm bảo có đủ phần để trích xuất
            extracted.append('/'.join(parts[-3:]).split('.')[0])
    return extracted

def evaluate_metrics(csv_file, search_function):
    """
    Tính toán Recall@k, mAP@50, MRR cho tất cả các query trong file CSV.
    csv_file: Đường dẫn file CSV chứa query và ground truth.
    search_function: Hàm tìm kiếm trả về {keyframe: score}.
    """
    recall_1_scores = []
    recall_5_scores = []
    recall_20_scores = []
    recall_50_scores = []
    recall_100_scores = []
    map_scores = []
    mrr_scores = []

    with open(csv_file, 'r') as file:
        reader = csv.DictReader(file)
        for row in reader:
            query_text = row['Query']
            base_dir = row['Base_dir']
            frames = row['Frames'].split(' ')
            ground_truth = [f'{base_dir}/{frame}' for frame in frames]

            results = search_function(query_text, top_k=100)  # Hàm trả về {keyframe: score}
            results = extract_relevant_parts(results)  # Trích xuất phần 'L__/V___/id_frame'

            # Tính các chỉ số Recall@k
            recall_1_scores.append(calculate_recall_at_k(ground_truth, results, k=1))
            recall_5_scores.append(calculate_recall_at_k(ground_truth, results, k=5))
            recall_20_scores.append(calculate_recall_at_k(ground_truth, results, k=20))
            recall_50_scores.append(calculate_recall_at_k(ground_truth, results, k=50))
            recall_100_scores.append(calculate_recall_at_k(ground_truth, results, k=100))

            # Tính mAP@50
            map_scores.append(calculate_map(ground_truth, results))

            # Tính MRR (giới hạn 50)
            mrr_scores.append(calculate_mrr(ground_truth, results))

    # Trung bình các chỉ số
    avg_recall_1 = sum(recall_1_scores) / len(recall_1_scores)
    avg_recall_5 = sum(recall_5_scores) / len(recall_5_scores)
    avg_recall_20 = sum(recall_20_scores) / len(recall_20_scores)
    avg_recall_50 = sum(recall_50_scores) / len(recall_50_scores)
    avg_recall_100 = sum(recall_100_scores) / len(recall_100_scores)
    avg_map = sum(map_scores) / len(map_scores)
    avg_mrr = sum(mrr_scores) / len(mrr_scores)

    return {
        'Recall@1': avg_recall_1,
        'Recall@5': avg_recall_5,
        'Recall@20': avg_recall_20,
        'Recall@50': avg_recall_50,
        'Recall@100': avg_recall_100,
        'mAP': avg_map,
        'MRR': avg_mrr
    }

In [None]:
# Đường dẫn đến file CSV
csv_file = './test.csv'

# Gọi hàm evaluate_metrics
metrics = evaluate_metrics(csv_file, search_function)

print('Kết quả đánh giá BLIP:')
print(f'    Recall@1: {metrics["Recall@1"]:.4f}')
print(f'    Recall@5: {metrics["Recall@5"]:.4f}')
print(f'    Recall@20: {metrics["Recall@20"]:.4f}')
print(f'    Recall@50: {metrics["Recall@50"]:.4f}')
print(f'    Recall@100: {metrics["Recall@100"]:.4f}')
print(f'    mAP: {metrics["mAP"]:.4f}')
print(f'    MRR: {metrics["MRR"]:.4f}')


Kết quả đánh giá BLIP:
    Recall@1: 0.1167
    Recall@5: 0.3659
    Recall@20: 0.5849
    Recall@50: 0.6562
    Recall@100: 0.7135
    mAP: 0.3564
    MRR: 0.4345
