In [1]:
import json
import torch
import faiss

from src.multimodal_retriever.retriever import Retriever
from src.utils.utils import load_model


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Init the model
retriever = Retriever()

In [3]:
# Load the model from the checkpoint
model_path = "../../trained_model/trained_model_20250731-223737/model_20250731-225721.pt"
retriever = load_model(retriever, model_path)

In [4]:
retriever.parameters()

<generator object Module.parameters at 0x7fdd30122960>

In [5]:
# Load the public test of the task 1
public_test_1_path = "../../data/VLSP 2025 - MLQA-TSR Data Release/public_test/vlsp_2025_public_test_task2.json"
with open(public_test_1_path, "r") as f:
    public_test = json.load(f)

In [6]:
# Load the embedding of the document from databases
document_embedding_path = "../../data/record_id_to_document_embedding.json"
with open(document_embedding_path, "r") as f:
    document_embeddings = json.load(f)

In [7]:
# Convert the embedding of the document from list into tensor
tensor_document_embeddings = {}
for key, values in document_embeddings.items():
    values["embedding"] = torch.tensor(values["embedding"])
    tensor_document_embeddings[key] = values

In [8]:
record_id_to_embedding = {}
for key, values in tensor_document_embeddings.items():
    record_id_to_embedding[key] = values["embedding"]

In [9]:
# Create a FAISS index. The dimension must match the embeddings
key, embedding = list(record_id_to_embedding.items())[0]
print(embedding.shape)

torch.Size([1024])


# Create a FAISS index. The dimension must match the embeddings


In [10]:
key, embedding = list(record_id_to_embedding.items())[0]
print(f"Embedding dimension: {embedding.shape}")

embedding_dim = embedding.shape[0]
index = faiss.IndexFlatIP(embedding_dim)

# Add all embeddings to the index
embeddings_matrix = []
record_ids = []
for record_id, embedding in record_id_to_embedding.items():
    embeddings_matrix.append(embedding.detach().cpu().numpy())
    record_ids.append(record_id)

embeddings_matrix = torch.stack([torch.tensor(emb) for emb in embeddings_matrix]).numpy()
index.add(embeddings_matrix)

print(f"Index created with {index.ntotal} embeddings")


Embedding dimension: torch.Size([1024])
Index created with 395 embeddings


# Load images and test the retriever


In [11]:
import os
from PIL import Image
import numpy as np

def test_retriever_on_public_test(retriever, public_test, index, record_ids, tensor_document_embeddings, k=5):
    """
    Test the retriever model on public test data

    Args:
        retriever: The trained retriever model
        public_test: List of test questions
        index: FAISS index for similarity search
        record_ids: List of record IDs corresponding to index
        tensor_document_embeddings: Dictionary mapping record_id to document info
        k: Number of top documents to retrieve
    """
    results = []
    image_base_path = "../../data/VLSP 2025 - MLQA-TSR Data Release/public_test/public_test_images/public_test_images"

    retriever.eval()
    with torch.no_grad():
        for test_item in public_test:
            test_id = test_item["id"]
            image_id = test_item["image_id"]
            question = test_item["question"]

            # Load and process image
            image_path = os.path.join(image_base_path, f"{image_id}.jpg")
            if not os.path.exists(image_path):
                print(f"Warning: Image not found - {image_path}")
                continue

            try:
                image = Image.open(image_path).convert('RGB')
                image = image.resize((224, 224))

                # Get multimodal embedding from retriever
                multimodal_embedding = retriever(image, question)
                query_embedding = multimodal_embedding.reshape(1, -1)

                # Search in FAISS index
                scores, indices = index.search(query_embedding, k)

                # Get top k results
                top_results = []
                for i, (score, idx) in enumerate(zip(scores[0], indices[0])):
                    record_id = record_ids[idx]
                    document_text = tensor_document_embeddings[record_id]["text"]
                    top_results.append({
                        "rank": i + 1,
                        "record_id": record_id,
                        "score": float(score),
                        "text": document_text
                    })

                results.append({
                    "test_id": test_id,
                    "image_id": image_id,
                    "question": question,
                    "top_results": top_results
                })

                print(f"Processed {test_id}: {question[:50]}...")

            except Exception as e:
                print(f"Error processing {test_id}: {str(e)}")
                continue

    return results


