In [1]:
# Reinstall the package from local source to get the latest fixes
# %pip install -e ..

In [2]:
#!pip -q install locisimiles

In [3]:
from locisimiles.evaluator import IntertextEvaluator
from locisimiles.pipeline import (
    ClassificationPipeline,
    ClassificationPipelineWithCandidategeneration,
    pretty_print,
)
from locisimiles.document import Document

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
# Load example query and source documents
query_doc = Document("./hieronymus_samples.csv", author="Hieronymus")
source_doc = Document("./vergil_samples.csv", author="Vergil")

print("Loaded query and source documents:")
print(f"Query Document: {query_doc}")
print(f"Source Document: {source_doc}")
print("=" * 70)

Loaded query and source documents:
Query Document: Document('hieronymus_samples.csv', segments=11, author='Hieronymus', meta={})
Source Document: Document('vergil_samples.csv', segments=10, author='Vergil', meta={})


## Two-Stage Pipeline (Retrieval + Classification)

In [5]:
# Load the pipeline with pre-trained models
pipeline_two_stage = ClassificationPipelineWithCandidategeneration(
    classification_name="julian-schelb/PhilBerta-class-latin-intertext-v1",
    embedding_model_name="julian-schelb/SPhilBerta-emb-lat-intertext-v1",
    device="cpu",
)

# Run the pipeline with the query and source documents
results_two_stage = pipeline_two_stage.run(
    query=query_doc,    # Query document
    source=source_doc,  # Source document
    top_k=10             # Number of top similar candidates to classify
)
print("Results of the two-stage pipeline run:")
pretty_print(results_two_stage)

Embedding query segments: 100%|██████████| 11/11 [00:00<00:00, 305545.32it/s]
Embedding source segments: 100%|██████████| 10/10 [00:00<00:00, 287281.10it/s]
Check candidates: 100%|██████████| 11/11 [00:04<00:00,  2.37it/s]

Results of the two-stage pipeline run:

▶ Query segment 'hier. adv. iovin. 1.1':
  verg. aen. 10.636          sim=+0.661  P(pos)=0.974
  verg. ecl. 8.62            sim=+0.610  P(pos)=0.000
  verg. ecl. 3.26            sim=+0.609  P(pos)=0.000
  verg. aen. 10.875          sim=+0.599  P(pos)=0.000
  verg. georg. 2.475         sim=+0.598  P(pos)=0.000
  verg. georg. 1.197         sim=+0.545  P(pos)=0.000
  verg. ecl. 3.49            sim=+0.485  P(pos)=0.000
  verg. aen. 1.177           sim=+0.484  P(pos)=0.000
  verg. aen. 4.172           sim=+0.437  P(pos)=0.000
  verg. aen. 11.508          sim=+0.407  P(pos)=0.000

▶ Query segment 'hier. adv. iovin. 1.41':
  verg. aen. 11.508          sim=+0.883  P(pos)=0.988
  verg. ecl. 3.49            sim=+0.559  P(pos)=0.000
  verg. aen. 4.172           sim=+0.539  P(pos)=0.000
  verg. ecl. 8.62            sim=+0.503  P(pos)=0.000
  verg. aen. 1.177           sim=+0.501  P(pos)=0.000
  verg. georg. 2.475         sim=+0.483  P(pos)=0.000
  verg. aen.




In [6]:
evaluator = IntertextEvaluator(
    query_doc=query_doc,
    source_doc=source_doc,
    ground_truth_csv="./ground_truth.csv",
    pipeline=pipeline_two_stage,
    top_k=10,
    threshold="auto",
    auto_threshold_metric="smr",
)

print("Single sentence:\n", evaluator.evaluate_single_query("hier. adv. iovin. 1.41"))
print("\nPer-sentence head:\n", evaluator.evaluate_all_queries().head(10))
print("\nMacro scores:\n", evaluator.evaluate(average="macro"))
print("\nMicro scores:\n", evaluator.evaluate(average="micro"))

