In [None]:
import json
import os
import re
import bisect
from pathlib import Path

import torch
import numpy as np
import pandas as pd
from datasets import Dataset
from spacy.lang.en import English
from transformers.models.deberta_v2 import DebertaV2ForTokenClassification, DebertaV2TokenizerFast
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from transformers.trainer import Trainer
from transformers.training_args import TrainingArguments
from transformers.data.data_collator import DataCollatorForTokenClassification

INFERENCE_MAX_LENGTH = 3500
CONF_THRESH = 0.90
URL_THRESH = 0.1
AMP = True
MODEL_PATH = '/kaggle/input/trained-mdeberta-v3-base/mdeberta-v3-base-tok/fold-3'
DATA_DIR = '/kaggle/input/pii-detection-removal-from-educational-data/'

nlp = English()

def find_span(target: list[str], document: list[str]) -> list[list[int]]:
    idx = 0
    spans = []
    span = []

    for i, token in enumerate(document):
        if token != target[idx]:
            idx = 0
            span = []
            continue
        span.append(i)
        idx += 1
        if idx == len(target):
            spans.append(span)
            span = []
            idx = 0
            continue
    
    return spans

def spacy_to_hf(data: dict, idx: int) -> slice:
    str_range = np.where(np.array(data["token_map"]) == idx)[0]
    start_idx = bisect.bisect_left([off[1] for off in data["offset_mapping"]], str_range.min())
    end_idx = start_idx
    while end_idx < len(data["offset_mapping"]):
        if str_range.max() > data["offset_mapping"][end_idx][1]:
            end_idx += 1
            continue
        break
    token_range = slice(start_idx, end_idx+1)
    return token_range

class CustomTokenizer:
    def __init__(self, tokenizer: PreTrainedTokenizerBase, max_length: int) -> None:
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __call__(self, example: dict) -> dict:
        text = []
        token_map = []

        for idx, (t, ws) in enumerate(zip(example["tokens"], example["trailing_whitespace"])):
            text.append(t)
            token_map.extend([idx]*len(t))
            if ws:
                text.append(" ")
                token_map.append(-1)

        tokenized = self.tokenizer(
            "".join(text),
            return_offsets_mapping=True,
            truncation=True,
            max_length=self.max_length,
        )

        return {**tokenized,"token_map": token_map,}

with open(str(Path(DATA_DIR).joinpath("test.json")), "r") as f:
    data = json.load(f)

ds = Dataset.from_dict({
    "full_text": [x["full_text"] for x in data],
    "document": [x["document"] for x in data],
    "tokens": [x["tokens"] for x in data],
    "trailing_whitespace": [x["trailing_whitespace"] for x in data],
})

tokenizer = DebertaV2TokenizerFast.from_pretrained(MODEL_PATH)
ds = ds.map(CustomTokenizer(tokenizer=tokenizer, max_length=INFERENCE_MAX_LENGTH), num_proc=os.cpu_count())

model = DebertaV2ForTokenClassification.from_pretrained(MODEL_PATH)
collator = DataCollatorForTokenClassification(tokenizer)
args = TrainingArguments(".", per_device_eval_batch_size=1, report_to="none", fp16=AMP)
trainer = Trainer(
    model=model, args=args, data_collator=collator, tokenizer=tokenizer,
)

predictions = trainer.predict(ds).predictions
pred_softmax = torch.softmax(torch.from_numpy(predictions), dim=2).numpy()
id2label = model.config.id2label
o_index = model.config.label2id["O"]
preds = predictions.argmax(-1)
preds_without_o = pred_softmax.copy()
preds_without_o[:,:,o_index] = 0
preds_without_o = preds_without_o.argmax(-1)
o_preds = pred_softmax[:,:,o_index]
preds_final = np.where(o_preds < CONF_THRESH, preds_without_o , preds)

processed =[]
pairs = set()

for p, token_map, offsets, tokens, doc in zip(
    preds_final, ds["token_map"], ds["offset_mapping"], ds["tokens"], ds["document"]
):
    for token_pred, (start_idx, end_idx) in zip(p, offsets):
        label_pred = id2label[token_pred]

        if start_idx + end_idx == 0:
            continue

        if token_map[start_idx] == -1:
            start_idx += 1

        while start_idx < len(token_map) and tokens[token_map[start_idx]].isspace():
            start_idx += 1

        if start_idx >= len(token_map): 
            break

        token_id = token_map[start_idx]
        pair = (doc, token_id)

        if label_pred in ("O", "B-EMAIL", "B-URL_PERSONAL", "B-PHONE_NUM", "I-PHONE_NUM") or token_id == -1:
            continue        

        if pair in pairs:
            continue
            
        processed.append(
            {"document": doc, "token": token_id, "label": label_pred, "token_str": tokens[token_id]}
        )
        pairs.add(pair)

url_whitelist = [
    "wikipedia.org",
    "coursera.org",
    "google.com",
    ".gov",
]
url_whitelist_regex = re.compile("|".join(url_whitelist))

for row_idx, _data in enumerate(ds):
    for token_idx, token in enumerate(_data["tokens"]):
        if not nlp.tokenizer.url_match(token):
            continue
        print(f"Found URL: {token}")
        if url_whitelist_regex.search(token) is not None:
            print("The above is in the whitelist")
            continue
        input_idxs = spacy_to_hf(_data, token_idx)
        probs = pred_softmax[row_idx, input_idxs, model.config.label2id["B-URL_PERSONAL"]]
        if probs.mean() > URL_THRESH:
            print("The above is PII")
            processed.append(
                {
                    "document": _data["document"], 
                    "token": token_idx, 
                    "label": "B-URL_PERSONAL", 
                    "token_str": token
                }
            )
            pairs.add((_data["document"], token_idx))
        else:
            print("The above is not PII")

email_regex = re.compile(r'[\w.+-]+@[\w-]+\.[\w.-]+')
phone_num_regex = re.compile(r"(\(\d{3}\)\d{3}\-\d{4}\w*|\d{3}\.\d{3}\.\d{4})\s")
emails = []
phone_nums = []

for _data in ds:
    for token_idx, token in enumerate(_data["tokens"]):
        if re.fullmatch(email_regex, token) is not None:
            emails.append(
                {"document": _data["document"], "token": token_idx, "label": "B-EMAIL", "token_str": token}
            )
    matches = phone_num_regex.findall(_data["full_text"])
    if not matches:
        continue
    for match in matches:
        target = [t.text for t in nlp.tokenizer(match)]
        matched_spans = find_span(target, _data["tokens"])
    for matched_span in matched_spans:
        for intermediate, token_idx in enumerate(matched_span):
            prefix = "I" if intermediate else "B"
            phone_nums.append(
                {"document": _data["document"], "token": token_idx, "label": f"{prefix}-PHONE_NUM", "token_str": _data["tokens"][token_idx]}
            )

df = pd.DataFrame(processed + emails + phone_nums)
df["row_id"] = list(range(len(df)))
df.head(100)

df[["row_id", "document", "token", "label"]].to_csv("submission.csv", index=False)