# Run the test


In [12]:
test_results = test_retriever_on_public_test(
    retriever,
    public_test,
    index,
    record_ids,
    tensor_document_embeddings,
    k=5
)
print(f"Completed testing on {len(test_results)} items")


Processed public_test_51: Trong tất cả biển báo trong hình bên, hãy cho biết...
Processed public_test_52: Các biển báo xuất hiện trong hình bên là loại biển...
Processed public_test_53: Đây là biển báo cấm vượt. Đúng hay sai?...
Processed public_test_54: "Biển có viền đỏ, nền trắng, hình vẽ màu đen trong...
Processed public_test_55: Hướng đi đến Thành phố Vinh là hướng đi thẳng, đún...
Processed public_test_56: Tốc độ nào sau đây mà người lái xe di chuyển trên ...
Processed public_test_57: Biển báo trong hình có ý nghĩa gì?...
Processed public_test_58: Trên phần đường có đặt biển báo trên thì đối tượng...
Processed public_test_59: Biển báo trong ảnh cảnh  báo điều gì?...
Processed public_test_60: Các biển báo trong ảnh thuộc loại biển báo gì?...
Processed public_test_61: Khi gặp biển báo trong ảnh vào ngày lẻ với biển bá...
Processed public_test_62: Biển báo trong hình là loại biển gì?...
Processed public_test_63: Theo các biển chỉ dẫn trên ảnh thì nút giao tại ng...
Processed public_t

# Display results for first few test cases


In [13]:
for i, result in enumerate(test_results[:10]):
    print(f"\n{'='*50}")
    print(f"Test ID: {result['test_id']}")
    print(f"Image ID: {result['image_id']}")
    print(f"Question: {result['question']}")
    print(f"Top 3 Retrieved Documents:")

    for j, doc in enumerate(result['top_results'][:3]):
        print(f"\n--- Rank {doc['rank']} (Score: {doc['score']:.4f}) ---")
        print(f"Record ID: {doc['record_id']}")
        print(f"Text: {doc['text'][:200]}...")



Test ID: public_test_51
Image ID: public_test_5_6
Question: Trong tất cả biển báo trong hình bên, hãy cho biết xe gắn máy bị cấm đi thẳng trong khoảng thời gian nào?
Top 3 Retrieved Documents:

--- Rank 1 (Score: 0.9946) ---
Record ID: QCVN 41:2024/BGTVT#22
Text: Ý nghĩa sử dụng các biển báo cấm
22.1.   tên các biển như sau:
- Biển số P.101: Đường cấm;
- Biển số P.102: Cấm đi ngược chiều;
- Biển số P.103a: Cấm xe ô tô;
- Biển số P.103(b,c): Cấm xe ô tô rẽ trái...

--- Rank 2 (Score: 0.8593) ---
Record ID: QCVN 41:2024/BGTVT#M.1
Text: M.1  Nhóm biển báo cấm
Biển số P.101: Đường cấm


Biển số P.102: Cấm đi ngược chiều


Biển số P.103a: Cấm xe ôtô




| Loại đường   |   Đường cao tốc |   Đường đôi ngoài đô thị |   Đường thông thường ...

--- Rank 3 (Score: 0.8451) ---
Record ID: QCVN 41:2024/BGTVT#32
Text: Ý nghĩa sử dụng các biển hiệu lệnh
32.1. Biển hiệu lệnh có mã R và R.E với tên các biển như sau:
- Biển số R.122: Dừng lại;
- Biển số R.301(a,b,c,d,e,f,g,h): Hướng đi phải theo;
- Biển

# Filter results by similarity threshold

In [14]:
def filter_results_by_threshold(test_results, threshold=0.7):
    """
    Filter retrieved documents to only keep those with similarity score > threshold

    Args:
        test_results: List of test results from retriever
        threshold: Minimum similarity score to keep (default: 0.8)

    Returns:
        Filtered test results with high-confidence retrievals only
    """
    filtered_results = []

    for result in test_results:
        # Filter top_results to only include scores > threshold
        high_confidence_results = [
            doc for doc in result['top_results']
            if doc['score'] > threshold
        ]

        # Update ranks after filtering
        for i, doc in enumerate(high_confidence_results):
            doc['rank'] = i + 1

        # Create filtered result
        filtered_result = {
            "test_id": result['test_id'],
            "image_id": result['image_id'],
            "question": result['question'],
            "top_results": high_confidence_results,
            "original_count": len(result['top_results']),
            "filtered_count": len(high_confidence_results)
        }

        filtered_results.append(filtered_result)

    return filtered_results