Embedding query segments: 100%|██████████| 11/11 [00:00<00:00, 322638.77it/s]
Embedding source segments: 100%|██████████| 10/10 [00:00<00:00, 384798.53it/s]
Check candidates: 100%|██████████| 11/11 [00:03<00:00,  2.98it/s]

[IntertextEvaluator] Auto-threshold enabled: best smr at threshold=0.10
Single sentence:
 {'query_id': 'hier. adv. iovin. 1.41', 'precision': 1.0, 'recall': 1.0, 'f1': 1.0, 'accuracy': 1.0, 'errors': 0, 'tp': 1, 'fp': 0, 'fn': 0, 'tn': 9, 'fpr': 0.0, 'fnr': 0.0, 'smr': 0.0}

Per-sentence head:
                  query_id  precision  recall   f1  accuracy  errors  tp  fp  \
0   hier. adv. iovin. 1.1        1.0     1.0  1.0       1.0       0   1   0   
1  hier. adv. iovin. 1.41        1.0     1.0  1.0       1.0       0   1   0   
2  hier. adv. iovin. 2.36        1.0     1.0  1.0       1.0       0   1   0   
3  hier. adv. pelag. 1.23        1.0     1.0  1.0       1.0       0   1   0   
4  hier. adv. pelag. 3.11        1.0     1.0  1.0       1.0       0   1   0   
5   hier. adv. pelag. 3.4        1.0     1.0  1.0       1.0       0   1   0   
6  hier. adv. rufin. 1.17        1.0     1.0  1.0       1.0       0   1   0   
7   hier. adv. rufin. 1.5        1.0     1.0  1.0       1.0       0   1 




### Find Optimal Probability Threshold

In [7]:
# Find the optimal threshold that maximizes F1 score
best_result, all_thresholds_df = evaluator.find_best_threshold(
    metric="smr",           # Optimize for SMR (can also use 'f1', 'precision', 'recall', 'accuracy', 'fpr', 'fnr')
    average="micro",       # Use micro-averaging
)

print(f"Best threshold: {best_result['best_threshold']}")
print(f"Best SMR score: {best_result['best_smr']:.4f}")
print(f"\nAll metrics at best threshold:")
for k, v in best_result.items():
    if k not in ['best_threshold', 'best_smr']:
        print(f"  {k}: {v:.4f}" if isinstance(v, float) else f"  {k}: {v}")

print("\n\nMetrics across all thresholds:")
print(all_thresholds_df.to_string(index=False))

Best threshold: 0.1
Best SMR score: 0.0000

All metrics at best threshold:
  precision: 1.0000
  recall: 1.0000
  f1: 1.0000
  accuracy: 1.0000
  fpr: 0.0000
  fnr: 0.0000
  smr: 0.0000
  tp: 10.0000
  fp: 0.0000
  fn: 0.0000
  tn: 100.0000


Metrics across all thresholds:
 threshold  precision  recall       f1  accuracy  fpr      fnr      smr   tp  fp  fn    tn
       0.1        1.0     1.0 1.000000  1.000000  0.0 0.000000 0.000000 10.0 0.0 0.0 100.0
       0.2        1.0     1.0 1.000000  1.000000  0.0 0.000000 0.000000 10.0 0.0 0.0 100.0
       0.3        1.0     1.0 1.000000  1.000000  0.0 0.000000 0.000000 10.0 0.0 0.0 100.0
       0.4        1.0     1.0 1.000000  1.000000  0.0 0.000000 0.000000 10.0 0.0 0.0 100.0
       0.5        1.0     0.9 0.947368  0.990909  0.0 0.009091 0.009091  9.0 0.0 1.0 100.0
       0.6        1.0     0.9 0.947368  0.990909  0.0 0.009091 0.009091  9.0 0.0 1.0 100.0
       0.7        1.0     0.9 0.947368  0.990909  0.0 0.009091 0.009091  9.0 0.0 1.0 100.

### Evaluate Metrics for Different Top-K Values

Evaluate metrics (including TP, FP, FN, TN) for different `k` values without re-running predictions.

