# Lab 3: Named Entity Recognition (NER) with Transformers

Goals:
1) Use a BERT-based token classification model for NER.
2) Prompt a Gemma chat model to perform NER.
3) Evaluate results for both approaches.

We'll use a small English dataset with PERSON/ORG/LOC entities.

## Setup
We adopt the same caching pattern as previous labs. Gemma generation is optional (enable only with sufficient resources).

In [3]:
import os, json, re
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModelForTokenClassification, AutoModelForCausalLM, pipeline

BASE_DIR = os.path.join('../../lab3')
DATA_DIR = os.path.join(BASE_DIR, 'data')
CACHE_DIR = os.path.join(BASE_DIR, 'models_cache')
os.makedirs(CACHE_DIR, exist_ok=True)

RUN_GEMMA = False

print('Transformers:', __import__('transformers').__version__)
print('Torch:', torch.__version__)

Transformers: 4.52.3
Torch: 2.7.0+cu126


## Data: sentences with gold entities
We evaluate on a small dataset with PERSON/ORG/LOC labels defined in `lab3/data/ner_examples.json`.

In [4]:
with open(os.path.join(DATA_DIR, 'ner_examples.json'), 'r', encoding='utf-8') as f:
    DATA = json.load(f)
len(DATA), DATA[0]['text']

(10, 'Barack Obama visited Stanford University in California.')

## Part 1: BERT-based NER (token classification)


In [19]:
BERT_NER_ID = 'dslim/bert-base-NER'
# Load tokenizer and token classification model directly (no pipeline)
tokenizer_bert_ner = AutoTokenizer.from_pretrained(BERT_NER_ID, cache_dir=CACHE_DIR)
model_bert_ner = AutoModelForTokenClassification.from_pretrained(BERT_NER_ID, cache_dir=CACHE_DIR)
print(model_bert_ner)
id2label = model_bert_ner.config.id2label
print('Trained to assign these labels: ', id2label)

Some weights of the model checkpoint at dslim/bert-base-NER were not used when initializing BertForTokenClassification: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']
- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


BertForTokenClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(28996, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12

In [33]:
# First, identify the offset of each token in the original text
text = 'Barack Obama visited Stanford University in California'
enc_offsets = tokenizer_bert_ner(text, return_offsets_mapping=True, truncation=True)
offsets = enc_offsets['offset_mapping']
enc_offsets

{'input_ids': [101, 14319, 7661, 3891, 8036, 1239, 1107, 1756, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1], 'offset_mapping': [(0, 0), (0, 6), (7, 12), (13, 20), (21, 29), (30, 40), (41, 43), (44, 54), (0, 0)]}

In [34]:
# use the tokenizer and model to get the logits for the input text
# assign a probability to each token to belong to one of the classes
inputs = tokenizer_bert_ner(input, return_tensors='pt')
logits = model_bert_ner(**inputs).logits
print(logits.shape)
print('Logits for one token: ', logits[0][0])

# find the actual predictions based on the probabilities
pred_ids = logits.argmax(dim=-1)[0].tolist()
token_ids = inputs['input_ids'][0].tolist()


torch.Size([1, 10, 9])
Logits for one token:  tensor([ 8.4647, -0.5055, -1.1914, -0.5454, -1.4987, -1.2059, -1.9625, -1.5004,
        -1.2675], grad_fn=<SelectBackward0>)


In [44]:

def assign_labels_to_tokens(pred_ids, token_ids, text, verbose=False):
    entities = []
    current_label = None
    current_start = None
    current_end = None


    for i, (start, end) in enumerate(offsets):
        tok = tokenizer_bert_ner.convert_ids_to_tokens([token_ids[i]])[0]
        if tok == '[CLS]':
            continue
        if verbose: print(f"Token: {tok}")
        # Skip special tokens or tokens without character span
        if (start == 0 and end == 0):
            if verbose: print(f"  not an entity")
            pred_lbl = 'O'
        else:
            raw_lbl = id2label[pred_ids[i]]
            pred_lbl = 'O' if raw_lbl == 'O' else raw_lbl
            if verbose: print(f"  found label: {raw_lbl}")
        if pred_lbl != 'O':
            if current_label == pred_lbl and current_end == start:
                # extend current span
                current_end = end
                if verbose: print(f" this token is part of an already found entity!")
            else:
                # close any previous span
                if current_label is not None:
                    span_text = text[current_start:current_end]
                    entities.append({'text': span_text, 'label': current_label})
                    if verbose: print(f"  end of the entity")
                # start new span
                current_label = pred_lbl
                current_start = start
                current_end = end
        else:
            if current_label is not None:
                span_text = text[current_start:current_end]
                entities.append({'text': span_text, 'label': current_label})
                current_label = None
                current_start = None
                current_end = None
    # close tail span if any
    if current_label is not None:
        span_text = text[current_start:current_end]
        entities.append({'text': span_text, 'label': current_label})

    return entities

In [47]:
entities = assign_labels_to_tokens(pred_ids, token_ids, text, verbose=True)
display(entities)

Token: Barack
  found label: B-PER
Token: Obama
  found label: I-PER
  end of the entity
Token: visited
  found label: O
Token: Stanford
  found label: B-ORG
Token: University
  found label: I-ORG
  end of the entity
Token: in
  found label: O
Token: California
  found label: B-LOC
Token: .
  not an entity


[{'text': 'Barack', 'label': 'B-PER'},
 {'text': 'Obama', 'label': 'I-PER'},
 {'text': 'Stanford', 'label': 'B-ORG'},
 {'text': 'University', 'label': 'I-ORG'},
 {'text': 'California', 'label': 'B-LOC'}]

## Part 2: Gemma prompting for NER
We prompt a chat model to extract entities and return JSON with keys `PERSON`, `ORG`, `LOC`.
We demonstrate zero-shot and few-shot (3-shot) prompting.

In [48]:
GEMMA_ID = 'unsloth/gemma-3-1B-it'

def load_chat_model(model_id):
    tok = AutoTokenizer.from_pretrained(model_id, cache_dir=CACHE_DIR)
    mdl = AutoModelForCausalLM.from_pretrained(model_id, cache_dir=CACHE_DIR, torch_dtype=torch.float16, device_map='auto')
    return tok, mdl

tokenizer_chat, model_chat = load_chat_model(GEMMA_ID)

print('Chat model:', GEMMA_ID)
print('Has chat template?', bool(getattr(tokenizer_chat, 'chat_template', None)))

SYSTEM_PROMPT = (
    'You extract named entities from text. Return a JSON object with keys PERSON, ORG, LOC, each mapped to an array of strings. Use exact surface forms from the text and avoid duplicates.'
)

def build_messages_zero_shot(text):
    return [
        {'role': 'system', 'content': SYSTEM_PROMPT},
        {'role': 'user', 'content': 'Text:' + text + 'Return only JSON.'}
    ]

# Few-shot examples
with open(os.path.join(DATA_DIR, 'few_shot_ner_examples.json'), 'r', encoding='utf-8') as f:
    FEW = json.load(f)

def build_messages_three_shot(text):
    msgs = [{'role': 'system', 'content': SYSTEM_PROMPT}]
    for ex in FEW:
        msgs.append({'role': 'user', 'content': ex['input']})
        msgs.append({'role': 'assistant', 'content': json.dumps(ex['output'])})
    msgs.append({'role': 'user', 'content': text})
    return msgs

def generate_json_entities(text, few_shot=False):
    messages = build_messages_three_shot(text) if few_shot else build_messages_zero_shot(text)
    if getattr(tokenizer_chat, 'apply_chat_template', None):
        input_ids = tokenizer_chat.apply_chat_template(messages, return_tensors='pt').to(model_chat.device)
    else:
        prompt = SYSTEM_PROMPT + 'Text:' + text + 'Return only JSON.'
        input_ids = tokenizer_chat(prompt, return_tensors='pt').input_ids.to(model_chat.device)
    if RUN_GEMMA:
        gen = model_chat.generate(input_ids, max_new_tokens=128, temperature=0.0)
        out = tokenizer_chat.decode(gen[0], skip_special_tokens=True)
        return out
    else:
        return tokenizer_chat.decode(input_ids[0])

# Preview prompts (no generation)
print('Zero-shot preview:', generate_json_entities(DATA[0]['text'], few_shot=False)[:400])
print('Three-shot preview:', generate_json_entities(DATA[0]['text'], few_shot=True)[:400])

Chat model: unsloth/gemma-3-1B-it
Has chat template? True
Zero-shot preview: <bos><start_of_turn>user
You extract named entities from text. Return a JSON object with keys PERSON, ORG, LOC, each mapped to an array of strings. Use exact surface forms from the text and avoid duplicates.

Text:Barack Obama visited Stanford University in California.Return only JSON.<end_of_turn>

Three-shot preview: <bos><start_of_turn>user
You extract named entities from text. Return a JSON object with keys PERSON, ORG, LOC, each mapped to an array of strings. Use exact surface forms from the text and avoid duplicates.

Barack Obama spoke at Stanford University in California.<end_of_turn>
<start_of_turn>model
{"PERSON": ["Barack Obama"], "ORG": ["Stanford University"], "LOC": ["California"]}<end_of_turn>
<st


## Part 3: Evaluation
We evaluate entity-level precision/recall/F1 by exact text match per label.
For the BERT pipeline, we map model labels to `PERSON/ORG/LOC`. For the chat model, we parse its JSON output.

In [49]:
def normalize_text(s):
    return re.sub(r'\s+', ' ', s.strip()).lower()

def gold_sets(entry):
    by_label = {}
    for ent in entry['entities']:
        by_label.setdefault(ent['label'], set()).add(normalize_text(ent['text']))
    return by_label

def pred_sets_from_bert(text):
    ents = extract_entities_bert(text)
    by_label = {}
    for e in ents:
        by_label.setdefault(e['label'], set()).add(normalize_text(e['text']))
    return by_label

def parse_chat_json(raw):
    # Attempt to parse JSON object with keys PERSON/ORG/LOC
    try:
        obj = json.loads(raw)
        return {k: set(normalize_text(x) for x in v) for k, v in obj.items() if isinstance(v, list)}
    except Exception:
        # Fallback: extract content between first '{' and last '}'
        m = re.search(r'\{.*\}', raw, re.S)
        if m:
            try:
                obj = json.loads(m.group(0))
                return {k: set(normalize_text(x) for x in v) for k, v in obj.items() if isinstance(v, list)}
            except Exception:
                pass
        return {}

def pred_sets_from_chat(text, few_shot=False):
    raw = generate_json_entities(text, few_shot=few_shot)
    return parse_chat_json(raw)

def prf(gold, pred, labels=('PERSON','ORG','LOC')):
    metrics = {}
    tp=fp=fn=0
    for lbl in labels:
        g = gold.get(lbl, set())
        p = pred.get(lbl, set())
        t = len(g & p)
        f_p = len(p - g)
        f_n = len(g - p)
        prec = t / (t + f_p) if (t + f_p) else 0.0
        rec = t / (t + f_n) if (t + f_n) else 0.0
        f1 = 2*prec*rec/(prec+rec) if (prec+rec) else 0.0
        metrics[lbl] = {'precision': prec, 'recall': rec, 'f1': f1}
        tp += t; fp += f_p; fn += f_n
    micro_prec = tp / (tp + fp) if (tp + fp) else 0.0
    micro_rec = tp / (tp + fn) if (tp + fn) else 0.0
    micro_f1 = 2*micro_prec*micro_rec/(micro_prec+micro_rec) if (micro_prec+micro_rec) else 0.0
    return metrics, {'precision': micro_prec, 'recall': micro_rec, 'f1': micro_f1}

# Evaluate both methods
results = {'bert': [], 'chat_zero': [], 'chat_three': []}
for entry in DATA:
    gold = gold_sets(entry)
    pred_b = pred_sets_from_bert(entry['text'])
    pred_z = pred_sets_from_chat(entry['text'], few_shot=False)
    pred_3 = pred_sets_from_chat(entry['text'], few_shot=True)
    m_b, micro_b = prf(gold, pred_b)
    m_z, micro_z = prf(gold, pred_z)
    m_3, micro_3 = prf(gold, pred_3)
    results['bert'].append(micro_b)
    results['chat_zero'].append(micro_z)
    results['chat_three'].append(micro_3)

def summarize(ms):
    arr_p = [x['precision'] for x in ms]
    arr_r = [x['recall'] for x in ms]
    arr_f = [x['f1'] for x in ms]
    return {'precision': float(np.mean(arr_p)), 'recall': float(np.mean(arr_r)), 'f1': float(np.mean(arr_f))}

print('BERT micro (avg):', summarize(results['bert']))
print('Chat zero-shot micro (avg):', summarize(results['chat_zero']))
print('Chat three-shot micro (avg):', summarize(results['chat_three']))

You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a dataset


BERT micro (avg): {'precision': 0.875, 'recall': 0.9166666666666666, 'f1': 0.8923809523809524}
Chat zero-shot micro (avg): {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}
Chat three-shot micro (avg): {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}


## Exercise
Extend the dataset and experiments:
- Add entity types (e.g., `DATE`, `EVENT`) and 10 more sentences to `lab3/data/ner_examples.json`.
- Compare BERT vs Gemma in zero-shot vs three-shot settings.
- Improve chat prompts to reduce false positives and duplicates.
- Report micro and per-label F1, and discuss error cases.