# Apply filtering with threshold 0.8

In [15]:
similarity_threshold = 0.8
filtered_test_results = filter_results_by_threshold(test_results, similarity_threshold)

print(f"Filtering results with similarity threshold > {similarity_threshold}")
print(f"Total test cases: {len(filtered_test_results)}")

Filtering results with similarity threshold > 0.8
Total test cases: 50


In [16]:
for test_result in filtered_test_results:
    top_results = test_result["top_results"]
    test_result["relevant_articles"] = []
    for top_result in top_results:
        law_article_id = top_result["record_id"].strip().split("#")
        law_id = law_article_id[0]
        article_id = law_article_id[1]
        test_result["relevant_articles"].append({
            "law_id": law_id,
            "article_id": article_id,
        })

In [18]:
def extract_law_and_article_id(record_id):
    """
    Extract law_id and article_id from record_id
    Expected format: "law_id#article_id"
    """
    try:
        parts = record_id.split('#')
        if len(parts) == 2:
            return parts[0], parts[1]
        else:
            return None, None
    except:
        return None, None

def calculate_f2_score_per_sample(retrieved_record_ids, relevant_articles):
    """
    Calculate F2 score for a single test sample

    Args:
        retrieved_record_ids: List of record IDs retrieved by the model
        relevant_articles: List of relevant articles from ground truth
                          Each item has structure: {"law_id": "...", "article_id": "..."}

    Returns:
        Dictionary with precision, recall, f2_score
    """
    # Convert relevant articles to set of (law_id, article_id) tuples
    relevant_set = set()
    for article in relevant_articles:
        relevant_set.add((article["law_id"], article["article_id"]))

    # Convert retrieved record IDs to set of (law_id, article_id) tuples
    retrieved_set = set()
    for record_id in retrieved_record_ids:
        law_id, article_id = extract_law_and_article_id(record_id)
        if law_id and article_id:
            retrieved_set.add((law_id, article_id))

    # Calculate metrics
    num_retrieved = len(retrieved_set)
    num_relevant = len(relevant_set)
    num_correct = len(retrieved_set.intersection(relevant_set))

    # Calculate precision and recall
    precision = num_correct / num_retrieved if num_retrieved > 0 else 0.0
    recall = num_correct / num_relevant if num_relevant > 0 else 0.0

    # Calculate F2 score
    if precision + recall > 0:
        f2_score = (5 * precision * recall) / (4 * precision + recall)
    else:
        f2_score = 0.0

    return {
        "precision": precision,
        "recall": recall,
        "f2_score": f2_score,
        "num_retrieved": num_retrieved,
        "num_relevant": num_relevant,
        "num_correct": num_correct,
        "retrieved_articles": retrieved_set,
        "relevant_articles": relevant_set,
        "correct_articles": retrieved_set.intersection(relevant_set)
    }