In [8]:
# Evaluate metrics for different k values (number of candidates to classify)
# This uses the existing predictions - no re-running required!
k_results = evaluator.evaluate_k_values(
    k_values=[1, 2, 3, 5, 10],  # k values to evaluate (must be <= original top_k)
    average="micro",            # Use micro-averaging
)

# Print results for each k
for k, metrics in k_results.items():
    print(f"\n=== k = {k} ===")
    print(f"  Precision: {metrics['precision']:.4f}")
    print(f"  Recall:    {metrics['recall']:.4f}")
    print(f"  F1:        {metrics['f1']:.4f}")
    print(f"  TP: {metrics['tp']}, FP: {metrics['fp']}, FN: {metrics['fn']}, TN: {metrics['tn']}")


=== k = 1 ===
  Precision: 1.0000
  Recall:    1.0000
  F1:        1.0000
  TP: 10.0, FP: 0.0, FN: 0.0, TN: 100.0

=== k = 2 ===
  Precision: 1.0000
  Recall:    1.0000
  F1:        1.0000
  TP: 10.0, FP: 0.0, FN: 0.0, TN: 100.0

=== k = 3 ===
  Precision: 1.0000
  Recall:    1.0000
  F1:        1.0000
  TP: 10.0, FP: 0.0, FN: 0.0, TN: 100.0

=== k = 5 ===
  Precision: 1.0000
  Recall:    1.0000
  F1:        1.0000
  TP: 10.0, FP: 0.0, FN: 0.0, TN: 100.0

=== k = 10 ===
  Precision: 1.0000
  Recall:    1.0000
  F1:        1.0000
  TP: 10.0, FP: 0.0, FN: 0.0, TN: 100.0


## Classification-Only Pipeline

In [8]:
# Load the classification-only pipeline
pipeline_clf = ClassificationPipeline(
    classification_name="julian-schelb/PhilBerta-class-latin-intertext-v1",
    device="cpu",
)

# Run the pipeline - classifies all pairs
results_clf = pipeline_clf.run(
    query=query_doc,
    source=source_doc,
    batch_size=32,
)

Classifying pairs: 100%|██████████| 11/11 [00:03<00:00,  2.80it/s]


In [9]:
# Filter and display high-probability results
threshold = 0.7
print(f"Filtered results (P(positive) > {threshold}):")
for query_id, pairs in results_clf.items():
    high_prob_pairs = [(seg, sim, prob) for seg, sim, prob in pairs if prob > threshold]
    if high_prob_pairs:
        print(f"\n▶ Query segment {query_id!r}:")
        for src_seg, sim, ppos in high_prob_pairs:
            print(f"  {src_seg.id:<25}  P(pos)={ppos:.3f}")

Filtered results (P(positive) > 0.7):

▶ Query segment 'hier. adv. iovin. 1.1':
  verg. aen. 10.636          P(pos)=0.974

▶ Query segment 'hier. adv. iovin. 1.41':
  verg. aen. 11.508          P(pos)=0.988

▶ Query segment 'hier. adv. iovin. 2.36':
  verg. aen. 4.172           P(pos)=0.997

▶ Query segment 'hier. adv. pelag. 1.23':
  verg. ecl. 8.62            P(pos)=0.981

▶ Query segment 'hier. adv. pelag. 3.11':
  verg. ecl. 3.49            P(pos)=0.995

▶ Query segment 'hier. adv. rufin. 1.17':
  verg. ecl. 3.26            P(pos)=0.997

▶ Query segment 'hier. adv. rufin. 1.5':
  verg. aen. 10.875          P(pos)=0.997

▶ Query segment 'hier. adv. rufin. 1.6':
  verg. aen. 1.177           P(pos)=0.993

▶ Query segment 'hier. adv. rufin. 3.28':
  verg. georg. 2.475         P(pos)=0.994


