# M2 Resume Extractor Training (Kaggle + P100 GPU)

Fine-tunes **yashpwr/resume-ner-bert-v2** for BIO NER on resumes with 14 entity types (29 labels).
Two-phase training: frozen layers 0-8 for 2 epochs, then all layers unfrozen for 6 epochs.

## Setup
1. Upload `m2_training_data.zip` as a Kaggle Dataset (it will auto-extract)
2. Add the dataset to this notebook via the sidebar **Add Data** button
3. Select **GPU P100** in notebook settings (Settings > Accelerator)
4. Enable **Internet** (Settings > Internet > On)
5. Run all cells
6. Download the trained model zip from the output

## Data Sources
| Dataset | Type | Expected Sequences |
|---------|------|-------------------|
| Mehyaar | Gold NER annotations | ~25K |
| DataTurks | Gold NER annotations | ~1K |
| Djinni | Weak supervision from structured fields | ~16K |
| DatasetMaster | Structured field synthesis | varies |

In [None]:
# Cell 1: Suppress TF/CUDA warnings + install dependencies
#
# Kaggle pre-loads both TensorFlow and PyTorch, causing duplicate CUDA factory
# registration messages (cuFFT, cuDNN, cuBLAS). These are cosmetic only and do
# NOT affect training. Setting TF_CPP_MIN_LOG_LEVEL=3 before any imports
# suppresses them entirely.

import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'           # Suppress TF C++ logs (INFO/WARNING/ERROR)
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'           # Suppress oneDNN messages
os.environ['GRPC_VERBOSITY'] = 'ERROR'               # Suppress gRPC logs
os.environ['ABSL_MIN_LOG_LEVEL'] = '2'               # Suppress abseil warnings

import warnings
warnings.filterwarnings('ignore', message='.*computation placer already registered.*')
warnings.filterwarnings('ignore', message='.*Unable to register.*factory.*')

!pip install -q transformers datasets seqeval accelerate pandas pyarrow pyyaml