In [19]:
def calculate_overall_f2_score(filtered_test_results, public_test):
    """
    Calculate F2 score for all test samples

    Args:
        filtered_test_results: Results from the retriever model (filtered)
        public_test: Ground truth test data

    Returns:
        Dictionary with overall metrics and per-sample details
    """
    # Create mapping from test_id to public_test item
    public_test_dict = {item["id"]: item for item in public_test}

    per_sample_metrics = []
    total_precision = 0.0
    total_recall = 0.0
    total_f2 = 0.0
    samples_processed = 0

    for result in filtered_test_results:
        test_id = result["test_id"]

        # Find corresponding ground truth
        if test_id not in public_test_dict:
            print(f"Warning: Test ID {test_id} not found in public test data")
            continue

        ground_truth = public_test_dict[test_id]
        relevant_articles = ground_truth.get("relevant_articles", [])

        # Extract retrieved record IDs
        retrieved_record_ids = [doc["record_id"] for doc in result["top_results"]]

        # Calculate metrics for this sample
        sample_metrics = calculate_f2_score_per_sample(retrieved_record_ids, relevant_articles)
        sample_metrics["test_id"] = test_id
        sample_metrics["question"] = result["question"]

        per_sample_metrics.append(sample_metrics)

        # Accumulate for overall metrics
        total_precision += sample_metrics["precision"]
        total_recall += sample_metrics["recall"]
        total_f2 += sample_metrics["f2_score"]
        samples_processed += 1

    # Calculate averages
    avg_precision = total_precision / samples_processed if samples_processed > 0 else 0.0
    avg_recall = total_recall / samples_processed if samples_processed > 0 else 0.0
    avg_f2 = total_f2 / samples_processed if samples_processed > 0 else 0.0

    return {
        "overall_metrics": {
            "avg_precision": avg_precision,
            "avg_recall": avg_recall,
            "avg_f2_score": avg_f2,
            "samples_processed": samples_processed,
            "total_samples": len(filtered_test_results)
        },
        "per_sample_metrics": per_sample_metrics
    }


In [20]:
# Calculate F2 scores
print("Calculating F2 scores...")
f2_results = calculate_overall_f2_score(filtered_test_results, public_test)

# Display overall results
overall = f2_results["overall_metrics"]
print(f"\n{'='*60}")
print("F2 SCORE EVALUATION RESULTS")
print(f"{'='*60}")
print(f"Samples processed: {overall['samples_processed']}/{overall['total_samples']}")
print(f"Average Precision: {overall['avg_precision']:.4f}")
print(f"Average Recall: {overall['avg_recall']:.4f}")
print(f"Average F2 Score: {overall['avg_f2_score']:.4f}")

Calculating F2 scores...

F2 SCORE EVALUATION RESULTS
Samples processed: 50/50
Average Precision: 0.1480
Average Recall: 0.3417
Average F2 Score: 0.2682


# Analyze filtering impact

In [27]:
def analyze_filtering_impact(filtered_results, threshold):
    """Analyze the impact of filtering on retrieval results"""

    total_cases = len(filtered_results)
    cases_with_results = sum(1 for r in filtered_results if len(r['top_results']) > 0)
    cases_without_results = total_cases - cases_with_results

    # Calculate total documents before/after filtering
    total_original = sum(r['original_count'] for r in filtered_results)
    total_filtered = sum(r['filtered_count'] for r in filtered_results)

    # Calculate average filtered count per case
    avg_filtered_per_case = total_filtered / total_cases if total_cases > 0 else 0

    # Get statistics for cases with results
    if cases_with_results > 0:
        scores_above_threshold = []
        for result in filtered_results:
            for doc in result['top_results']:
                scores_above_threshold.append(doc['score'])

        avg_high_confidence_score = np.mean(scores_above_threshold) if scores_above_threshold else 0
        min_score = min(scores_above_threshold) if scores_above_threshold else 0
        max_score = max(scores_above_threshold) if scores_above_threshold else 0
    else:
        avg_high_confidence_score = 0
        min_score = 0
        max_score = 0

    print(f"\n{'=' * 50}")
    print(f"FILTERING ANALYSIS (Threshold: {threshold})")
    print(f"{'=' * 50}")
    print(f"Total test cases: {total_cases}")
    print(f"Cases with high-confidence results: {cases_with_results} ({cases_with_results / total_cases * 100:.1f}%)")
    print(
        f"Cases without high-confidence results: {cases_without_results} ({cases_without_results / total_cases * 100:.1f}%)")
    print(f"Total documents before filtering: {total_original}")
    print(f"Total documents after filtering: {total_filtered}")
    print(f"Filtering retention rate: {total_filtered / total_original * 100:.1f}%" if total_original > 0 else "N/A")
    print(f"Average high-confidence documents per case: {avg_filtered_per_case:.2f}")

    if cases_with_results > 0:
        print(f"Average score of high-confidence results: {avg_high_confidence_score:.4f}")
        print(f"Score range: {min_score:.4f} - {max_score:.4f}")

    return {
        "total_cases": total_cases,
        "cases_with_results": cases_with_results,
        "cases_without_results": cases_without_results,
        "retention_rate": total_filtered / total_original if total_original > 0 else 0,
        "avg_filtered_per_case": avg_filtered_per_case,
        "avg_high_confidence_score": avg_high_confidence_score
    }