In [10]:
evaluator_clf = IntertextEvaluator(
    query_doc=query_doc,
    source_doc=source_doc,
    ground_truth_csv="./ground_truth.csv",
    pipeline=pipeline_clf,
    top_k=len(source_doc),  # All pairs
    threshold=0.5,
)

print("Single sentence:\n", evaluator_clf.evaluate_single_query("hier. adv. iovin. 1.41"))
print("\nPer-sentence head:\n", evaluator_clf.evaluate_all_queries().head(10))
print("\nMacro scores:\n", evaluator_clf.evaluate(average="macro"))
print("\nMicro scores:\n", evaluator_clf.evaluate(average="micro"))

Classifying pairs: 100%|██████████| 11/11 [00:04<00:00,  2.67it/s]

Single sentence:
 {'query_id': 'hier. adv. iovin. 1.41', 'precision': 1.0, 'recall': 1.0, 'f1': 1.0, 'accuracy': 1.0, 'errors': 0, 'tp': 1, 'fp': 0, 'fn': 0, 'tn': 9, 'fpr': 0.0, 'fnr': 0.0, 'smr': 0.0}

Per-sentence head:
                  query_id  precision  recall   f1  accuracy  errors  tp  fp  \
0   hier. adv. iovin. 1.1        1.0     1.0  1.0       1.0       0   1   0   
1  hier. adv. iovin. 1.41        1.0     1.0  1.0       1.0       0   1   0   
2  hier. adv. iovin. 2.36        1.0     1.0  1.0       1.0       0   1   0   
3  hier. adv. pelag. 1.23        1.0     1.0  1.0       1.0       0   1   0   
4  hier. adv. pelag. 3.11        1.0     1.0  1.0       1.0       0   1   0   
5   hier. adv. pelag. 3.4        0.0     0.0  0.0       0.9       1   0   0   
6  hier. adv. rufin. 1.17        1.0     1.0  1.0       1.0       0   1   0   
7   hier. adv. rufin. 1.5        1.0     1.0  1.0       1.0       0   1   0   
8   hier. adv. rufin. 1.6        1.0     1.0  1.0       1.0      




### Inspect Input Encoding

In [11]:
# Get the first query and candidate pair
first_query_id = list(query_doc.ids())[0]
first_query_text = query_doc.get_text(first_query_id)
first_candidate_text = source_doc.get_text(list(source_doc.ids())[0])

# Debug the encoding
debug_info = pipeline_clf.debug_input_sequence(first_query_text, first_candidate_text)

print(f"\nInput Text (with special tokens):")
print(debug_info['input_text'])
print(f"\nInput IDs (first 20): {debug_info['input_ids']}")
print(f"Total tokens: {len(debug_info['input_ids'])}")


Input Text (with special tokens):
<s>Furiosas Apollinis uates legimus; et illud Uirgilianum: Dat sine mente sonum.</s></s>tum dea nube caua tenuem sine uiribus umbram in faciem Aeneae uisu mirabile monstrum Dardaniis ornat telis clipeumque iubasque diuini assimulat capitis, dat inania uerba, dat sine mente sonum gressusque effingit euntis, morte obita qualis fama est uolitare figuras aut quae sopitos deludunt somnia sensus.</s>

Input IDs (first 20): [0, 48157, 2035, 346, 47074, 489, 1604, 16836, 31, 361, 2000, 1931, 410, 75, 588, 4186, 30, 49974, 1747, 6664, 30821, 18, 2, 2, 746, 30897, 29120, 1808, 2338, 24795, 324, 1747, 43883, 31334, 306, 8298, 16723, 381, 12147, 89, 33370, 53290, 28818, 304, 1563, 46427, 33064, 42918, 1995, 6164, 8195, 3576, 1721, 1338, 345, 24206, 13375, 16, 8358, 47435, 21042, 16, 8358, 1747, 6664, 30821, 27783, 7816, 1471, 43780, 26124, 282, 16, 5795, 765, 1174, 10403, 10160, 427, 5805, 14004, 35587, 610, 691, 16603, 7009, 1872, 454, 433, 37650, 5964, 18, 2]
T