In [None]:
# Cell 2: Check GPU
import torch
print(f"GPU available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    print("WARNING: No GPU detected! Enable GPU in Settings > Accelerator.")

In [None]:
# Cell 3: Locate training data
#
# Kaggle auto-extracts uploaded zip files. When you upload m2_training_data.zip
# as a dataset, the contents are extracted into /kaggle/input/<dataset-name>/.
# Kaggle converts underscores to hyphens in dataset names, so "m2_training_data"
# becomes "/kaggle/input/m2-training-data/". The actual folder structure may be
# nested further depending on how the zip was created.
#
# This cell searches /kaggle/input/ for the expected data folders and sets DATA_DIR.

import os
from pathlib import Path

# The 5 expected data subdirectories
TARGET_FOLDERS = {
    "yashpwr_resume_ner", "dataturks_resume_ner", "mehyaar_ner_cvs",
    "datasetmaster_resumes", "djinni_candidates"
}

print("Available datasets in /kaggle/input/:")
for d in os.listdir("/kaggle/input/"):
    print(f"  /kaggle/input/{d}/")

# Search for the directory containing our data folders
DATA_DIR = None
for root, dirs, files in os.walk("/kaggle/input/"):
    if any(d in TARGET_FOLDERS for d in dirs):
        DATA_DIR = Path(root)
        break
    # Don't search too deep
    if root.count(os.sep) - "/kaggle/input/".count(os.sep) > 4:
        break

if DATA_DIR:
    print(f"\nData root found: {DATA_DIR}")
    print("Contents:")
    for item in sorted(os.listdir(DATA_DIR)):
        full = DATA_DIR / item
        if full.is_dir():
            count = sum(1 for _ in full.rglob("*") if _.is_file())
            print(f"  {item}/ ({count} files)")
        else:
            print(f"  {item} ({full.stat().st_size / 1e6:.1f} MB)")
else:
    # Fallback: show full tree for debugging
    print("\nERROR: Could not find expected data folders!")
    print("Full /kaggle/input/ tree:")
    for root, dirs, files in os.walk("/kaggle/input/"):
        depth = root.replace("/kaggle/input/", "").count(os.sep)
        indent = "  " * depth
        print(f"{indent}{os.path.basename(root)}/")
        if depth < 4:
            for f in files[:10]:
                print(f"{indent}  {f}")
            if len(files) > 10:
                print(f"{indent}  ... and {len(files) - 10} more files")
    print("\nPlease add m2_training_data.zip as a Dataset in the notebook sidebar.")

In [None]:
# Cell 4: Data preparation - load and unify all resume NER datasets

import json
import logging
import random
import re
from pathlib import Path
from typing import Any

import numpy as np
import pandas as pd
from datasets import Dataset, DatasetDict, Features, Sequence, Value

logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
logger = logging.getLogger(__name__)

# DATA_DIR is set by Cell 3
assert DATA_DIR is not None, "DATA_DIR not set! Re-run Cell 3."
print(f"Using DATA_DIR: {DATA_DIR}")

# --- Entity types and BIO labels ---
ENTITY_TYPES = [
    "NAME", "EMAIL", "PHONE", "LOCATION", "DESIGNATION", "COMPANY",
    "DEGREE", "GRADUATION_YEAR", "COLLEGE_NAME", "YEARS_OF_EXPERIENCE",
    "SKILLS", "CERTIFICATION", "PROJECT_NAME", "PROJECT_TECHNOLOGY",
]

LABELS = ["O"]
for etype in ENTITY_TYPES:
    LABELS.append(f"B-{etype}")
    LABELS.append(f"I-{etype}")

LABEL2ID = {label: idx for idx, label in enumerate(LABELS)}
ID2LABEL = {idx: label for idx, label in enumerate(LABELS)}

print(f"Entity types: {len(ENTITY_TYPES)}")
print(f"Total labels (BIO): {len(LABELS)}")

# --- Label normalization map ---
_LABEL_NORMALIZE = {
    "Name": "NAME", "name": "NAME",
    "EMAIL": "EMAIL", "Email Address": "EMAIL", "email": "EMAIL",
    "Phone": "PHONE", "phone": "PHONE", "PHONE": "PHONE",
    "Location": "LOCATION", "location": "LOCATION", "LOCATION": "LOCATION",
    "Designation": "DESIGNATION", "designation": "DESIGNATION", "DESIGNATION": "DESIGNATION",
    "Companies worked at": "COMPANY", "Company": "COMPANY", "company": "COMPANY", "COMPANY": "COMPANY",
    "Degree": "DEGREE", "degree": "DEGREE", "DEGREE": "DEGREE",
    "Graduation Year": "GRADUATION_YEAR", "graduation_year": "GRADUATION_YEAR",
    "College Name": "COLLEGE_NAME", "college_name": "COLLEGE_NAME", "COLLEGE": "COLLEGE_NAME",
    "Years of Experience": "YEARS_OF_EXPERIENCE", "years_of_experience": "YEARS_OF_EXPERIENCE", "Experience": "YEARS_OF_EXPERIENCE",
    "Skills": "SKILLS", "skills": "SKILLS", "SKILLS": "SKILLS",
    "Certification": "CERTIFICATION", "certification": "CERTIFICATION", "CERTIFICATION": "CERTIFICATION",
    "Project": "PROJECT_NAME", "project": "PROJECT_NAME", "PROJECT": "PROJECT_NAME",
    "Technology": "PROJECT_TECHNOLOGY", "technology": "PROJECT_TECHNOLOGY",
}


def _normalize_label(raw_label):
    return _LABEL_NORMALIZE.get(raw_label)


def _tokenize_with_offsets(text):
    tokens, offsets = [], []
    for m in re.finditer(r"\S+", text):
        tokens.append(m.group())
        offsets.append((m.start(), m.end()))
    return tokens, offsets


def _bio_tags_from_char_spans(tokens, char_offsets, spans):
    tags = ["O"] * len(tokens)
    for span in spans:
        s_start, s_end, label = span["start"], span["end"], span["label"]
        first = True
        for idx, (t_start, t_end) in enumerate(char_offsets):
            if t_end <= s_start or t_start >= s_end:
                continue
            prefix = "B" if first else "I"
            tag = f"{prefix}-{label}"
            if tag in LABEL2ID:
                tags[idx] = tag
                first = False
    return tags


def _label_to_id(tags):
    return [LABEL2ID.get(t, 0) for t in tags]


# --- Dataset Loaders ---

def load_dataturks():
    json_path = DATA_DIR / "dataturks_resume_ner" / "Entity Recognition in Resumes.json"
    if not json_path.exists():
        logger.warning("DataTurks not found -- skipping.")
        return []
    records = []
    with open(json_path, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            try:
                entry = json.loads(line)
            except json.JSONDecodeError:
                continue
            content = entry.get("content", "")
            annotations = entry.get("annotation", [])
            if not content or not annotations:
                continue
            tokens, offsets = _tokenize_with_offsets(content)
            if len(tokens) < 3:
                continue
            spans = []
            for ann in annotations:
                raw_labels = ann.get("label", [])
                if isinstance(raw_labels, str):
                    raw_labels = [raw_labels]
                points = ann.get("points", [])
                if not points:
                    continue
                for raw_label in raw_labels:
                    norm = _normalize_label(raw_label)
                    if norm is None:
                        continue
                    for pt in points:
                        start, end = pt.get("start"), pt.get("end")
                        if start is None or end is None:
                            continue
                        spans.append({"start": start, "end": end + 1, "label": norm})
            tags = _bio_tags_from_char_spans(tokens, offsets, spans)
            for i in range(0, len(tokens), 128):
                ct, ctags = tokens[i:i+128], tags[i:i+128]
                if len(ct) >= 3:
                    records.append({"tokens": ct, "ner_tags": _label_to_id(ctags), "source": "dataturks"})
    logger.info("DataTurks: %d sequences", len(records))
    return records


def load_mehyaar():
    base_dir = DATA_DIR / "mehyaar_ner_cvs" / "ResumesJsonAnnotated" / "ResumesJsonAnnotated"
    if not base_dir.exists():
        logger.warning("Mehyaar not found -- skipping.")
        return []
    records = []
    for jf in sorted(base_dir.glob("*.json")):
        try:
            with open(jf, "r", encoding="utf-8") as f:
                data = json.load(f)
        except (json.JSONDecodeError, UnicodeDecodeError):
            continue
        text = data.get("text", data.get("content", ""))
        annotations = data.get("annotations", data.get("annotation", []))
        if not text or not annotations:
            continue
        tokens, offsets = _tokenize_with_offsets(text)
        if len(tokens) < 3:
            continue
        spans = []
        for ann in annotations:
            if isinstance(ann, dict):
                label = ann.get("label", ann.get("type", ""))
                start = ann.get("start", ann.get("startOffset"))
                end = ann.get("end", ann.get("endOffset"))
            elif isinstance(ann, (list, tuple)) and len(ann) >= 3:
                start, end, label = ann[0], ann[1], ann[2]
            else:
                continue
            if start is None or end is None or not label:
                continue
            norm = _normalize_label(label)
            if norm is None:
                if label.upper() in [e.upper() for e in ENTITY_TYPES]:
                    norm = label.upper()
                else:
                    continue
            spans.append({"start": int(start), "end": int(end), "label": norm})
        tags = _bio_tags_from_char_spans(tokens, offsets, spans)
        for i in range(0, len(tokens), 128):
            ct, ctags = tokens[i:i+128], tags[i:i+128]
            if len(ct) >= 3:
                records.append({"tokens": ct, "ner_tags": _label_to_id(ctags), "source": "mehyaar"})
    logger.info("Mehyaar: %d sequences", len(records))
    return records


def load_datasetmaster():
    parquet_path = DATA_DIR / "datasetmaster_resumes" / "train.parquet"
    if not parquet_path.exists():
        logger.warning("DatasetMaster not found -- skipping.")
        return []
    df = pd.read_parquet(parquet_path)
    records = []
    text_col = None
    for candidate in ["resume_text", "text", "content", "resume", "Resume"]:
        if candidate in df.columns:
            text_col = candidate
            break
    if text_col is None:
        for _, row in df.iterrows():
            tokens_all, tags_all = [], []
            for col, label in [("skills", "SKILLS"), ("education", "DEGREE"), ("projects", "PROJECT_NAME")]:
                val = row.get(col)
                if isinstance(val, str) and val.strip():
                    toks, _ = _tokenize_with_offsets(val)
                    tokens_all.extend(toks)
                    tags_all.extend([f"B-{label}"] + [f"I-{label}"] * (len(toks) - 1))
                elif isinstance(val, list):
                    for item in val:
                        if isinstance(item, str) and item.strip():
                            toks, _ = _tokenize_with_offsets(item)
                            tokens_all.extend(toks)
                            tags_all.extend([f"B-{label}"] + [f"I-{label}"] * (len(toks) - 1))
            if len(tokens_all) >= 5:
                records.append({"tokens": tokens_all, "ner_tags": _label_to_id(tags_all), "source": "datasetmaster"})
        logger.info("DatasetMaster (structured): %d sequences", len(records))
        return records
    field_label_map = {
        "skills": "SKILLS", "education": "DEGREE", "company": "COMPANY",
        "designation": "DESIGNATION", "college": "COLLEGE_NAME", "degree": "DEGREE",
        "projects": "PROJECT_NAME", "certification": "CERTIFICATION", "certifications": "CERTIFICATION",
    }
    for _, row in df.iterrows():
        text = row[text_col]
        if not isinstance(text, str) or len(text) < 30:
            continue
        tokens, offsets = _tokenize_with_offsets(text[:2000])
        if len(tokens) < 5:
            continue
        spans = []
        for col, label in field_label_map.items():
            val = row.get(col)
            if val is None:
                continue
            search_terms = []
            if isinstance(val, str) and val.strip():
                search_terms = [val.strip()]
            elif isinstance(val, list):
                search_terms = [str(v).strip() for v in val if isinstance(v, str) and v.strip()]
            for term in search_terms[:10]:
                try:
                    for m in re.finditer(re.escape(term[:100]), text[:2000], re.IGNORECASE):
                        spans.append({"start": m.start(), "end": m.end(), "label": label})
                        break
                except re.error:
                    continue
        tags = _bio_tags_from_char_spans(tokens, offsets, spans)
        for i in range(0, len(tokens), 128):
            ct, ctags = tokens[i:i+128], tags[i:i+128]
            if len(ct) >= 3:
                records.append({"tokens": ct, "ner_tags": _label_to_id(ctags), "source": "datasetmaster"})
    logger.info("DatasetMaster: %d sequences", len(records))
    return records


def load_djinni():
    """Load Djinni candidates. Uses 'CV' column for text, 'Position' for designation,
    'Primary Keyword' for skills."""
    parquet_path = DATA_DIR / "djinni_candidates" / "train.parquet"
    if not parquet_path.exists():
        logger.warning("Djinni not found -- skipping.")
        return []
    df = pd.read_parquet(parquet_path)
    if len(df) > 15000:
        df = df.sample(n=15000, random_state=42)
    records = []
    exp_pattern = re.compile(r"\b(\d+)\+?\s*years?\b", re.IGNORECASE)

    for _, row in df.iterrows():
        text = ""
        for col in ["CV", "Moreinfo", "Highlights", "description", "text", "bio", "summary", "content"]:
            val = row.get(col, "")
            if isinstance(val, str) and len(val) > 20:
                text = val
                break
        if not text:
            continue

        text = text[:1500]
        tokens, offsets = _tokenize_with_offsets(text)
        if len(tokens) < 5:
            continue

        spans = []

        for m in exp_pattern.finditer(text):
            spans.append({"start": m.start(), "end": m.end(), "label": "YEARS_OF_EXPERIENCE"})

        for skill_col in ["Primary Keyword", "skills", "keywords", "technologies"]:
            skills_val = row.get(skill_col, "")
            if isinstance(skills_val, str) and skills_val.strip():
                skill_list = [s.strip() for s in skills_val.split(",") if s.strip()]
                for skill in skill_list[:15]:
                    try:
                        for m in re.finditer(re.escape(skill), text, re.IGNORECASE):
                            spans.append({"start": m.start(), "end": m.end(), "label": "SKILLS"})
                            break
                    except re.error:
                        continue
                break

        for pos_col in ["Position", "position", "title", "designation", "job_title"]:
            pos_val = row.get(pos_col, "")
            if isinstance(pos_val, str) and pos_val.strip():
                try:
                    for m in re.finditer(re.escape(pos_val.strip()), text, re.IGNORECASE):
                        spans.append({"start": m.start(), "end": m.end(), "label": "DESIGNATION"})
                        break
                except re.error:
                    pass
                break

        tags = _bio_tags_from_char_spans(tokens, offsets, spans)
        if any(t != "O" for t in tags):
            for i in range(0, len(tokens), 128):
                ct, ctags = tokens[i:i+128], tags[i:i+128]
                if len(ct) >= 3:
                    records.append({"tokens": ct, "ner_tags": _label_to_id(ctags), "source": "djinni"})

    logger.info("Djinni: %d sequences", len(records))
    return records


# --- Load all datasets ---
print("Loading datasets...")
all_records = []
for name, loader in [("DataTurks", load_dataturks), ("Mehyaar", load_mehyaar),
                      ("DatasetMaster", load_datasetmaster), ("Djinni", load_djinni)]:
    try:
        recs = loader()
        all_records.extend(recs)
        print(f"  {name}: {len(recs)} sequences (total: {len(all_records)})")
    except Exception as e:
        print(f"  {name}: FAILED - {e}")

for rec in all_records:
    rec.pop("_raw_tags", None)

print(f"\nTotal: {len(all_records)} sequences")

In [None]:
# Cell 5: Clean and split data

cleaned = []
for rec in all_records:
    tokens = [str(t) if not isinstance(t, str) else t for t in rec.get("tokens", [])]
    tags = [int(t) if not isinstance(t, int) else t for t in rec.get("ner_tags", [])]
    if not tokens or len(tokens) != len(tags):
        continue
    if any(t in ("nan", "None", "") for t in tokens):
        continue
    tokens = [t.encode("utf-8", errors="replace").decode("utf-8") for t in tokens]
    source = str(rec.get("source", "unknown")).encode("utf-8", errors="replace").decode("utf-8")
    cleaned.append({"tokens": tokens, "ner_tags": tags, "source": source})

print(f"Cleaned: {len(all_records)} -> {len(cleaned)} records (dropped {len(all_records) - len(cleaned)})")

TRAIN_RATIO = 0.8
VAL_RATIO = 0.1

rng = random.Random(42)
rng.shuffle(cleaned)

n = len(cleaned)
n_train = int(n * TRAIN_RATIO)
n_val = int(n * VAL_RATIO)

splits = {
    "train": cleaned[:n_train],
    "validation": cleaned[n_train:n_train + n_val],
    "test": cleaned[n_train + n_val:],
}

features = Features({
    "tokens": Sequence(Value("string")),
    "ner_tags": Sequence(Value("int32")),
    "source": Value("string"),
})

dd = DatasetDict()
for split_name, split_records in splits.items():
    dd[split_name] = Dataset.from_dict(
        {
            "tokens": [r["tokens"] for r in split_records],
            "ner_tags": [r["ner_tags"] for r in split_records],
            "source": [r["source"] for r in split_records],
        },
        features=features,
    )
    print(f"  {split_name}: {len(split_records)} examples")

train_ds, val_ds, test_ds = dd["train"], dd["validation"], dd["test"]
print(f"\nTrain: {len(train_ds)}, Val: {len(val_ds)}, Test: {len(test_ds)}")

In [None]:
# Cell 6: Tokenize and align labels
from transformers import AutoTokenizer

BASE_MODEL = "yashpwr/resume-ner-bert-v2"
MAX_LENGTH = 512

tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)

def tokenize_and_align(examples):
    tokenized = tokenizer(
        examples["tokens"],
        truncation=True,
        is_split_into_words=True,
        max_length=MAX_LENGTH,
        padding="max_length",
    )
    all_labels = []
    for i, labels in enumerate(examples["ner_tags"]):
        word_ids = tokenized.word_ids(batch_index=i)
        aligned = []
        previous_word_id = None
        for word_id in word_ids:
            if word_id is None:
                aligned.append(-100)
            elif word_id != previous_word_id:
                aligned.append(labels[word_id])
            else:
                aligned.append(-100)
            previous_word_id = word_id
        all_labels.append(aligned)
    tokenized["labels"] = all_labels
    return tokenized

print("Tokenizing train...")
train_tok = train_ds.map(tokenize_and_align, batched=True, remove_columns=train_ds.column_names)
print("Tokenizing val...")
val_tok = val_ds.map(tokenize_and_align, batched=True, remove_columns=val_ds.column_names)
print("Tokenizing test...")
test_tok = test_ds.map(tokenize_and_align, batched=True, remove_columns=test_ds.column_names)
print(f"Tokenized: train={len(train_tok)}, val={len(val_tok)}, test={len(test_tok)}")

In [None]:
# Cell 7: Model setup + two-phase training

import transformers
from transformers import AutoModelForTokenClassification, TrainingArguments, Trainer
from seqeval.metrics import f1_score, precision_score, recall_score

label_list = LABELS
id2label = ID2LABEL
label2id = LABEL2ID

# Temporarily suppress transformers warnings during model load.
# The base model has 25 labels (11 entity types) but we need 29 (14 entity types).
# The classifier head gets re-initialized with the correct size - this is intentional.
transformers.logging.set_verbosity_error()
model = AutoModelForTokenClassification.from_pretrained(
    BASE_MODEL,
    num_labels=len(label_list),
    id2label=id2label,
    label2id=label2id,
    ignore_mismatched_sizes=True,
)
transformers.logging.set_verbosity_warning()
print(f"Model loaded: {BASE_MODEL} with {len(label_list)} labels (classifier head re-initialized)")


def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=2)
    true_labels, true_preds = [], []
    for pred_seq, label_seq in zip(predictions, labels):
        t_labels, t_preds = [], []
        for p, l in zip(pred_seq, label_seq):
            if l == -100:
                continue
            t_labels.append(label_list[l])
            t_preds.append(label_list[p])
        true_labels.append(t_labels)
        true_preds.append(t_preds)
    return {
        "precision": precision_score(true_labels, true_preds),
        "recall": recall_score(true_labels, true_preds),
        "f1": f1_score(true_labels, true_preds),
    }