filtering_analysis = analyze_filtering_impact(filtered_test_results, similarity_threshold)


FILTERING ANALYSIS (Threshold: 0.9)
Total test cases: 50
Cases with high-confidence results: 42 (84.0%)
Cases without high-confidence results: 8 (16.0%)
Total documents before filtering: 250
Total documents after filtering: 49
Filtering retention rate: 19.6%
Average high-confidence documents per case: 0.98
Average score of high-confidence results: 0.9486
Score range: 0.9037 - 0.9936


# Display filtered results for cases with high-confidence retrievals

In [103]:
print(f"\n{'=' * 60}")
print("HIGH-CONFIDENCE RETRIEVAL RESULTS")
print(f"{'=' * 60}")

high_confidence_cases = [r for r in filtered_test_results if len(r['top_results']) > 0]

for i, result in enumerate(high_confidence_cases[:5]):  # Show first 5 cases with results
    print(f"\n{'-' * 40}")
    print(f"Test ID: {result['test_id']}")
    print(f"Question: {result['question']}")
    print(f"High-confidence documents found: {result['filtered_count']}")

    for doc in result['top_results'][:3]:  # Show top 3 high-confidence results
        print(f"\n  Rank {doc['rank']} (Score: {doc['score']:.4f})")
        print(f"  Record ID: {doc['record_id']}")
        print(f"  Text: {doc['text'][:200]}...")

# Display cases without high-confidence results
print(f"\n{'=' * 60}")
print("CASES WITHOUT HIGH-CONFIDENCE RESULTS")
print(f"{'=' * 60}")

no_confidence_cases = [r for r in filtered_test_results if len(r['top_results']) == 0]

for i, result in enumerate(no_confidence_cases[:5]):  # Show first 5 cases without results
    print(f"\n{'-' * 40}")
    print(f"Test ID: {result['test_id']}")
    print(f"Question: {result['question']}")
    print(f"Original results: {result['original_count']}, After filtering: {result['filtered_count']}")



HIGH-CONFIDENCE RETRIEVAL RESULTS

----------------------------------------
Test ID: public_test_51
Question: Trong tất cả biển báo trong hình bên, hãy cho biết xe gắn máy bị cấm đi thẳng trong khoảng thời gian nào?
High-confidence documents found: 5

  Rank 1 (Score: 0.9669)
  Record ID: QCVN 41:2024/BGTVT#22
  Text: Ý nghĩa sử dụng các biển báo cấm
22.1.   tên các biển như sau:
- Biển số P.101: Đường cấm;
- Biển số P.102: Cấm đi ngược chiều;
- Biển số P.103a: Cấm xe ô tô;
- Biển số P.103(b,c): Cấm xe ô tô rẽ trái...

  Rank 2 (Score: 0.8715)
  Record ID: QCVN 41:2024/BGTVT#32
  Text: Ý nghĩa sử dụng các biển hiệu lệnh
32.1. Biển hiệu lệnh có mã R và R.E với tên các biển như sau:
- Biển số R.122: Dừng lại;
- Biển số R.301(a,b,c,d,e,f,g,h): Hướng đi phải theo;
- Biển số R.302(a,b,c)...

  Rank 3 (Score: 0.8403)
  Record ID: QCVN 41:2024/BGTVT#36
  Text: Ý nghĩa sử dụng các biển chỉ dẫn
36.1. Biển chỉ dẫn trên các đường ô tô không phải là đường cao tốc có mã “I” với tên các biển như sa

In [104]:
# Save filtered results
filtered_results_path = "../../results/filtered_public_test_task1_results.json"
os.makedirs(os.path.dirname(filtered_results_path), exist_ok=True)

# Add metadata to the saved results
filtered_results_with_metadata = {
    "metadata": {
        "similarity_threshold": similarity_threshold,
        "total_test_cases": len(filtered_test_results),
        "cases_with_high_confidence_results": len(high_confidence_cases),
        "cases_without_high_confidence_results": len(no_confidence_cases),
        "filtering_retention_rate": filtering_analysis["retention_rate"],
        "average_high_confidence_score": filtering_analysis["avg_high_confidence_score"]
    },
    "results": filtered_test_results
}

