In [None]:
#Partial Macth
import json

# load predicted results and true data
with open("/content/gpt_50percent_data_finetune_predictions.json", "r", encoding="utf-8") as f:
    predictions = json.load(f)
with open("/content/test.json", "r", encoding="utf-8") as f:
    ground_truth = [json.loads(line) for line in f]

# preprocessing function
import string, re

def normalize_text(s):
    def remove_articles(text):
        return re.sub(r'\b(a|an|the)\b', ' ', text)
    def white_space_fix(text):
        return ' '.join(text.split())
    def remove_punc(text):
        return ''.join(ch for ch in text if ch not in set(string.punctuation))
    def lower(text):
        return text.lower()
    return white_space_fix(remove_articles(remove_punc(lower(s))))

# token overlap F1 score calculating
def token_overlap_f1(pred_text, gold_text):
    pred_tokens = set(normalize_text(pred_text).split())
    gold_tokens = set(normalize_text(gold_text).split())
    if not pred_tokens or not gold_tokens:
        return 0.0, 0.0, 0.0
    common = pred_tokens & gold_tokens
    p = len(common) / len(pred_tokens)
    r = len(common) / len(gold_tokens)
    f1 = 2 * p * r / (p + r) if (p + r) > 0 else 0
    return p, r, f1

# match
ps, rs, f1s = [], [], []
for item in ground_truth:
    qid = item["id"]
    gold = item["answers"]["text"][0]
    pred = predictions.get(qid, "")
    p, r, f1 = token_overlap_f1(pred, gold)
    ps.append(p)
    rs.append(r)
    f1s.append(f1)

# average metrics
partial_metrics = {
    "partial_precision": sum(ps) / len(ps),
    "partial_recall": sum(rs) / len(rs),
    "partial_f1": sum(f1s) / len(f1s)
}

import pandas as pd
pd.DataFrame([partial_metrics])


Unnamed: 0,partial_precision,partial_recall,partial_f1
0,0.954936,0.956952,0.949151


In [None]:
#Exact Macth
import json
from collections import Counter

# ====================
# Tool function：BIO label entity extraction
# ====================
def get_entities(seq, id2label, markup='bios'):
    chunks = []
    chunk = [-1, -1, -1]
    for i, tag in enumerate(seq):
        label = id2label[tag] if isinstance(tag, int) else tag
        if label.startswith("B-"):
            if chunk[2] != -1:
                chunks.append(tuple(chunk))
            chunk = [label[2:], i, i]
        elif label.startswith("I-") and chunk[0] == label[2:]:
            chunk[2] = i
        else:
            if chunk[2] != -1:
                chunks.append(tuple(chunk))
            chunk = [-1, -1, -1]
    if chunk[2] != -1:
        chunks.append(tuple(chunk))
    return chunks

# ====================
# BIO evaluator：SeqEntityScore
# ====================
class SeqEntityScore:
    def __init__(self, id2label, markup='bios'):
        self.id2label = id2label
        self.markup = markup
        self.reset()

    def reset(self):
        self.origins, self.founds, self.rights = [], [], []

    def compute(self, origin, found, right):
        recall = 0 if origin == 0 else (right / origin)
        precision = 0 if found == 0 else (right / found)
        f1 = 0. if recall + precision == 0 else (2 * precision * recall) / (precision + recall)
        return recall, precision, f1

    def result(self):
        class_info = {}
        origin_counter = Counter([x[0] for x in self.origins])
        found_counter = Counter([x[0] for x in self.founds])
        right_counter = Counter([x[0] for x in self.rights])
        for type_, count in origin_counter.items():
            origin, found = count, found_counter.get(type_, 0)
            right = right_counter.get(type_, 0)
            r, p, f1 = self.compute(origin, found, right)
            class_info[type_] = {"precision": round(p, 4), "recall": round(r, 4), "f1": round(f1, 4)}
        r, p, f1 = self.compute(len(self.origins), len(self.founds), len(self.rights))
        return {"precision": p, "recall": r, "f1": f1}, class_info

    def update(self, label_paths, pred_paths):
        for label_path, pre_path in zip(label_paths, pred_paths):
            gold = get_entities(label_path, self.id2label, self.markup)
            pred = get_entities(pre_path, self.id2label, self.markup)
            self.origins.extend(gold)
            self.founds.extend(pred)
            self.rights.extend([p for p in pred if p in gold])