def freeze_bert_layers(model, layer_indices):
    for idx in layer_indices:
        for param in model.bert.encoder.layer[idx].parameters():
            param.requires_grad = False
    print(f"Froze BERT layers: {layer_indices}")


def unfreeze_all(model):
    for param in model.parameters():
        param.requires_grad = True
    print("Unfroze all layers")


# --- Config ---
OUTPUT_DIR = "/kaggle/working/m2_resume_extractor"
LEARNING_RATE = 2e-5
BATCH_SIZE = 16
WEIGHT_DECAY = 0.01
WARMUP_RATIO = 0.1
FREEZE_LAYERS = [0, 1, 2, 3, 4, 5, 6, 7, 8]
FREEZE_EPOCHS = 2
TOTAL_EPOCHS = 8

# ============================
# Phase 1: Frozen layers 0-8
# ============================
print(f"\n{'='*60}")
print(f"Phase 1: Frozen layers {FREEZE_LAYERS} for {FREEZE_EPOCHS} epochs")
print(f"{'='*60}")

freeze_bert_layers(model, FREEZE_LAYERS)

phase1_args = TrainingArguments(
    output_dir=OUTPUT_DIR + "/phase1",
    eval_strategy="epoch",
    save_strategy="epoch",
    learning_rate=LEARNING_RATE,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    num_train_epochs=FREEZE_EPOCHS,
    weight_decay=WEIGHT_DECAY,
    warmup_ratio=WARMUP_RATIO,
    fp16=True,
    logging_steps=50,
    save_total_limit=1,
    report_to="none",
)