with open(filtered_results_path, "w", encoding="utf-8") as f:
    json.dump(filtered_results_with_metadata, f, indent=2, ensure_ascii=False)

print(f"\nFiltered results saved to: {filtered_results_path}")




Filtered results saved to: ../../results/filtered_public_test_task1_results.json


# Create a summary report

In [106]:
def create_summary_report(filtered_results, threshold):
    """Create a detailed summary report of filtering results"""

    print(f"\n{'=' * 60}")
    print("DETAILED SUMMARY REPORT")
    print(f"{'=' * 60}")

    # Distribution of filtered document counts
    count_distribution = {}
    for result in filtered_results:
        count = result['filtered_count']
        count_distribution[count] = count_distribution.get(count, 0) + 1

    print(f"\nDistribution of high-confidence documents per test case:")
    for count in sorted(count_distribution.keys()):
        cases = count_distribution[count]
        percentage = cases / len(filtered_results) * 100
        print(f"  {count} documents: {cases} cases ({percentage:.1f}%)")

    # Score distribution for high-confidence results
    all_scores = []
    for result in filtered_results:
        for doc in result['top_results']:
            all_scores.append(doc['score'])

    if all_scores:
        score_ranges = {
            f"{threshold:.1f}-0.85": sum(1 for s in all_scores if threshold < s <= 0.85),
            "0.85-0.90": sum(1 for s in all_scores if 0.85 < s <= 0.90),
            "0.90-0.95": sum(1 for s in all_scores if 0.90 < s <= 0.95),
            "0.95-1.00": sum(1 for s in all_scores if 0.95 < s <= 1.00)
        }

        print(f"\nScore distribution of high-confidence results:")
        for range_name, count in score_ranges.items():
            percentage = count / len(all_scores) * 100 if all_scores else 0
            print(f"  {range_name}: {count} documents ({percentage:.1f}%)")


create_summary_report(filtered_test_results, similarity_threshold)



DETAILED SUMMARY REPORT

Distribution of high-confidence documents per test case:
  1 documents: 1 cases (2.0%)
  2 documents: 6 cases (12.0%)
  3 documents: 4 cases (8.0%)
  4 documents: 2 cases (4.0%)
  5 documents: 37 cases (74.0%)

Score distribution of high-confidence results:
  0.8-0.85: 135 documents (61.9%)
  0.85-0.90: 38 documents (17.4%)
  0.90-0.95: 17 documents (7.8%)
  0.95-1.00: 28 documents (12.8%)


# Optional: Adjust threshold and compare results

In [107]:
print(f"\n{'=' * 60}")
print("THRESHOLD COMPARISON")
print(f"{'=' * 60}")

thresholds_to_test = [0.7, 0.75, 0.8, 0.85, 0.9]
threshold_comparison = {}

for thresh in thresholds_to_test:
    filtered_at_thresh = filter_results_by_threshold(test_results, thresh)
    cases_with_results = sum(1 for r in filtered_at_thresh if len(r['top_results']) > 0)
    total_docs = sum(r['filtered_count'] for r in filtered_at_thresh)

    threshold_comparison[thresh] = {
        'cases_with_results': cases_with_results,
        'total_documents': total_docs,
        'coverage_rate': cases_with_results / len(filtered_at_thresh) * 100
    }

print(f"{'Threshold':<10} {'Cases w/ Results':<15} {'Total Docs':<12} {'Coverage Rate':<12}")
print(f"{'-' * 50}")
for thresh, stats in threshold_comparison.items():
    print(
        f"{thresh:<10.2f} {stats['cases_with_results']:<15} {stats['total_documents']:<12} {stats['coverage_rate']:<12.1f}%")


THRESHOLD COMPARISON
Threshold  Cases w/ Results Total Docs   Coverage Rate
--------------------------------------------------
0.70       50              250          100.0       %
0.75       50              243          100.0       %
0.80       50              218          100.0       %
0.85       48              83           96.0        %
0.90       40              45           80.0        %