# ====================
# BIO lable construction function：generate BIO sequence for given span
# ====================
def make_bio_labels(context, spans, label="ANSWER"):
    tokens = context.split()
    token_starts = []
    idx = 0
    for token in tokens:
        idx = context.find(token, idx)
        token_starts.append(idx)
        idx += len(token)

    labels = ["O"] * len(tokens)
    for start_char, end_char in spans:
        for i, token_start in enumerate(token_starts):
            token_end = token_start + len(tokens[i])
            if token_end <= start_char:
                continue
            if token_start >= end_char:
                break
            if start_char <= token_start < end_char:
                labels[i] = "B-" + label if labels[i] == "O" else "I-" + label
    return tokens, labels

# ====================
# main
# ====================
with open("/content/test.json", "r", encoding="utf-8") as f:
    ground_truth = [json.loads(line) for line in f]

with open("/content/gpt_50percent_data_finetune_predictions.json", "r", encoding="utf-8") as f:
    predictions = json.load(f)

label_paths = []
pred_paths = []

for item in ground_truth:
    qid = item["id"]
    context = item["context"].lower()
    true_text = item["answers"]["text"][0].strip().lower()
    true_start = item["answers"]["answer_start"][0]
    true_end = true_start + len(true_text)
    _, true_labels = make_bio_labels(context, [(true_start, true_end)])

    if qid in predictions:
        pred_text = predictions[qid].strip().lower()
        pred_start = context.find(pred_text)
        if pred_start != -1:
            pred_end = pred_start + len(pred_text)
            _, pred_labels = make_bio_labels(context, [(pred_start, pred_end)])
        else:
            _, pred_labels = make_bio_labels(context, [])
    else:
        _, pred_labels = make_bio_labels(context, [])

    label_paths.append(true_labels)
    pred_paths.append(pred_labels)

# ====================
# Evaluation
# ====================
id2label = {0: "O", 1: "B-ANSWER", 2: "I-ANSWER"}  # 实际仅用标签字符串，不依赖整数映射
evaluator = SeqEntityScore(id2label)
evaluator.update(label_paths, pred_paths)
overall_result, class_info = evaluator.result()

print("📊 Overall Evaluation:")
print(overall_result)
print("\n📘 Per-Entity Type Evaluation:")
print(class_info)


📊 Overall Evaluation:
{'precision': 0.9210307564422278, 'recall': 0.9089417555373257, 'f1': 0.9149463253509496}

📘 Per-Entity Type Evaluation:
{'ANSWER': {'precision': 0.921, 'recall': 0.9089, 'f1': 0.9149}}


In [None]:
#Calculate label types proportion
import json
import pandas as pd
from collections import Counter

# load test.json
with open("/content/test.json", "r", encoding="utf-8") as f:
    ground_truth = [json.loads(line) for line in f]

# Defined questions
QUESTION_ENTITY_MAP = {
    "What is the name of the host protein/RBP/host factor interacting with the virus?": "Host Protein",
    "What experimental methods were used to detect the virus-host interaction?": "Experimental Method",
    "What is the infection time for the experiment?": "Infection Time",
    "What is the name of the virus whose protein interacts with host factors/proteins?": "Virus Name",
    "What is the name of the virus whose RNA interacts with host factors/proteins?": "Virus Name2",
    "What type of cell was infected by the virus in this experiment?": "Cell Type",
    "What function does the host protein have on the virus ?": "Host Protein Function",
    "What is the name of the virus protein interacting with the host protein/RBP/host factor?": "Virus Protein",
    "What is the host species that virus infected?": "Host Species",
    "What is the strain/subtype of the virus studied?": "Virus Strain",
    "Which RNA structures within the viral genome are preferentially bound by host proteins?": "RNA Structure Preference",
    "Where is the binding site/region located on the Virus?": "Binding Site",
    "What tissue/organ does the infected cell originate from?": "Cell Origin",
    "What is the name of the table that includes the interaction between viral RNA and host protein?": "Table Name",
    "What function does the virus protein have on the host?": "Virus Protein Function",
}