trainer = Trainer(
    model=model,
    args=phase1_args,
    train_dataset=train_tok,
    eval_dataset=val_tok,
    compute_metrics=compute_metrics,
)

trainer.train()
phase1_results = trainer.evaluate()
print(f"Phase 1 results: {phase1_results}")

# ============================
# Phase 2: All layers unfrozen
# ============================
remaining_epochs = TOTAL_EPOCHS - FREEZE_EPOCHS

print(f"\n{'='*60}")
print(f"Phase 2: All layers unfrozen for {remaining_epochs} epochs")
print(f"{'='*60}")

unfreeze_all(model)

phase2_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    eval_strategy="epoch",
    save_strategy="epoch",
    learning_rate=LEARNING_RATE * 0.5,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    num_train_epochs=remaining_epochs,
    weight_decay=WEIGHT_DECAY,
    warmup_ratio=WARMUP_RATIO,
    fp16=True,
    load_best_model_at_end=True,
    metric_for_best_model="f1",
    greater_is_better=True,
    logging_steps=50,
    save_total_limit=2,
    report_to="none",
)

trainer = Trainer(
    model=model,
    args=phase2_args,
    train_dataset=train_tok,
    eval_dataset=val_tok,
    compute_metrics=compute_metrics,
)

trainer.train()
print("Training complete!")

