In [1]:
import torch
import torch.nn as nn
from torchcrf import CRF
from transformers import AutoModel, AutoTokenizer
from datasets import load_dataset
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from functools import partial
import uuid
from collections import Counter
from seqeval.metrics import classification_report
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix
import numpy as np
import logging
import re
from faker import Faker

In [None]:
class BiLSTMCRFForTokenClassification(nn.Module):
    def __init__(self, embedding_dim, hidden_dim, num_labels, dropout=0.1):
        super(BiLSTMCRFForTokenClassification, self).__init__()
        self.embedding_dim = embedding_dim
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, bidirectional=True, batch_first=True)
        self.dropout = nn.Dropout(dropout)
        self.fc = nn.Linear(hidden_dim * 2, num_labels)
        self.crf = CRF(num_labels, batch_first=True)

    def forward(self, embeddings, attention_mask=None, labels=None):
        lstm_out, _ = self.lstm(embeddings)
        lstm_out = self.dropout(lstm_out)
        emissions = self.fc(lstm_out)

        mask = attention_mask.bool()

        if labels is not None:
            labels = torch.where(labels == -100, 0, labels)
            log_likelihood = self.crf(emissions, labels, mask=mask)
            return -log_likelihood
        else:
            prediction = self.crf.decode(emissions, mask=mask)
            return prediction

In [4]:
    # Define labels
labels = [
        "BOD", "BUILDING", "CITY", "COUNTRY", "DATE", "DRIVERLICENSE",
        "EMAIL", "GEOCOORD", "GIVENNAME", "IDCARD", "IP", "LASTNAME", "PASS", "PASSPORT", "POSTCODE", "SECADDRESS", "SEX",
        "SOCIALNUMBER", "STATE", "STREET", "TEL", "TIME", "TITLE", "USERNAME"
    ]

labels = [f"I-{label}" for label in labels] + [f"B-{label}" for label in labels] + ["O"]

In [5]:
label2id = {label: i for i, label in enumerate(labels)}
id2label = {v: k for k, v in label2id.items()}

In [9]:
  # Initialize model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
model = BiLSTMCRFForTokenClassification(
        bert_model.config.hidden_size,
        hidden_dim=128,
        num_labels=len(labels)
    ).to(device)

In [None]:
# Load the saved model
model.load_state_dict(torch.load("C:/Users/harsh/Downloads/Project/bertembed.pth"))



In [22]:
def predict_custom_sentences(model, tokenizer, text, id2label, device):
    model.eval()
    with torch.no_grad():
        inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
        inputs = {k: v.to(device) for k, v in inputs.items()}
        predictions = model(inputs["input_ids"], attention_mask=inputs["attention_mask"])

    tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
    label_predictions = [id2label[p] for p in predictions[0]]

    # Merge tokens and labels, considering B and I tags
    merged_tokens = []
    merged_labels = []
    current_word = ""
    current_label = "O"

    for token, label in zip(tokens, label_predictions):
        if token.startswith("##"):
            current_word += token[2:]
            if label.startswith("I-") and current_label.endswith(label[2:]):
                continue
            elif label != "O":
                current_label = label.replace("I-", "").replace("B-", "")
        else:
            if current_word:
                merged_tokens.append(current_word)
                merged_labels.append(current_label)
            current_word = token
            if label.startswith("B-") or (label != "O" and current_label == "O"):
                current_label = label.replace("B-", "").replace("I-", "")
            elif label.startswith("I-") and current_label.endswith(label[2:]):
                continue
            else:
                current_label = "O"

    if current_word:
        merged_tokens.append(current_word)
        merged_labels.append(current_label)

    # Group consecutive tokens with the same label
    results = []
    current_entity = ""
    current_label = "O"

    for token, label in zip(merged_tokens, merged_labels):
        if label != "O":
            if current_label == label:
                current_entity += " " + token
            else:
                if current_entity:
                    results.append((current_entity.strip(), current_label))
                current_entity = token
                current_label = label
        else:
            if current_entity:
                results.append((current_entity.strip(), current_label))
                current_entity = ""
                current_label = "O"

    if current_entity:
        results.append((current_entity.strip(), current_label))

    print("Predictions:")
    for entity, label in results:
        print(f"{entity}: {label}")

    print("\nMasked text:")
    masked_text = text

    # Replace entities in the text with their corresponding labels
    for entity, label in reversed(results):
        # Escape special characters and handle spaces
        entity_pattern = re.escape(entity)
        # Handle spaces explicitly
        entity_pattern = entity_pattern.replace(r'\ ', r'\s*')
        # Replace all occurrences of the entity pattern in the text
        masked_text = re.sub(entity_pattern, f"[{label}]", masked_text, flags=re.IGNORECASE)

    print(masked_text)


In [26]:
# Testing loop
while True:
    custom_text = input("Please enter the document with sensitive information for masking (or type 'q' to quit): ")
    if custom_text.lower() == 'q':
        break
    predict_custom_sentences(model, tokenizer, custom_text, id2label, device)
    print("")

Please enter the document with sensitive information for masking (or type 'q' to quit):  The owner resides at 12 Kensington High Street, Apt 4C, London, W8 5NP. For any official communications, the preferred contact is via john.doe@legalmail.co.uk, or by phone at +44-20-7946-5678. The driver's license number associated with this individual is L875643210, issued by the DVLA in the United Kingdom.


Predictions:
12: BUILDING
Kensington High Street: STREET
Apt 4C: SECADDRESS
London: CITY
W8 5NP: POSTCODE
john . doe @ legalmail . co . uk: EMAIL
+ 44 - 20 - 7946 - 5678: TEL
L875643210: DRIVERLICENSE
United Kingdom: COUNTRY

Masked text:
The owner resides at [BUILDING] [STREET], [SECADDRESS], [CITY], [POSTCODE]. For any official communications, the preferred contact is via [EMAIL], or by phone at [TEL]. The driver's license number associated with this individual is [DRIVERLICENSE], issued by the DVLA in the [COUNTRY].



Please enter the document with sensitive information for masking (or type 'q' to quit):  Mr. Alexander Green, Director of Operations at GlobalTech Solutions, is based in London, United Kingdom. His passport number is GB789456123, and it was issued on 2019-03-12. Mr. Green’s current residential address is Flat 5B, 17 King's Road, Chelsea, and his social security number is 345-67-8912. For urgent matters, Mr. Green may be reached at +44-20-7946-1234 or via his professional email at alex.green@globaltech.co.uk.


Predictions:
GB789456123: PASSPORT
2019 - 03 - 12: DATE
Flat 5B: SECADDRESS
17: BUILDING
King ' s Road: STREET
Chelsea: CITY
345 - 67 - 8912: SOCIALNUMBER
+ 44 - 20 - 7946 - 1234: TEL

Masked text:
Mr. Alexander Green, Director of Operations at GlobalTech Solutions, is based in London, United Kingdom. His passport number is [PASSPORT], and it was issued on [DATE]. Mr. Green’s current residential address is [SECADDRESS], [BUILDING] [STREET], [CITY], and his social security number is [SOCIALNUMBER]. For urgent matters, Mr. Green may be reached at [TEL] or via his professional email at alex.green@globaltech.co.uk.



KeyboardInterrupt: Interrupted by user