In [1]:
%load_ext autoreload
%autoreload 2

In [26]:
import asyncio
import json
import os
import sys
from typing import List

from datasets import Dataset
from agents import Runner, RunConfig

from config import settings  # type: ignore
from data import (  # type: ignore
    build_prompt,
    get_label_mapping,
    sanitize_tokens,
    validate_predicted_ids,
)
from model_provider import GEMINI_PROVIDER  # type: ignore
from ner_agent import build_ner_agent  # type: ignore

from data import load_conll2003_test_split  # type: ignore
from runner import run_dataset  # type: ignore





In [28]:
print("[NER] Environment loaded. Starting...")
if not settings.api_key:
    raise ValueError("GOOGLE_API_KEY is required")
print("[NER] Loading CoNLL2003 test split...")
ds = load_conll2003_test_split()


[NER] Environment loaded. Starting...
[NER] Loading CoNLL2003 test split...


In [16]:
ds[0]

{'id': '0',
 'tokens': ['SOCCER',
  '-',
  'JAPAN',
  'GET',
  'LUCKY',
  'WIN',
  ',',
  'CHINA',
  'IN',
  'SURPRISE',
  'DEFEAT',
  '.'],
 'pos_tags': [21, 8, 22, 37, 22, 22, 6, 22, 15, 12, 21, 7],
 'chunk_tags': [11, 0, 11, 21, 11, 12, 0, 11, 13, 11, 12, 0],
 'ner_tags': [0, 0, 5, 0, 0, 0, 0, 1, 0, 0, 0, 0]}

In [29]:
label_names, _ = get_label_mapping(ds)
num_labels = len(label_names)

print(label_names)
print(num_labels)

