# Retrieval 성능 측정 하기 - MRR, Recall at K

In [1]:
import numpy as np
# from langchain_openai import OpenAIEmbeddings
from langchain_google_genai import GoogleGenerativeAIEmbeddings

In [2]:
# embeddings_client = OpenAIEmbeddings()
embeddings_client = GoogleGenerativeAIEmbeddings(model='models/embedding-001')

# 예제 데이터셋 (질문과 이에 대응하는 답변)
questions = [
    "What is the capital of France?",
    "How many continents are there?",
    "What is the largest ocean?",
    "Who wrote 'Hamlet'?",
    "What is the chemical symbol for gold?",
    "How long is the Great Wall of China?",
    "What is the speed of light?",
    "Who painted the Mona Lisa?",
    "What is the smallest planet in our solar system?",
    "What is the largest country by area?",
    "What year did the Titanic sink?",
    "Who is known as the father of computers?",
    "What is the capital of Japan?",
    "How many elements are in the periodic table?",
    "What is the hardest natural substance on Earth?",
    "What is the longest river in the world?",
    "What is the main ingredient in glass?",
    "What planet is known as the Red Planet?",
    "Who discovered penicillin?",
    "What is the capital of Australia?"
]

answers = [
    "Paris",
    "7",
    "Pacific Ocean",
    "William Shakespeare",
    "Au",
    "Approximately 21,196 kilometers",
    "299,792,458 meters per second",
    "Leonardo da Vinci",
    "Mercury",
    "Russia",
    "1912",
    "Charles Babbage",
    "Tokyo",
    "118",
    "Diamond",
    "The Nile River",
    "Silicon dioxide",
    "Mars",
    "Alexander Fleming",
    "Canberra"
]

In [3]:
# 임베딩 생성
question_embeddings = embeddings_client.embed_documents(questions)
answer_embeddings = embeddings_client.embed_documents(answers)

In [4]:
question_embeddings = np.array(question_embeddings)
answer_embeddings = np.array(answer_embeddings)

In [5]:
question_embeddings.shape

(20, 768)

In [6]:
answer_embeddings.shape

(20, 768)

In [7]:
def compute_mrr(gt_indices, pred_indices):
    """Mean Reciprocal Rank 계산"""
    reciprocal_ranks = []
    for gt_idx, pred_idx in zip(gt_indices, pred_indices):
        rank = pred_idx.index(gt_idx) + 1
        reciprocal_ranks.append(1 / rank)
    return np.mean(reciprocal_ranks)

def compute_recall_at_k(gt_indices, pred_indices, k=1):
    """Recall at K 계산"""
    hits = 0
    for gt_idx, pred_idx in zip(gt_indices, pred_indices):
        if gt_idx in pred_idx[:k]:
            hits += 1
    return hits / len(gt_indices)

In [8]:
# 유사도 계산 및 정렬 (여기서는 단순 유사도로 정렬)
similarity_matrix = np.dot(question_embeddings, answer_embeddings.T)
pred_indices = np.argsort(-similarity_matrix, axis=1)
gt_indices = list(range(len(questions)))  # 각 질문에 대한 정답 인덱스

In [9]:
pred_indices

array([[ 0, 12,  9,  5, 19, 17,  4,  6,  3, 11, 13, 14,  1,  7, 15,  2,
         8, 10, 18, 16],
       [ 5,  2,  6,  9, 17, 15,  0, 12,  1,  3,  7,  4, 13, 19, 14, 11,
         8, 18, 10, 16],
       [ 2,  5,  6, 15,  9, 14, 17,  4,  3,  7, 13,  0, 12,  1,  8, 18,
        11, 19, 16, 10],
       [ 3,  7, 11, 18,  0,  9, 17, 13,  4, 14, 10,  5,  6, 12,  1, 19,
         8, 15,  2, 16],
       [14,  8, 18,  6, 16, 17,  4,  9,  7,  0,  2, 13, 12,  3, 11,  1,
        10, 19, 15,  5],
       [ 5,  6,  7,  9,  0,  3, 15, 13,  1, 12,  2,  4, 11, 17, 19, 10,
        14, 18, 16,  8],
       [ 6,  5,  7,  1, 14, 13,  4,  3, 17,  0,  9, 11, 12, 15,  8,  2,
        10, 18, 16, 19],
       [ 7,  3,  0, 11,  4, 14, 18,  9,  6, 12, 17,  1, 13, 15, 10,  2,
         5, 19,  8, 16],
       [17,  8,  6,  2,  5,  1,  4, 14,  9, 13,  7, 15, 19,  0,  3, 11,
        12, 16, 18, 10],
       [ 5,  9,  6,  4,  2, 13, 12,  1, 17,  0, 19, 14, 15,  3,  7,  8,
        10, 11, 18, 16],
       [10,  2,  3,  9,  6, 14

In [10]:
# MRR 및 Recall@k 계산
mrr = compute_mrr(gt_indices, pred_indices.tolist())
recall_at_1 = compute_recall_at_k(gt_indices, pred_indices.tolist(), k=1)
recall_at_3 = compute_recall_at_k(gt_indices, pred_indices.tolist(), k=3)

In [11]:
print(f"MRR: {mrr}")
print(f"Recall@1: {recall_at_1}")
print(f"Recall@3: {recall_at_3}")

MRR: 0.8126984126984127
Recall@1: 0.7
Recall@3: 0.9