In [None]:
# Cell 8: Evaluate on test set
print("Evaluating on test set...")
results = trainer.evaluate(test_tok)
print(f"\nTest Results:")
print(f"  Precision: {results.get('eval_precision', 0):.4f}")
print(f"  Recall:    {results.get('eval_recall', 0):.4f}")
print(f"  F1:        {results.get('eval_f1', 0):.4f}")

TARGET_F1 = 0.90
test_f1 = results.get('eval_f1', 0)
if test_f1 >= TARGET_F1:
    print(f"\nTarget F1 {TARGET_F1:.2f} ACHIEVED (got {test_f1:.4f})")
else:
    print(f"\nTarget F1 {TARGET_F1:.2f} NOT MET (got {test_f1:.4f})")

In [None]:
# Cell 9: Save model and download
import shutil

trainer.save_model(OUTPUT_DIR)
tokenizer.save_pretrained(OUTPUT_DIR)
print(f"Model saved to {OUTPUT_DIR}")

# List saved files
for f in sorted(Path(OUTPUT_DIR).iterdir()):
    if f.is_file():
        size_mb = f.stat().st_size / 1e6
        print(f"  {f.name}: {size_mb:.1f} MB")

# Zip for download
shutil.make_archive("/kaggle/working/m2_resume_extractor_trained", "zip", OUTPUT_DIR)
zip_size = Path("/kaggle/working/m2_resume_extractor_trained.zip").stat().st_size / 1e6
print(f"\nZipped to /kaggle/working/m2_resume_extractor_trained.zip ({zip_size:.1f} MB)")
print("\nDownload from the Output tab on the right sidebar.")