def normalize(text):
    return text.strip().rstrip("?").lower()
entity_counter = Counter()
for item in ground_truth:
    q_norm = normalize(item["question"])
    for template, etype in QUESTION_ENTITY_MAP.items():
        if normalize(template) == q_norm:
            entity_counter[etype] += 1
            break

# DataFrame
df_entity_counts = pd.DataFrame(entity_counter.items(), columns=["Entity Type", "Sample Count"])
df_entity_counts = df_entity_counts.sort_values(by="Sample Count", ascending=False).reset_index(drop=True)
print(df_entity_counts.to_markdown(index=False))


| Entity Type              |   Sample Count |
|:-------------------------|---------------:|
| Host Protein             |            129 |
| Virus Protein            |            101 |
| Virus Name               |             62 |
| Experimental Method      |             57 |
| Cell Type                |             45 |
| Virus Name2              |             43 |
| Host Species             |             29 |
| Infection Time           |             20 |
| Virus Strain             |             10 |
| Binding Site             |             10 |
| Cell Origin              |              8 |
| Table Name               |              7 |
| Virus Protein Function   |              6 |
| Host Protein Function    |              5 |
| RNA Structure Preference |              2 |


In [None]:
# calculate metrics by entity types
import json
from collections import Counter
import pandas as pd

# define question entity map
QUESTION_ENTITY_MAP = {
    "What is the name of the host protein/RBP/host factor interacting with the virus?": "Host Protein",
    "What experimental methods were used to detect the virus-host interaction?": "Experimental Method",
    "What is the infection time for the experiment?": "Infection Time",
    "What is the name of the virus whose protein interacts with host factors/proteins?": "Virus Name",
    "What type of cell was infected by the virus in this experiment?": "Cell Type",
    "What function does the host protein have on the virus ?": "Host Protein Function",
    "What is the name of the virus protein interacting with the host protein/RBP/host factor?": "Virus Protein",
    "What is the host species that virus infected?": "Host Species",
    "What is the strain/subtype of the virus studied?": "Virus Strain",
    "Which RNA structures within the viral genome are preferentially bound by host proteins?": "RNA Structure Preference",
    "Where is the binding site/region located on the Virus?": "Binding Site",
    "What tissue/organ does the infected cell originate from?": "Cell Origin",
    "What is the name of the table that includes the interaction between viral RNA and host protein?": "Table Name",
    "What function does the virus protein have on the host?": "Virus Protein Function",
}

def normalize(text):
    return text.strip().rstrip("?").lower()
normalized_question_map = {normalize(q): v for q, v in QUESTION_ENTITY_MAP.items()}

# BIO label extraction function
def get_entities(seq, id2label, markup='bios'):
    chunks = []
    chunk = [-1, -1, -1]
    for i, tag in enumerate(seq):
        label = id2label[tag] if isinstance(tag, int) else tag
        if label.startswith("B-"):
            if chunk[2] != -1:
                chunks.append(tuple(chunk))
            chunk = [label[2:], i, i]
        elif label.startswith("I-") and chunk[0] == label[2:]:
            chunk[2] = i
        else:
            if chunk[2] != -1:
                chunks.append(tuple(chunk))
            chunk = [-1, -1, -1]
    if chunk[2] != -1:
        chunks.append(tuple(chunk))
    return chunks

