# Resume NER with BERT-BiLSTM-CRF (FYP)

This notebook trains a **BERT-BiLSTM-CRF** model for Named Entity Recognition on resumes. Entity types: **NAME**, **EMAIL**, **SKILL**, **OCCUPATION**, **EDUCATION**, **EXPERIENCE**. Data should be prepared with `prepare_data.py` in `resume_ner_pipeline/`.

## Dependencies

Run once.

In [None]:
!pip install -q torch transformers pytorch-crf seqeval

## Mount Google Drive

In [None]:
from google.colab import drive
drive.mount("/content/drive")

## 1. Data loading

In [None]:
import json
import os

_drive_base = "/content/drive/MyDrive" if os.path.exists("/content/drive/MyDrive") else "/content/drive/My Drive"
DATA_PATH = os.path.join(_drive_base, "merged_resume_ner.json")
if not os.path.exists(DATA_PATH):
    raise FileNotFoundError("JSON not found. Mount Google Drive and place merged_resume_ner.json in My Drive root.")

data = []
with open(DATA_PATH, "r", encoding="utf-8") as f:
    for line in f:
        line = line.strip()
        if line:
            data.append(json.loads(line))
print(f"Loaded {len(data)} resumes")

LABEL_MAPPING = {
    "Name": "NAME", "Email Address": "EMAIL", "Skills": "SKILL", "Designation": "OCCUPATION",
    "Degree": "EDUCATION", "College Name": "EDUCATION", "Graduation Year": "EDUCATION",
    "Companies worked at": "EXPERIENCE", "Years of Experience": "EXPERIENCE", "Location": "O", "UNKNOWN": "O",
    "NAME": "NAME", "EMAIL": "EMAIL", "SKILL": "SKILL", "OCCUPATION": "OCCUPATION", "EDUCATION": "EDUCATION", "EXPERIENCE": "EXPERIENCE", "O": "O",
}
for item in data:
    for ann in item.get("annotation", []):
        ann["label"] = [LABEL_MAPPING.get(l, "O") for l in ann["label"]]

## 2. Preprocessing and train/val/test split

In [None]:
import re
import random

def tokenize_with_positions(text):
    return [(m.group(), m.start(), m.end()) for m in re.finditer(r"\S+", text)]

def create_bio_tags_fixed(tokens, annotations):
    """Build BIO tags from token positions and annotation spans; no B-O / I-O."""
    bio = ["O"] * len(tokens)
    for ann in annotations:
        if not ann.get("label") or ann["label"][0] == "O":
            continue
        entity = ann["label"][0]
        for pt in ann.get("points", []):
            s, e = pt["start"], pt["end"]
            first = True
            for i, (_, ts, te) in enumerate(tokens):
                if te <= s or ts >= e:
                    continue
                bio[i] = f"B-{entity}" if first else f"I-{entity}"
                first = False
    return bio

all_sents, all_labels = [], []
for item in data:
    content = item.get("content", "")
    anns = item.get("annotation", [])
    if not content or not anns:
        continue
    toks = tokenize_with_positions(content)
    if not toks:
        continue
    labs = create_bio_tags_fixed(toks, anns)
    all_sents.append([t[0] for t in toks])
    all_labels.append(labs)

n = len(all_sents)
random.seed(42)
idx = list(range(n))
random.shuffle(idx)
n_train, n_val = int(0.8 * n), int(0.1 * n)
train_sents = [all_sents[i] for i in idx[:n_train]]
train_labels = [all_labels[i] for i in idx[:n_train]]
val_sents = [all_sents[i] for i in idx[n_train : n_train + n_val]]
val_labels = [all_labels[i] for i in idx[n_train : n_train + n_val]]
test_sents = [all_sents[i] for i in idx[n_train + n_val :]]
test_labels = [all_labels[i] for i in idx[n_train + n_val :]]
print(f"Train {len(train_sents)} Val {len(val_sents)} Test {len(test_sents)}")

## 3. BERT tokenizer and dataset

In [None]:
from transformers import BertTokenizer
from torch.utils.data import Dataset, DataLoader
import torch