['O', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC', 'B-MISC', 'I-MISC']
9


In [30]:
print("[NER] Building agent...")
agent = build_ner_agent()


[NER] Building agent...


In [None]:
limit=100
dataset=ds
import time

n = min(limit, len(dataset))
print(f"[NER] Beginning inference for {n} rows...")
results: List[List[int]] = []
for i in range(n):
    # Check if i is one of the trigger values
    if i in {30, 60, 90}:
        print(f"Reached {i}, pausing for 30 seconds...")
        time.sleep(30)   # <-- pause execution for 30 seconds
    #tokens = sanitize_tokens(dataset[i]["tokens"])  # type: ignore[index]
    tokens = dataset[i]["tokens"]
    print(f"[NER] Row {i}: {len(tokens)} tokens")
    prompt = build_prompt(tokens, label_names)
    result = await Runner.run(agent, prompt, run_config=RunConfig(model_provider=GEMINI_PROVIDER))
    text = result.final_output if hasattr(result, "final_output") else str(result)
    arr = json.loads(text)
    preds = [int(x) for x in arr]
    #Basic validation; fallback to O if invalid length
    if not validate_predicted_ids(preds, len(tokens), num_labels):
        preds = [0] * len(tokens)
    results.append(preds)

dataset = dataset.select(range(n))
dataset = dataset.add_column("pred_ner_tags", results)
print("[NER] Inference complete.")


In [10]:
len(dataset[0]['tokens'])

12

In [11]:
prompt

"Tokens: ['SOCCER', '-', 'JAPAN', 'GET', 'LUCKY', 'WIN', ',', 'CHINA', 'IN', 'SURPRISE', 'DEFEAT', '.']\nLabel names (index aligned): [O, B-PER, I-PER, B-ORG, I-ORG, B-LOC, I-LOC, B-MISC, I-MISC]\nReturn a list of integers, one per token."

In [62]:
len(arr)

12

In [64]:
print(len(dataset[0]['ner_tags']))
print(len(preds))

12
12


In [12]:
dataset[0]

{'id': '0',
 'tokens': ['SOCCER',
  '-',
  'JAPAN',
  'GET',
  'LUCKY',
  'WIN',
  ',',
  'CHINA',
  'IN',
  'SURPRISE',
  'DEFEAT',
  '.'],
 'pos_tags': [21, 8, 22, 37, 22, 22, 6, 22, 15, 12, 21, 7],
 'chunk_tags': [11, 0, 11, 21, 11, 12, 0, 11, 13, 11, 12, 0],
 'ner_tags': [0, 0, 5, 0, 0, 0, 0, 1, 0, 0, 0, 0],
 'pred_ner_tags': [0, 0, 5, 0, 0, 0, 0, 5, 0, 0, 0, 0]}

In [37]:
from evaluation import evaluate_ner_predictions

In [38]:
import pandas as pd

# Replace with your file path
df = pd.read_parquet("/Users/gabrieldiasmp/Documents/pasta_gabriel/codigo/agents/7_named_entity_recognition/output/predictions.parquet")

# Preview the first rows
df.head()

Unnamed: 0,id,tokens,pos_tags,chunk_tags,ner_tags,pred_ner_tags
0,0,"[SOCCER, -, JAPAN, GET, LUCKY, WIN, ,, CHINA, ...","[21, 8, 22, 37, 22, 22, 6, 22, 15, 12, 21, 7]","[11, 0, 11, 21, 11, 12, 0, 11, 13, 11, 12, 0]","[0, 0, 5, 0, 0, 0, 0, 1, 0, 0, 0, 0]","[0, 0, 5, 0, 0, 0, 0, 5, 0, 0, 0, 0]"
1,1,"[Nadim, Ladki]","[22, 22]","[11, 12]","[1, 2]","[1, 2]"
2,2,"[AL-AIN, ,, United, Arab, Emirates, 1996-12-06]","[22, 6, 22, 22, 23, 11]","[11, 0, 11, 12, 12, 12]","[5, 0, 5, 6, 6, 0]","[5, 0, 5, 6, 6, 0]"
3,3,"[Japan, began, the, defence, of, their, Asian,...","[22, 38, 12, 21, 15, 29, 16, 22, 21, 15, 12, 1...","[11, 21, 11, 12, 13, 11, 12, 12, 12, 13, 11, 1...","[5, 0, 0, 0, 0, 0, 7, 8, 0, 0, 0, 0, 0, 0, 0, ...","[5, 0, 0, 0, 0, 0, 7, 8, 0, 0, 0, 0, 0, 0, 0, ..."
4,4,"[But, China, saw, their, luck, desert, them, i...","[10, 22, 38, 29, 21, 37, 28, 15, 12, 21, 21, 1...","[0, 11, 21, 11, 12, 21, 11, 13, 11, 12, 12, 13...","[0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."


In [40]:
# id2label example:
id2label = {0:'O',1:'B-PER',2:'I-PER',3:'B-ORG',4:'I-ORG',5:'B-LOC',6:'I-LOC',7:'B-MISC',8:'I-MISC'}

In [41]:
result = evaluate_ner_predictions(df, id2label)

In [45]:
result

{'overall': {'f1': 0.7115902964959567,
  'precision': 0.8571428571428571,
  'recall': 0.6082949308755761},
 'per_label':       precision    recall  f1-score  support
 LOC    0.947368  0.878049  0.911392     82.0
 MISC   0.800000  0.545455  0.648649     22.0
 ORG    0.125000  0.500000  0.200000      2.0
 PER    0.854545  0.423423  0.566265    111.0,
 'report_text': '              precision    recall  f1-score   support\n\n         LOC       0.95      0.88      0.91        82\n        MISC       0.80      0.55      0.65        22\n         ORG       0.12      0.50      0.20         2\n         PER       0.85      0.42      0.57       111\n\n   micro avg       0.86      0.61      0.71       217\n   macro avg       0.68      0.59      0.58       217\nweighted avg       0.88      0.61      0.70       217\n'}

In [44]:
print(f"F1: {result['overall']['f1']} \nPrecision: {result['overall']['precision']} \nRecall: {result['overall']['recall']}")
print(result['report_text'])

F1: 0.7115902964959567 
Precision: 0.8571428571428571 
Recall: 0.6082949308755761
              precision    recall  f1-score   support

         LOC       0.95      0.88      0.91        82
        MISC       0.80      0.55      0.65        22
         ORG       0.12      0.50      0.20         2
         PER       0.85      0.42      0.57       111

   micro avg       0.86      0.61      0.71       217
   macro avg       0.68      0.59      0.58       217
weighted avg       0.88      0.61      0.70       217

