In [20]:
!pip install transformers
!pip install pycryptodome



In [21]:
import torch
import os
import re
from Crypto.Cipher import AES
from transformers import AutoTokenizer, AutoModelForTokenClassification

pii_mapping = {}  # store ciphertext -> key/iv for later decryption

model_name = "iiiorg/piiranha-v1-detect-personal-information"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForTokenClassification.from_pretrained(model_name)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

def mask_pii(text, aggregate_redaction=True):
    # Tokenize input text
    inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
    inputs = {k: v.to(device) for k, v in inputs.items()}

    # Get the model predictions
    with torch.no_grad():
        outputs = model(**inputs)

    # Get the predicted labels
    predictions = torch.argmax(outputs.logits, dim=-1)

    # Convert token predictions to word predictions
    encoded_inputs = tokenizer.encode_plus(text, return_offsets_mapping=True, add_special_tokens=True)
    offset_mapping = encoded_inputs['offset_mapping']

    masked_text = list(text)
    is_redacting = False
    redaction_start = 0
    current_pii_type = ''

    for i, (start, end) in enumerate(offset_mapping):
        if start == end:  # Special token
            continue

        label = predictions[0][i].item()
        if label != model.config.label2id['O']:  # Non-O label
            pii_type = model.config.id2label[label]
            if not is_redacting:
                is_redacting = True
                redaction_start = start
                current_pii_type = pii_type
            elif not aggregate_redaction and pii_type != current_pii_type:
                # End current redaction and start a new one
                apply_redaction(masked_text, redaction_start, start, current_pii_type, aggregate_redaction)
                current_pii_type = pii_type
        else:
            if is_redacting:
                apply_redaction(masked_text, redaction_start, end, current_pii_type, aggregate_redaction)
                is_redacting = False


    # Handle case where PII is at the end of the text
    if is_redacting:
        apply_redaction(masked_text, redaction_start, len(masked_text), current_pii_type, aggregate_redaction)

    return ''.join(masked_text)

def apply_redaction(masked_text, start, end, pii_type, aggregate_redaction):
    original_text = ''.join(masked_text[start:end])   # <-- Grab original substring here
    #print(f"Detected PII ({pii_type}): {original_text}")

    for j in range(start, end):
        masked_text[j] = ''
    if aggregate_redaction:
        key, iv, ciphertext = encrypt_text(original_text)
        if key and iv:
           #print("Text encrypted successfully.")
           #print("Key:", key)
           #print("IV:", iv)
           masked_text[start] = f"[{ciphertext}]"
           decrypt_text(key, iv, ciphertext)
           # Store mapping for later decryption
           pii_mapping[ciphertext] = {
                "key": key,
                "iv": iv,
                "pii_type": pii_type,
                "original": original_text
            }

           #print("File decrypted successfully.")

        else:
           print("Encryption failed.") #instead of redacting the pii we can encrypt it
    else:
        masked_text[start] = f'[{pii_type}]'


#=============================== Encryption and Decryption ===============================
#AES Encryption
#First, we need to pad
def pad(data):
  data = data.encode("utf-8")
  padding_length = 16 - len(data) % 16
  padding = bytes([padding_length] * padding_length)
  return data + padding

def unpad(data):
    padding_length = data[-1]
    if padding_length < 1 or padding_length > 16:
        raise ValueError("Invalid padding encountered")
    return data[:-padding_length]

#AES256 Encryption
def encrypt_text(input_text):
  key = os.urandom(32)
  iv = os.urandom(16)
  cipher = AES.new(key, AES.MODE_CBC, iv)
  try:
    plaintext = input_text
    padded_plaintext = pad(plaintext)
    ciphertext = cipher.encrypt(padded_plaintext)
    #print("The ciphertext is ", ciphertext.hex())
    encoded_key = key.hex()
    encoded_iv = iv.hex()
    return encoded_key, encoded_iv, ciphertext.hex()
  except Exception as e:
        print(f"An error occurred during encryption: {e}")
        return None, None

#AES Decryption
def decrypt_text(key, iv, input_text):
    try:
        decoded_key = bytes.fromhex(key)
        iv_bytes = bytes.fromhex(iv)
        encrypted_data = bytes.fromhex(input_text)
        if len(decoded_key) != 32:
            raise ValueError("Incorrect AES key length")
        cipher = AES.new(decoded_key, AES.MODE_CBC, iv_bytes)
        decrypted_data = unpad(cipher.decrypt(encrypted_data))
        #print("Decrypted data:", decrypted_data.decode("utf-8"))
        return decrypted_data.decode("utf-8")
    except Exception as e:
        print(f"An error occurred during decryption: {e}")


def decrypt_text(key, iv, input_text):
    try:
        decoded_key = bytes.fromhex(key)
        iv_bytes = bytes.fromhex(iv)
        encrypted_data = bytes.fromhex(input_text)
        cipher = AES.new(decoded_key, AES.MODE_CBC, iv_bytes)
        decrypted_data = unpad(cipher.decrypt(encrypted_data))
        return decrypted_data.decode("utf-8")
    except Exception as e:
        print(f"Decryption error: {e}")
        return "[DECRYPTION_FAILED]"

# ------------------- Restore full sentence -------------------
def restore_plaintext(masked_sentence):
    def decrypt_match(match):
        ciphertext = match.group(1)
        if ciphertext in pii_mapping:
            key = pii_mapping[ciphertext]["key"]
            iv = pii_mapping[ciphertext]["iv"]
            return decrypt_text(key, iv, ciphertext)
        else:
            return "[UNKNOWN]"
    pattern = r"\[([0-9a-fA-F]+)\]"
    restored_sentence = re.sub(pattern, decrypt_match, masked_sentence)
    return restored_sentence



In [22]:
# Example usage
example_text = "My name is Obee Nobi and I live at 432423 Deka St, Tanooti. My phone number is 5455-123-4567."

print("Aggregated Encryption:")
masked_example_aggregated = mask_pii(example_text, aggregate_redaction=True)
print(masked_example_aggregated)

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


Aggregated Encryption:
My name is[3a70ad0d4b429f784bc24b21d4ae4d6d] I live at[ffe4d8458aa84beeef33df3ad73e8e9a],[4c21bf15edcef61a743c232bffcd8fdd] My phone number is[252206487bac3d1a130c42d4f575bfd3]


In [24]:
#Decryption
restored = restore_plaintext("My name is[3a70ad0d4b429f784bc24b21d4ae4d6d] I live at[ffe4d8458aa84beeef33df3ad73e8e9a],[4c21bf15edcef61a743c232bffcd8fdd] My phone number is[252206487bac3d1a130c42d4f575bfd3]")
print("\nRestored Plaintext:", restored)


Restored Plaintext: My name is Obee Nobi and I live at 432423 Deka St, Tanooti. My phone number is 5455-123-4567.