TAGS = ["O", "B-NAME", "I-NAME", "B-EMAIL", "I-EMAIL", "B-SKILL", "I-SKILL", "B-OCCUPATION", "I-OCCUPATION", "B-EXPERIENCE", "I-EXPERIENCE", "B-EDUCATION", "I-EDUCATION"]
LABEL2ID = {t: i for i, t in enumerate(TAGS)}
ID2LABEL = {i: t for i, t in enumerate(TAGS)}
NUM_LABELS = len(TAGS)

def align_to_bert(words, word_labels, tokenizer, max_len=512):
    """Align word-level labels to BERT subword indices; label only first subword of each word."""
    first_idx, toks = [], ["[CLS]"]
    for w in words:
        p = tokenizer.tokenize(w) or [tokenizer.unk_token]
        first_idx.append(len(toks))
        toks.extend(p)
    toks.append("[SEP]")
    ids = tokenizer.convert_tokens_to_ids(toks)
    mask = [1] * len(ids)
    aligned = [-100] * len(ids)
    for pos, lab in zip(first_idx, word_labels):
        if pos < len(aligned):
            aligned[pos] = LABEL2ID.get(lab, 0)
    if len(ids) > max_len:
        ids = ids[: max_len - 1] + [tokenizer.sep_token_id]
        mask = mask[: max_len - 1] + [1]
        aligned = aligned[: max_len - 1] + [-100]
    return ids, mask, aligned

class BertNERDataset(Dataset):
    """Dataset of (input_ids, attention_mask, labels) for BERT NER."""
    def __init__(self, sents, labels, tokenizer, max_len=512):
        self.samples = [align_to_bert(w, l, tokenizer, max_len) for w, l in zip(sents, labels) if len(w) == len(l)]

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, i):
        return self.samples[i]

tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
train_ds = BertNERDataset(train_sents, train_labels, tokenizer)
val_ds = BertNERDataset(val_sents, val_labels, tokenizer)

def collate(batch):
    max_l = max(len(b[0]) for b in batch)
    pad = 0
    return (
        torch.tensor([b[0] + [pad] * (max_l - len(b[0])) for b in batch], dtype=torch.long),
        torch.tensor([b[1] + [0] * (max_l - len(b[1])) for b in batch], dtype=torch.long),
        torch.tensor([b[2] + [-100] * (max_l - len(b[2])) for b in batch], dtype=torch.long),
    )

rare_tags = {"B-EDUCATION", "I-EDUCATION", "B-EXPERIENCE", "I-EXPERIENCE", "B-OCCUPATION", "I-OCCUPATION"}
train_weights = [2.0 if any(t in rare_tags for t in l) else 1.0 for w, l in zip(train_sents, train_labels) if len(w) == len(l)]
from torch.utils.data import WeightedRandomSampler
train_sampler = WeightedRandomSampler(weights=train_weights, num_samples=len(train_weights))
train_loader = DataLoader(train_ds, batch_size=8, sampler=train_sampler, collate_fn=collate)
val_loader = DataLoader(val_ds, batch_size=8, collate_fn=collate)
print("Datasets ready")

## 4. Model: BERT-BiLSTM-CRF

In [None]:
import torch.nn as nn
from transformers import BertModel
from torchcrf import CRF

