In [1]:
import os
import numpy as np
import matplotlib.pyplot as plt
from utils import json_load

In [2]:
# get evaluator result
#MODEL_NAME = 'text-embedding-3-small'  # "text-embedding-ada-002"
MODEL_NAME = "text-embedding-ada-002"
model_base_name = '-'.join(MODEL_NAME.split('-')[-2:])
response_model = '2_llama_zh' # '4_gemma'

small_3_dir = os.path.join('data', 'deliverables', f"{model_base_name}_semantic_similarity")
small_3_file_path = os.path.join(small_3_dir, f'{response_model}_semantic_result.json')

small_3_data  = json_load(small_3_file_path)
print(len(small_3_data))
small_3_data['1']

load data from: data/deliverables/ada-002_semantic_similarity/2_llama_zh_semantic_result.json
80


{'reference_answer': '常見針灸配穴法中,所指的「四關穴」,為下列何穴位之組合?',
 'reference_context': '1.常見針灸配穴法中，所指的「四關穴」，為下列何穴位之組合？\n\xa0\nA.上星、日月\nB.合谷、太衝\nC.內關、外關\nD.上關、下關',
 'response': '台九気终飯加度泩。吃的“六南分。，是一条有分题不别的宽器。',
 'semantic_score': 0.776408433921812}

In [3]:
# get exact match data
dataset_dir = os.path.join('data', 'source', 'normalized_exact_match_result')
label_file_path = os.path.join(dataset_dir, f'{response_model}_evaluation_result.json')
label_data = json_load(label_file_path)
label_data[0]

load data from: data/source/normalized_exact_match_result/2_llama_zh_evaluation_result.json


{'qid': '1',
 'ispass': False,
 'label': '常見針灸配穴法中,所指的「四關穴」,為下列何穴位之組合?',
 'pred': '台九気终飯加度泩。吃的“六南分。，是一条有分题不别的宽器。',
 'nlabel': '常見針灸配穴法中所指的四關穴為下列何穴位之組合',
 'npred': '台九気终飯加度泩吃的六南分是一条有分题不别的宽器'}

In [4]:
# get binary labels
# and prediction scores
# 由於這邊同樣是用 stem，所以認定 label 是 答案
labels = []
small_3_scores = []  # 3-small

for idx, label in enumerate(label_data):
    # get exact match and qid
    qid = label['qid']
    if label['ispass']:
        labels.append(1)
    else:
        labels.append(0)
        
    # get small-3 pred
    small_3_result = small_3_data[qid]
    small_3_score = small_3_result['semantic_score']
    small_3_scores.append(small_3_score)


## small_3 selection

In [5]:
from sklearn.metrics import precision_recall_curve
precisions, recalls, thresholds = precision_recall_curve(labels, small_3_scores)
print(type(precisions), precisions.shape, type(recalls), recalls.shape, type(thresholds), thresholds.shape)

<class 'numpy.ndarray'> (81,) <class 'numpy.ndarray'> (81,) <class 'numpy.ndarray'> (80,)


In [6]:
import evaluate  # pip install evaluate

def _get_binary_prediction(pred, thr):
    return [0.0 if p <= thr else 1. for p in pred]

def _get_precision_recall(pred, label):
    precision_metric = evaluate.load("precision")
    recall_metric = evaluate.load("recall")

    combined = evaluate.combine([precision_metric, recall_metric])
    results = combined.compute(predictions=pred, references=label)
    return results  # dict 

def get_precision_recall(pred, label, thr):
    binary_pred = _get_binary_prediction(pred, thr)
    result = _get_precision_recall(binary_pred, label)
    return result  # dict 

  from .autonotebook import tqdm as notebook_tqdm


In [7]:
get_precision_recall(small_3_scores, labels, thresholds[1])

{'precision': 0.3717948717948718, 'recall': 1.0}

In [8]:
def get_f1_score(precisions, recalls):
    # from precision and recall
    f1_scores = np.where(
        (precisions + recalls) > 0,
        2 * precisions * recalls / (precisions + recalls),
        0.0,
    )
    return f1_scores

f1_scores = get_f1_score(precisions, recalls)
f1_scores.shape

(81,)

In [9]:
def get_topk_highlight(precisions, recalls, k=3):
    f1_scores = get_f1_score(precisions, recalls)
    topk_idx = np.argsort(-f1_scores)[:k]

    highlights = []
    for idx in topk_idx:
        thr = None if idx == 0 else thresholds[idx - 1]
        highlights.append({
            "f1": float(f1_scores[idx]),
            "precision": float(precisions[idx]),
            "recall": float(recalls[idx]),
            "threshold": None if thr is None else float(thr),
            "idx": int(idx),
        })
    return highlights

In [10]:
highlights = get_topk_highlight(precisions, recalls)
highlights

[{'f1': 0.9508196721311475,
  'precision': 0.90625,
  'recall': 1.0,
  'threshold': 0.9642988120688676,
  'idx': 48},
 {'f1': 0.9491525423728815,
  'precision': 0.9333333333333333,
  'recall': 0.9655172413793104,
  'threshold': 0.9676289405210952,
  'idx': 50},
 {'f1': 0.9454545454545454,
  'precision': 1.0,
  'recall': 0.896551724137931,
  'threshold': 0.9899387683709545,
  'idx': 54}]

In [11]:
get_precision_recall(small_3_scores, labels, highlights[0]['threshold'])

{'precision': 0.90625, 'recall': 1.0}