---

## Optional: yashpwr Weak Supervision NER

**Skip this section unless you want to add ~20K extra weakly-labeled sequences.**

The yashpwr dataset has a `messages` column (chat/instruction format) instead of BIO-tagged
`tokens`/`ner_tags`. This cell extracts resume text from the user message and applies regex-based
weak supervision to auto-generate BIO tags.

**To use:** Run the code cell below BEFORE Cell 5 (Clean & split), then re-run Cells 5-9.

**Entity detection:** EMAIL, PHONE, LOCATION, DEGREE, SKILLS (section-aware), DESIGNATION,
COMPANY, COLLEGE_NAME, GRADUATION_YEAR, CERTIFICATION, YEARS_OF_EXPERIENCE.

**Label quality:** ~70-80% accuracy. Mixing with gold-standard data improves entity coverage.

---

## Appendix: Using yashpwr Data for Other Models

The yashpwr dataset (22,855 resumes in chat format) contains rich structured resume text that
can be reformatted for other models in the pipeline beyond M2.

---

### M3 (Skills Comparator) - Skill Co-occurrence Pairs

**What M3 needs:** Triplets of (anchor_skill, positive_skill, negative_skill) for contrastive learning.

**How yashpwr helps:** Each resume's Skills section lists skills that co-occur in real professionals.
Skills listed together = positive pair. Skills from different domains = negative.

```python
# Pseudocode: extract skill co-occurrence from yashpwr
triplets = []
for resume in yashpwr_resumes:
    skills = extract_skills_section(resume)  # ["Python", "SQL", "ML"]
    for i, anchor in enumerate(skills):
        for j, positive in enumerate(skills):
            if i != j:
                negative = random_skill_from_different_domain()
                triplets.append({"anchor": anchor, "positive": positive, "negative": negative})
```

**Expected yield:** ~50K-100K skill triplets from real resume co-occurrence.

---

### M4 (Exp/Edu Comparator) - Resume Feature Extraction

**What M4 needs:** (resume_features, jd_features) pairs with match scores.

**How yashpwr helps:** The assistant summary + structured resume text provides:
- Years of experience, education level/field, job titles, domain
- Can synthesize matching/mismatching JD pairs with varying quality scores

```python
# Pseudocode: generate M4 pairs from yashpwr
for resume in yashpwr_resumes:
    features = extract_features(resume)  # years, edu, titles, skills
    jd_match = synthesize_matching_jd(features)      # label ~0.85
    jd_mismatch = synthesize_mismatching_jd(features) # label ~0.2
```

**Expected yield:** ~45K pairs.

---

### M5 (Judge) - Score Calibration

**What M5 needs:** Combined M3+M4 scores mapped to overall match quality.

**How yashpwr helps:** Assistant summaries provide implicit quality signals.
Best used AFTER M3/M4 are trained (needs their outputs as input features).

---

### Priority Table

