In [3]:
from typing import List, Tuple, Dict, Any
import numpy as np
import pickle
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import Pipeline
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report

LABEL = "REFERENCE"

In [8]:
# 3) TF-IDF + Logistic Regression baseline
model = Pipeline([
    ("tfidf", TfidfVectorizer(ngram_range=(1, 2), min_df=2)),
    ("clf", LogisticRegression(max_iter=500, solver="saga", class_weight="balanced")),
])
with open("LG_x_train.pkl", "rb") as b:
        X_train =pickle.load(b)
with open("LG_y_train.pkl", "rb") as b:
        Y_train =pickle.load(b)
with open("LG_x_test.pkl", "rb") as b:
        X_test =pickle.load(b)
with open("LG_y_test.pkl", "rb") as b:
    Y_test =pickle.load(b)
model.fit(X_train, Y_train)
pred = model.predict(X_test)



              precision    recall  f1-score   support

       NoRef       0.98      0.90      0.94      2607
      HasRef       0.58      0.87      0.70       414

    accuracy                           0.90      3021
   macro avg       0.78      0.89      0.82      3021
weighted avg       0.92      0.90      0.91      3021



In [None]:
def get_spans(labels, target="I-REFERENCE"):
    """Return list of (start,end) inclusive spans for contiguous target labels."""
    spans = []
    n = len(labels)
    i = 0
    while i < n:
        if labels[i] == target:
            start = i
            i += 1
            while i < n and labels[i] == target:
                i += 1
            end = i - 1
            spans.append((start, end))
        else:
            i += 1
    return spans

In [None]:
def overlap_len(a, b):
    """Length of overlap between inclusive spans a=(s,e), b=(s,e)."""
    s = max(a[0], b[0])
    e = min(a[1], b[1])
    return max(0, e - s + 1)

gold_spans = get_spans(y_test, "I-REFERENCE")  # already in order
pred_spans = get_spans(y_pred, "I-REFERENCE")     # already in order

total_gold = len(gold_spans)
total_pred = len(pred_spans)
print(total_gold)
print(total_pred)
complete = 0
part = 0
false_negative = 0
false_positive = 0

matched_pred = set()

j = 0  # pointer into pred_spans

# choose your threshold:
# - for plain lenient: min_overlap_ratio = 0.0 (any overlap counts)
# - for "at least half of gold tokens overlapped": 0.5
min_overlap_ratio = 0.5
# change to 0.5 if you want your "half" rule

for g_idx, g in enumerate(gold_spans):
    g_start, g_end = g
    g_len = g_end - g_start + 1
    required = int(min_overlap_ratio * g_len + 1e-9)  # floor

    # advance j while predicted span ends before this gold span starts (can't overlap)
    while j < total_pred and pred_spans[j][1] < g_start:
        j += 1

    # now scan forward from j while predicted spans start before gold ends (could overlap)
    k = j
    best_overlap = 0
    exact = False
    overlapped_any = False

    while k < total_pred and pred_spans[k][0] <= g_end:
        p = pred_spans[k]
        ol = overlap_len(g, p)
        if ol > 0:
            overlapped_any = True
            matched_pred.add(k)
            if p == g:
                exact = True
            if ol > best_overlap:
                best_overlap = ol
        k += 1

    # decide category for this gold span
    if not overlapped_any or best_overlap <= required:
        false_negative += 1
    else:
        if exact:
            complete += 1
        else:
            part += 1

# predicted spans not matched to any gold span are false positives (lenient overlap definition)
for k in range(total_pred):
    if k not in matched_pred:
        false_positive += 1

print("total_gold:", total_gold)
print("total_pred:", total_pred)
print("complete:", complete)
print("part:", part)
print("false_positive:", false_positive)
print("false_negative:", false_negative)


In [None]:
tp_lenient = complete + part
fp_lenient = false_positive
fn_lenient = false_negative

precision = tp_lenient / (tp_lenient + fp_lenient) if (tp_lenient + fp_lenient) else 0.0
recall    = tp_lenient / (tp_lenient + fn_lenient) if (tp_lenient + fn_lenient) else 0.0
f1        = 2*precision*recall/(precision+recall) if (precision+recall) else 0.0

print("Precision:", precision)
print("Recall:   ", recall)
print("F1:       ", f1)


In [None]:



gold_spans = set(get_spans(y_test, "I-REFERENCE"))
pred_spans = set(get_spans(y_pred, "I-REFERENCE"))

tp_strict = len(gold_spans & pred_spans)      # exact matches
fp_strict = len(pred_spans - gold_spans)      # predicted but not gold (includes partial + spurious)
fn_strict = len(gold_spans - pred_spans)      # gold but not predicted (includes partial + missed)

precision_strict = tp_strict / (tp_strict + fp_strict) if (tp_strict + fp_strict) else 0.0
recall_strict    = tp_strict / (tp_strict + fn_strict) if (tp_strict + fn_strict) else 0.0
f1_strict        = (
    2 * precision_strict * recall_strict / (precision_strict + recall_strict)
    if (precision_strict + recall_strict) else 0.0
)

print("STRICT")
print("total_gold:", len(gold_spans))
print("total_pred:", len(pred_spans))
print("TP:", tp_strict)
print("FP:", fp_strict)
print("FN:", fn_strict)
print("Precision:", precision_strict)
print("Recall:   ", recall_strict)
print("F1:       ", f1_strict)
