In [None]:
import re
import torch
from transformers import AutoTokenizer, AutoModelForTokenClassification
from typing import List, Tuple

# 1. 라벨 매핑
LABEL_LIST = [
    "O",
    "B-이름", "I-이름",
    "B-주민번호", "I-주민번호",
    "B-전화번호", "I-전화번호",
    "B-이메일", "I-이메일",
    "B-카드번호", "I-카드번호"
]
id2label = {i: label for i, label in enumerate(LABEL_LIST)}

# 2. 모델 & 토크나이저 로드
model_path = "./ner_model"
model = AutoModelForTokenClassification.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path)
model.eval()

# 3. NER 추론 함수
def ner_predict(text: str) -> List[Tuple[str, str]]:
    tokens = list(text)
    tokenized = tokenizer(tokens, is_split_into_words=True, return_tensors="pt", truncation=True)
    with torch.no_grad():
        output = model(**tokenized)
    predictions = output.logits.argmax(dim=-1).squeeze().tolist()
    word_ids = tokenized.word_ids()
    merged = []
    prev_word_id = None
    for idx, wid in enumerate(word_ids):
        if wid is None or wid == prev_word_id:
            continue
        merged.append((tokens[wid], id2label[predictions[idx]]))
        prev_word_id = wid
    return merged

# 4. 엔티티 병합 함수
def merge_entities(tagged_tokens: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
    result = []
    current_tag = None
    current_text = ""
    for char, tag in tagged_tokens:
        if tag.startswith("B-"):
            if current_tag:
                result.append((current_text, current_tag))
            current_tag = tag[2:]
            current_text = char
        elif tag.startswith("I-") and current_tag == tag[2:]:
            current_text += char
        else:
            if current_tag:
                result.append((current_text, current_tag))
                current_tag = None
            current_text = ""
    if current_tag:
        result.append((current_text, current_tag))
    return result

# 5. 마스킹 포맷 정의
def mask_entity(text: str, label: str) -> str:
    if label == "이름":
        return text[0] + "**"
    elif label == "주민번호":
        return re.sub(r"\d{6}-\d{7}", lambda m: m.group(0)[:6] + "-*******", text)
    elif label == "전화번호":
        return re.sub(r"\d{2,3}-\d{3,4}-\d{4}", lambda m: m.group(0)[:3] + "-****-" + m.group(0)[-4:], text)
    elif label == "이메일":
        local, _, domain = text.partition("@")
        return local[0] + "***@" + domain
    elif label == "카드번호":
        return re.sub(r"\d{4}-\d{4}-\d{4}-\d{4}", lambda m: m.group(0)[:4] + "-****-****-" + m.group(0)[-4:], text)
    else:
        return text

# 6. 정규표현식 보완 마스킹 (NER 탐지 누락 대비)
def regex_based_mask(text: str) -> str:
    text = re.sub(r"\d{6}-\d{7}", lambda m: m.group(0)[:6] + "-*******", text)  # 주민번호
    text = re.sub(r"\d{2,3}-\d{3,4}-\d{4}", lambda m: m.group(0)[:3] + "-****-" + m.group(0)[-4:], text)  # 전화번호
    text = re.sub(r"\b([a-zA-Z0-9._%+-]+)@([a-zA-Z0-9.-]+\.[a-zA-Z]{2,})\b", lambda m: m.group(1)[0] + "***@" + m.group(2), text)  # 이메일
    text = re.sub(r"\b(\d{4})-(\d{4})-(\d{4})-(\d{4})\b", lambda m: f"{m.group(1)}-****-****-{m.group(4)}", text)  # 카드번호
    return text

# 7. 통합 마스킹 함수 (NER + 정규표현식)
def mask_text(text: str) -> str:
    tagged = ner_predict(text)
    entities = merge_entities(tagged)
    masked_text = text
    already_masked = set()
    for original, label in entities:
        if original in already_masked:
            continue
        masked = mask_entity(original, label)
        masked_text = masked_text.replace(original, masked, 1)
        already_masked.add(original)
    final_text = regex_based_mask(masked_text)
    return final_text