class BertBiLSTMCRF(nn.Module):
    """BERT encoder + BiLSTM + CRF for NER."""
    def __init__(self, bert_name="bert-base-uncased", hidden_dim=256, num_labels=NUM_LABELS, dropout=0.3):
        super().__init__()
        self.bert = BertModel.from_pretrained(bert_name)
        self.lstm = nn.LSTM(self.bert.config.hidden_size, hidden_dim // 2, num_layers=1, bidirectional=True, batch_first=True)
        self.drop = nn.Dropout(dropout)
        self.fc = nn.Linear(hidden_dim, num_labels)
        self.crf = CRF(num_labels, batch_first=True)

    def forward(self, input_ids, attention_mask, labels=None):
        out = self.bert(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
        out, _ = self.lstm(self.drop(out))
        emissions = self.fc(self.drop(out))
        mask_b = attention_mask.bool()
        if labels is not None:
            labels = labels.clone().masked_fill(labels == -100, 0)
            return -self.crf(emissions, labels, mask=mask_b, reduction="mean")
        return self.crf.decode(emissions, mask=mask_b)

if torch.cuda.is_available():
    device = torch.device("cuda")
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

model = BertBiLSTMCRF(dropout=0.3).to(device)

no_decay = ["bias", "LayerNorm.weight"]
bert_params = list(model.bert.named_parameters())
optimizer_grouped = [
    {"params": [p for n, p in bert_params if not any(nd in n for nd in no_decay)], "lr": 2e-5, "weight_decay": 0.01},
    {"params": [p for n, p in bert_params if any(nd in n for nd in no_decay)], "lr": 2e-5, "weight_decay": 0.0},
    {"params": [p for n, p in model.named_parameters() if not n.startswith("bert.") and not any(nd in n for nd in no_decay)], "lr": 1e-4, "weight_decay": 0.01},
    {"params": [p for n, p in model.named_parameters() if not n.startswith("bert.") and any(nd in n for nd in no_decay)], "lr": 1e-4, "weight_decay": 0.0},
]
optimizer = torch.optim.AdamW(optimizer_grouped)
print(f"Model on {device}")

## 5. Training (with early stopping and validation F1)

In [None]:
from seqeval.metrics import f1_score
from torch.optim.lr_scheduler import LinearLR, SequentialLR, ConstantLR

def run_validation(model, val_loader, device, id2label, num_labels):
    """Return (val_f1, true_all, pred_all) for early stopping and reporting."""
    model.eval()
    true_all, pred_all = [], []
    with torch.no_grad():
        for inp, mask, labels in val_loader:
            inp, mask = inp.to(device), mask.to(device)
            preds = model(inp, mask)
            for b in range(inp.size(0)):
                m, labs = mask[b].cpu(), labels[b].cpu()
                pred_b = preds[b]
                tlist, plist = [], []
                pos = 0
                for i in range(m.size(0)):
                    if m[i].item() == 0:
                        break
                    p = id2label[pred_b[pos]] if pos < len(pred_b) and pred_b[pos] < num_labels else "O"
                    pos += 1
                    if labs[i].item() == -100:
                        continue
                    tlist.append(id2label[labs[i].item()])
                    plist.append(p)
                if tlist and plist:
                    true_all.append(tlist)
                    pred_all.append(plist)
    f1 = f1_score(true_all, pred_all, zero_division=0) if true_all else 0.0
    return f1, true_all, pred_all

EPOCHS = 60
PATIENCE = 12
best_f1 = 0.0
best_state = None
epochs_no_improve = 0

warmup_epochs = max(1, EPOCHS // 10)
scheduler = SequentialLR(optimizer, [
    ConstantLR(optimizer, factor=0.1, total_iters=warmup_epochs),
    LinearLR(optimizer, start_factor=1.0, end_factor=0.2, total_iters=EPOCHS - warmup_epochs),
], milestones=[warmup_epochs])

for epoch in range(EPOCHS):
    model.train()
    total = 0
    for inp, mask, lab in train_loader:
        inp, mask, lab = inp.to(device), mask.to(device), lab.to(device)
        optimizer.zero_grad()
        loss = model(inp, mask, lab)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        total += loss.item()
    scheduler.step()
    val_f1, _, _ = run_validation(model, val_loader, device, ID2LABEL, NUM_LABELS)
    if val_f1 > best_f1:
        best_f1 = val_f1
        best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
        epochs_no_improve = 0
    else:
        epochs_no_improve += 1
    print(f"Epoch {epoch+1}/{EPOCHS} Loss: {total/len(train_loader):.4f} Val F1: {val_f1:.4f} Best: {best_f1:.4f}")
    if epochs_no_improve >= PATIENCE:
        print(f"Early stopping (no improvement for {PATIENCE} epochs).")
        break
if best_state is not None:
    model.load_state_dict(best_state)
    print("Restored best checkpoint (by val F1).")

## 6. Save model and tokenizer

In [None]:
SAVE_DIR = os.environ.get("RESUME_NER_SAVE_DIR", "resume_ner")
os.makedirs(SAVE_DIR, exist_ok=True)

torch.save(model.state_dict(), os.path.join(SAVE_DIR, "bert_bilstm_crf_state.pt"))
config = {"tags": TAGS, "bert_name": "bert-base-uncased", "num_labels": NUM_LABELS}
with open(os.path.join(SAVE_DIR, "ner_config.json"), "w", encoding="utf-8") as f:
    json.dump(config, f, indent=2)
tokenizer.save_pretrained(SAVE_DIR)

print("Saved:", SAVE_DIR)
print("  - bert_bilstm_crf_state.pt")
print("  - ner_config.json")
print("  - tokenizer files")

## 7. Load saved model and run inference

In [None]:
LOAD_DIR = os.environ.get("RESUME_NER_LOAD_DIR", "resume_ner")
with open(os.path.join(LOAD_DIR, "ner_config.json"), "r", encoding="utf-8") as f:
    load_config = json.load(f)

TAGS = load_config["tags"]
LABEL2ID = {t: i for i, t in enumerate(TAGS)}
ID2LABEL = {i: t for i, t in enumerate(TAGS)}
NUM_LABELS = load_config["num_labels"]

from transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained(LOAD_DIR)
model = BertBiLSTMCRF(bert_name=load_config["bert_name"], num_labels=NUM_LABELS).to(device)
model.load_state_dict(torch.load(os.path.join(LOAD_DIR, "bert_bilstm_crf_state.pt"), map_location=device))
model.eval()

print("Model loaded from", LOAD_DIR)

def parse_resume(text, tokenizer, model, device, id2label, max_len=512):
    """Tokenize resume text, run NER, return (words, tags) and entity dict."""
    words = re.findall(r"\S+", text)
    if not words:
        return [], [], {}
    first_idx, toks = [], ["[CLS]"]
    for w in words:
        sub = tokenizer.tokenize(w) or [tokenizer.unk_token]
        first_idx.append(len(toks))
        toks.extend(sub)
    toks.append("[SEP]")
    ids = tokenizer.convert_tokens_to_ids(toks)
    if len(ids) > max_len:
        ids = ids[: max_len - 1] + [tokenizer.sep_token_id]
        first_idx = [i for i in first_idx if i < len(ids)]
        words = words[: len(first_idx)]
    mask = [1] * len(ids)
    inp = torch.tensor([ids], dtype=torch.long).to(device)
    mask_t = torch.tensor([mask], dtype=torch.long).to(device)
    model.eval()
    with torch.no_grad():
        preds = model(inp, mask_t)
    pred_tags = [id2label.get(preds[0][i], "O") for i in first_idx]
    entities = {}
    i = 0
    while i < len(words):
        tag = pred_tags[i] if i < len(pred_tags) else "O"
        if tag.startswith("B-"):
            entity_type = tag[2:]
            phrase = [words[i]]
            i += 1
            while i < len(words) and i < len(pred_tags) and pred_tags[i] == f"I-{entity_type}":
                phrase.append(words[i])
                i += 1
            entities.setdefault(entity_type, []).append(" ".join(phrase))
        else:
            i += 1
    return words, pred_tags, entities

EMAIL_RE = re.compile(r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\\.[a-zA-Z]{2,}", re.IGNORECASE)
def extract_email_rules(text):
    return list(dict.fromkeys(EMAIL_RE.findall(text)))

def extract_name_heuristic(text):
    lines = [ln.strip() for ln in text.strip().split("\n") if ln.strip()]
    for line in lines[:4]:
        if "@" in line or "http" in line.lower() or "www." in line.lower():
            continue
        parts = line.split()
        if 1 <= len(parts) <= 4 and all(p[0].isupper() for p in parts if len(p) > 0 and p[0].isalpha()):
            c = " ".join(parts)
            if len(c) < 80 and not c.endswith("."):
                return [c]
    return []

def parse_resume_hybrid(text, tokenizer, model, device, id2label, max_len=512):
    """Hybrid: NAME/EMAIL from rules (high recall), SKILL/EXPERIENCE/EDUCATION/OCCUPATION from model."""
    text = text.strip()
    rn, re_ = extract_name_heuristic(text), extract_email_rules(text)
    words, pred_tags, entities = parse_resume(text, tokenizer, model, device, id2label, max_len)
    if rn:
        entities["NAME"] = rn
    if re_:
        entities["EMAIL"] = re_
    return words, pred_tags, entities

In [None]:
RESUME_TEXT = """
John Doe
john.doe@email.com
Software Engineer with 5 years of experience.
Skills: Python, Java, Machine Learning.
Education: BSc Computer Science, University of Colombo 2020.
Worked at Tech Corp and Data Inc.
"""
words, tags, entities = parse_resume_hybrid(RESUME_TEXT.strip(), tokenizer, model, device, ID2LABEL)
print("Entities (hybrid):")
for k, v in entities.items():
    print(f"  {k}: {v}")
print("\nWord-level tags (first 30):", list(zip(words[:30], tags[:30])))

## Example texts â€“ check extracted values

Run the cell below with different resume snippets to see extracted entities (NAME, EMAIL, SKILL, OCCUPATION, EDUCATION, EXPERIENCE).

In [None]:
EXAMPLE_TEXTS = [
    """Jane Smith
    jane.smith@gmail.com
    Data Scientist | 4 years experience
    Skills: Python, SQL, TensorFlow, NLP.
    MSc Data Science, University of Moratuwa 2019.
    Previous: Analytics Ltd, BigData Inc.""",
    """Kamal Perera
    kamal.p@company.lk
    Senior Software Engineer with 8+ years. Java, Spring, AWS.
    BSc Eng (Hons) Computer Science, University of Peradeniya 2014.
    Worked at Virtusa and WSO2.""",
    """Maria Garcia
    maria.garcia@outlook.com
    Product Manager. Agile, Jira, user research.
    MBA, Colombo Business School 2021. BA Economics 2016.
    Experience: StartupXYZ, Tech Solutions Pvt Ltd."""
]

for i, text in enumerate(EXAMPLE_TEXTS, 1):
    words, tags, entities = parse_resume_hybrid(text.strip(), tokenizer, model, device, ID2LABEL)
    print(f"{'='*60}\nExample {i}\n{'='*60}")
    print("Extracted entities:")
    for k, v in entities.items():
        print(f"  {k}: {v}")
    print("\nWord-level tags (first 20):", list(zip(words[:20], tags[:20])))
    print()

## 8. Evaluation (validation and test)

### Validation evaluation

In [None]:
from seqeval.metrics import classification_report, f1_score

model.eval()
true_all, pred_all = [], []
with torch.no_grad():
    for inp, mask, labels in val_loader:
        inp, mask = inp.to(device), mask.to(device)
        preds = model(inp, mask)
        for b in range(inp.size(0)):
            m, labs = mask[b].cpu(), labels[b].cpu()
            pred_b = preds[b]
            tlist, plist = [], []
            pos = 0
            for i in range(m.size(0)):
                if m[i].item() == 0:
                    break
                p = ID2LABEL[pred_b[pos]] if pos < len(pred_b) and pred_b[pos] < NUM_LABELS else "O"
                pos += 1
                if labs[i].item() == -100:
                    continue
                tlist.append(ID2LABEL[labs[i].item()])
                plist.append(p)
            if tlist and plist:
                true_all.append(tlist)
                pred_all.append(plist)

print(classification_report(true_all, pred_all, zero_division=0))
val_f1 = f1_score(true_all, pred_all, zero_division=0)
print("Val F1 (entity-level):", val_f1)
total_tok = sum(len(t) for t in true_all)
correct_tok = sum(sum(1 for a, b in zip(t, p) if a == b) for t, p in zip(true_all, pred_all))
token_acc = correct_tok / total_tok if total_tok else 0.0
print("Token accuracy: {:.2%}".format(token_acc))

### Test set evaluation

In [None]:
test_ds = BertNERDataset(test_sents, test_labels, tokenizer)
test_loader = DataLoader(test_ds, batch_size=8, collate_fn=collate)

model.eval()
true_test, pred_test = [], []
with torch.no_grad():
    for inp, mask, labels in test_loader:
        inp, mask = inp.to(device), mask.to(device)
        preds = model(inp, mask)
        for b in range(inp.size(0)):
            m, labs = mask[b].cpu(), labels[b].cpu()
            pred_b = preds[b]
            tlist, plist = [], []
            pos = 0
            for i in range(m.size(0)):
                if m[i].item() == 0:
                    break
                p = ID2LABEL[pred_b[pos]] if pos < len(pred_b) and pred_b[pos] < NUM_LABELS else "O"
                pos += 1
                if labs[i].item() == -100:
                    continue
                tlist.append(ID2LABEL[labs[i].item()])
                plist.append(p)
            if tlist and plist:
                true_test.append(tlist)
                pred_test.append(plist)

print("--- Test set results ---")
print(classification_report(true_test, pred_test, zero_division=0))
print("Test F1:", f1_score(true_test, pred_test, zero_division=0))