| Model | Effort | Impact | Priority |
|-------|--------|--------|----------|
| **M2** (done above) | Low | High - adds ~20K sequences | Already implemented |
| **M3** | Medium | Medium - adds skill co-occurrence | Second priority |
| **M4** | Medium | Low - already has good data | Nice to have |
| **M5** | High | Low - needs M3/M4 first | Future work |

In [None]:
# Optional: yashpwr Weak Supervision - extract NER from chat/instruction format
# Run this BEFORE Cell 5 (Clean & split), then re-run Cells 5-9.

def load_yashpwr_weak_supervision():
    """Parse yashpwr chat messages into weakly-labeled NER sequences."""
    parquet_path = DATA_DIR / "yashpwr_resume_ner" / "train.parquet"
    if not parquet_path.exists():
        print("yashpwr parquet not found -- skipping weak supervision.")
        return []

    df = pd.read_parquet(parquet_path)
    if "messages" not in df.columns:
        print(f"yashpwr has no 'messages' column (cols: {df.columns.tolist()}) -- skipping.")
        return []

    print(f"yashpwr: {len(df)} rows with chat messages. Applying weak supervision...")

    SECTION_PATTERNS = {
        "skills": re.compile(
            r"^(Skills|Technical Skills|Core Competencies|Areas of Expertise|"
            r"Skill Highlights|Additional Skills|Computer Skills|Software Skills)\b",
            re.IGNORECASE | re.MULTILINE
        ),
        "education": re.compile(
            r"^(Education|Academic Background|Educational Background|Qualifications)\b",
            re.IGNORECASE | re.MULTILINE
        ),
        "experience": re.compile(
            r"^(Professional Experience|Work Experience|Experience|Employment|"
            r"Employment History|Work History|Professional Background|Career History)\b",
            re.IGNORECASE | re.MULTILINE
        ),
        "certifications": re.compile(
            r"^(Certifications?|Licenses?|Professional Certifications?|"
            r"Licenses? and Certifications?)\b",
            re.IGNORECASE | re.MULTILINE
        ),
    }

    EMAIL_RE = re.compile(r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b")
    PHONE_RE = re.compile(r"(?:\+?1[-.\s]?)?\(?\d{3}\)?[-.\s]?\d{3}[-.\s]?\d{4}")
    YEARS_EXP_RE = re.compile(r"\b(\d{1,2})\+?\s*(?:years?|yrs?)\s*(?:of\s+)?(?:experience)?\b", re.IGNORECASE)
    YEAR_RE = re.compile(r"\b(19[89]\d|20[0-2]\d)\b")
    LOCATION_RE = re.compile(r"\b([A-Z][a-z]+(?:\s[A-Z][a-z]+)?)\s*,\s*([A-Z]{2})\b")
    COMPANY_RE = re.compile(r"Company\s+Name\s+(.+?)(?:\s+City\s*,|\s*$)", re.IGNORECASE | re.MULTILINE)
    DEGREE_RE = re.compile(
        r"\b(Ph\.?D\.?|M\.?D\.?|J\.?D\.?|M\.?B\.?A\.?|M\.?S\.?|B\.?S\.?|B\.?A\.?|M\.?A\.?|"
        r"Bachelor(?:'?s)?(?:\s+of\s+\w+)?|Master(?:'?s)?(?:\s+of\s+\w+)?|"
        r"Associate(?:'?s)?(?:\s+of\s+\w+)?|Doctorate|Doctor of)\b",
        re.IGNORECASE
    )
    TITLE_RE = re.compile(
        r"\b((?:Senior|Junior|Lead|Chief|Head|Principal|Staff|Associate|Assistant|Executive|"
        r"Vice President|VP|Director|Manager|Coordinator|Specialist|Analyst|Engineer|"
        r"Developer|Designer|Consultant|Administrator|Supervisor|Officer|Architect|"
        r"Technician|Representative|Advisor|Strategist|Planner)"
        r"(?:\s+(?:of|for))?"
        r"(?:\s+\w+){0,3})"
        r"(?=\s+(?:January|February|March|April|May|June|July|August|September|October|November|December|\d{4}|Company|$))",
        re.IGNORECASE | re.MULTILINE
    )

    records = []
    skipped = 0

    for idx, row in df.iterrows():
        messages = row["messages"]
        if not isinstance(messages, (list, np.ndarray)):
            skipped += 1
            continue

        resume_text = ""
        for msg in messages:
            if isinstance(msg, dict) and msg.get("role") == "user":
                content = msg.get("content", "")
                for separator in [
                    "following resume:\n\n", "following resume:\n",
                    "this resume:\n\n", "this resume:\n",
                    "resume:\n\n", "resume:\n",
                ]:
                    if separator in content:
                        resume_text = content.split(separator, 1)[1]
                        break
                if not resume_text and len(content) > 200:
                    resume_text = content
                break

        if len(resume_text) < 50:
            skipped += 1
            continue

        resume_text = resume_text[:3000]
        tokens, offsets = _tokenize_with_offsets(resume_text)
        if len(tokens) < 10:
            skipped += 1
            continue

        spans = []

        for m in EMAIL_RE.finditer(resume_text):
            spans.append({"start": m.start(), "end": m.end(), "label": "EMAIL"})
        for m in PHONE_RE.finditer(resume_text):
            spans.append({"start": m.start(), "end": m.end(), "label": "PHONE"})
        for m in YEARS_EXP_RE.finditer(resume_text):
            spans.append({"start": m.start(), "end": m.end(), "label": "YEARS_OF_EXPERIENCE"})
        for m in LOCATION_RE.finditer(resume_text):
            spans.append({"start": m.start(), "end": m.end(), "label": "LOCATION"})
        for m in COMPANY_RE.finditer(resume_text):
            company = m.group(1).strip()
            if len(company) > 2:
                spans.append({"start": m.start(1), "end": m.start(1) + len(company), "label": "COMPANY"})
        for m in DEGREE_RE.finditer(resume_text):
            spans.append({"start": m.start(), "end": m.end(), "label": "DEGREE"})

        # Section-aware tagging
        sections = []
        for sec_name, sec_re in SECTION_PATTERNS.items():
            for m in sec_re.finditer(resume_text):
                sections.append((m.start(), sec_name))
        sections.sort(key=lambda x: x[0])

        for i, (sec_start, sec_name) in enumerate(sections):
            sec_end = sections[i + 1][0] if i + 1 < len(sections) else len(resume_text)
            sec_text = resume_text[sec_start:sec_end]

            if sec_name == "skills":
                lines = sec_text.split("\n")[1:]
                content = " ".join(lines).strip()
                skill_items = re.split(r"[,;|]|\band\b", content)
                for item in skill_items:
                    item = item.strip().strip(".")
                    if 2 < len(item) < 50 and item and not item[0].isdigit():
                        try:
                            for m in re.finditer(re.escape(item), resume_text[sec_start:sec_end]):
                                spans.append({"start": sec_start + m.start(), "end": sec_start + m.end(), "label": "SKILLS"})
                                break
                        except re.error:
                            continue

            elif sec_name == "education":
                for m in YEAR_RE.finditer(sec_text):
                    spans.append({"start": sec_start + m.start(), "end": sec_start + m.end(), "label": "GRADUATION_YEAR"})
                for m in re.finditer(
                    r"((?:University|College|Institute|School|Academy|"
                    r"Polytechnic|Conservatory)(?:\s+of)?\s+[\w\s]+?)(?:\s*[-,\n]|$)",
                    sec_text, re.IGNORECASE
                ):
                    name = m.group(1).strip()
                    if len(name) > 5:
                        spans.append({"start": sec_start + m.start(1), "end": sec_start + m.start(1) + len(name), "label": "COLLEGE_NAME"})

            elif sec_name == "certifications":
                for line in sec_text.split("\n")[1:]:
                    line = line.strip()
                    if len(line) > 5 and not re.match(r"^\d{4}", line):
                        try:
                            for m in re.finditer(re.escape(line[:80]), resume_text):
                                spans.append({"start": m.start(), "end": m.end(), "label": "CERTIFICATION"})
                                break
                        except re.error:
                            continue

        for m in TITLE_RE.finditer(resume_text):
            title = m.group(1).strip()
            if len(title) > 3:
                spans.append({"start": m.start(1), "end": m.start(1) + len(title), "label": "DESIGNATION"})

        if not spans:
            skipped += 1
            continue

        spans.sort(key=lambda s: (s["start"], -(s["end"] - s["start"])))
        filtered_spans = []
        last_end = -1
        for s in spans:
            if s["start"] >= last_end:
                filtered_spans.append(s)
                last_end = s["end"]

        tags = _bio_tags_from_char_spans(tokens, offsets, filtered_spans)

        for i in range(0, len(tokens), 128):
            ct = tokens[i:i + 128]
            ctags = tags[i:i + 128]
            if len(ct) >= 5 and any(t != "O" for t in ctags):
                records.append({"tokens": ct, "ner_tags": _label_to_id(ctags), "source": "yashpwr_weak"})

    entity_counts = {}
    for rec in records:
        for tag_id in rec["ner_tags"]:
            label = ID2LABEL.get(tag_id, "O")
            if label != "O" and label.startswith("B-"):
                etype = label[2:]
                entity_counts[etype] = entity_counts.get(etype, 0) + 1

    print(f"\nyashpwr weak supervision results:")
    print(f"  Processed: {len(df) - skipped} resumes (skipped {skipped})")
    print(f"  Generated: {len(records)} sequences")
    print(f"  Entity counts (B- tags):")
    for etype, count in sorted(entity_counts.items(), key=lambda x: -x[1]):
        print(f"    {etype}: {count}")

    return records


# --- Uncomment below to add yashpwr weak supervision to all_records ---
# yashpwr_weak = load_yashpwr_weak_supervision()
# all_records.extend(yashpwr_weak)
# print(f"\nTotal after yashpwr: {len(all_records)} sequences")
# # Then re-run Cells 5-9