# BIO evaluator
class SeqEntityScore:
    def __init__(self, id2label, markup='bios'):
        self.id2label = id2label
        self.markup = markup
        self.reset()

    def reset(self):
        self.origins, self.founds, self.rights = [], [], []

    def compute(self, origin, found, right):
        recall = 0 if origin == 0 else (right / origin)
        precision = 0 if found == 0 else (right / found)
        f1 = 0. if recall + precision == 0 else (2 * precision * recall) / (precision + recall)
        return recall, precision, f1

    def result(self):
        r, p, f1 = self.compute(len(self.origins), len(self.founds), len(self.rights))
        return {"precision": round(p,4), "recall": round(r,4), "f1": round(f1,4)}

    def update(self, label_paths, pred_paths):
        for label_path, pre_path in zip(label_paths, pred_paths):
            gold = get_entities(label_path, self.id2label, self.markup)
            pred = get_entities(pre_path, self.id2label, self.markup)
            self.origins.extend(gold)
            self.founds.extend(pred)
            self.rights.extend([p for p in pred if p in gold])

# BIO label generator
def make_bio_labels(context, spans, label="ANSWER"):
    tokens = context.split()
    token_starts = []
    idx = 0
    for token in tokens:
        idx = context.find(token, idx)
        token_starts.append(idx)
        idx += len(token)
    labels = ["O"] * len(tokens)
    for start_char, end_char in spans:
        for i, token_start in enumerate(token_starts):
            token_end = token_start + len(tokens[i])
            if token_end <= start_char: continue
            if token_start >= end_char: break
            labels[i] = "B-" + label if labels[i] == "O" else "I-" + label
    return tokens, labels

# load data
with open("/content/test.json", "r", encoding="utf-8") as f:
    ground_truth = [json.loads(line) for line in f]
with open("/content/pubmedBERT_predict_predictions.json", "r", encoding="utf-8") as f:
    predictions = json.load(f)

id2label = {0: "O", 1: "B-ANSWER", 2: "I-ANSWER"}

# every entity type with independent evaluator
evaluators = {etype: SeqEntityScore(id2label) for etype in set(QUESTION_ENTITY_MAP.values())}

for item in ground_truth:
    norm_q = normalize(item["question"])
    entity_type = normalized_question_map.get(norm_q)
    if entity_type is None:
        continue
    qid = item["id"]
    context = item["context"].lower()
    true_text = item["answers"]["text"][0].strip().lower()
    true_start = item["answers"]["answer_start"][0]
    true_end = true_start + len(true_text)
    _, true_labels = make_bio_labels(context, [(true_start, true_end)])

    if qid in predictions:
        pred_text = predictions[qid].strip().lower()
        pred_start = context.find(pred_text)
        if pred_start != -1:
            pred_end = pred_start + len(pred_text)
            _, pred_labels = make_bio_labels(context, [(pred_start, pred_end)])
        else:
            _, pred_labels = make_bio_labels(context, [])
    else:
        _, pred_labels = make_bio_labels(context, [])

    evaluators[entity_type].update([true_labels], [pred_labels])

# Metrics results by labels
results = []
for etype, evaluator in evaluators.items():
    metrics = evaluator.result()
    results.append({
        "Entity Type": etype,
        "Precision": metrics["precision"],
        "Recall": metrics["recall"],
        "F1": metrics["f1"]
    })

# DataFrame
df_results = pd.DataFrame(results).sort_values(by="F1", ascending=False).reset_index(drop=True)
df_results


Unnamed: 0,Entity Type,Precision,Recall,F1
0,Virus Strain,1.0,1.0,1.0
1,Host Species,1.0,1.0,1.0
2,RNA Structure Preference,1.0,1.0,1.0
3,Infection Time,1.0,1.0,1.0
4,Binding Site,1.0,0.9444,0.9714
5,Cell Type,0.9573,0.9739,0.9655
6,Table Name,1.0,0.9333,0.9655
7,Experimental Method,0.9427,0.9511,0.9469
8,Cell Origin,1.0,0.8889,0.9412
9,Virus Name,0.8533,1.0,0.9209
