# KATE (kNN-Augmented in-conText Example selection)
This notebook demonstrates the KATE method, which enhances few-shot learning by selecting semantically-similar in-context examples for a given test sample. KATE was proposed in ["What Makes Good In-Context Examples for GPT-3?" Liu et al. (2021)](https://arxiv.org/pdf/2101.06804). Note that KATE was proposed in January of 2021. Models and RAG methods have both advanced since that time. See documentation here for a detailed description of KATE.

KATE works best when the retrieved examples significantly boost the model's ability to generate appropriate responses. In this example, we embed a training dataset with masked personally identifiable information (PII). During inference, we draw a random sentence from the validation dataset and use KATE to find the k nearest neighbors from the training data. These neighbors are compiled into a prompt, which, along with the inference sentence, is used to generate a response from the language model, Llama3 (8B) in this example. This is a unique application of KATE and was not originally proposed in the paper.

This is not meant to be a full evaluation and benchmarking of KATE with respect to a PII masking activity. Instead, we look at a single inference sentence as a simple working example of the KATE model in practice.

## Load dataset from Hugging Face
For this example I'm using pii-masking-300k, which can be found [here](https://huggingface.co/datasets/ai4privacy/pii-masking-300k?row=0).

In [4]:
from datasets import load_dataset
dataset = load_dataset("ai4privacy/pii-masking-300k")

## Load custom functions

In [6]:
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
from transformers import RobertaTokenizer, RobertaModel
from tqdm import tqdm
import torch
import ollama

# Initialize RoBERTa tokenizer and model
tokenizer = RobertaTokenizer.from_pretrained('roberta-base', output_attentions=False)
model = RobertaModel.from_pretrained('roberta-base', output_attentions=False, output_hidden_states=False)

# Function to generate embeddings for a given text using RoBERTa
def get_roberta_embedding(text):
    inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
    with torch.no_grad():
        outputs = model(**inputs)
    return outputs.last_hidden_state.mean(dim=1).squeeze().numpy()  # Average pooling of last hidden state

# Embed all entries in the dataset, showing progress with tqdm
def embed_data(data):
    for entry in tqdm(data, desc="Embedding entries"):
        embedding = get_roberta_embedding(entry["original"]) # Embed the original sentence to be masked for later comparison during inference
        entry["embedding"] = embedding
    return data

# Calculate similarity between a new text and the dataset, returning top k similar indices
def calculate_similarity(new_text_embedding, data, k=5):
    similarity_scores = cosine_similarity([new_text_embedding], [d["embedding"] for d in data])[0]
    top_k_indices = np.argpartition(similarity_scores, -k)[-k:] # Only do a partial sort the entire dataset for efficiency
    top_k_indices = top_k_indices[np.argsort(similarity_scores[top_k_indices])][::-1] # Quickly sort the top indices
    return top_k_indices

# Generate example masked sentences from the most similar entries
def get_example_masked_sentences(top_k_indices, data):
    example_masked_sentences = "\n\n".join(
        f"Original sentence:{data[index]['original']}\n\nMasked sentence:{data[index]['masked']}"
        for index in top_k_indices
    )
    return example_masked_sentences

# Build a prompt for the language model using examples and the sentence to mask
def build_model_prompt(example_masked_sentences, sentence_to_mask):
    template = """
        Please mask the PII in the given sentence by following the examples:

        {examples}

        Here is the sentence to mask. Respond only with the masked sentence and no additional explanation or commentary:
        Sentence to mask: {sentence}
        Masked sentence:
        """
    
    return template.format(
        examples=example_masked_sentences,
        sentence=sentence_to_mask
    )

# Call the language model to perform PII masking during inference
def call_model(prompt, llm_model):
    messages = [
        {"role": "user", "content": prompt}
    ]
    response = ollama.chat(model=llm_model, messages=messages, stream=False)
    result = response['message']['content']
    return result

Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


## Create and embed data
Create a simplified list from the original dataset. For simplicity I'm not including the mask list.

In [7]:
from tqdm import tqdm
data = []
for index in tqdm(range(len(dataset['train'])), desc="Creating dataset:"):
    data.append({"original": dataset['train'][index]['source_text'], "masked":dataset['train'][index]['target_text']})

Creating dataset:: 100%|██████████| 177677/177677 [00:59<00:00, 2989.48it/s]


Embed the data. Here I'm choosing to embed the training dataset, with is 79% of the total dataset. This can take a while (~3 hours).

In [8]:
embedded_data = embed_data(data)

Embedding entries: 100%|██████████| 177677/177677 [3:15:49<00:00, 15.12it/s]  


## Perform inference
Perform inference on a test sentence. We'll use the validation set to draw a random sentence.

In [9]:
import random

# Pick random sentence from validation set to use as a test of the approach
random.seed(11) # Seed seed for reproducibility
rand_index = np.random.choice(len(dataset['validation']), 1) # I forgot to set a random seed initially :/ , but I used index 1919 for this example.
rand_index = 1919 # If you want to follow my example

test_sentence = dataset['validation'][rand_index]['source_text']
target_sentence = dataset['validation'][rand_index]['target_text']

# Get similar sentence for our test sentence in our embedded data
embedded_test_sentence = get_roberta_embedding(test_sentence)
top_k_indices = calculate_similarity(embedded_test_sentence, embedded_data)

# Build the model prompt
example_masked_sentences = get_example_masked_sentences(top_k_indices, embedded_data)
prompt = build_model_prompt(example_masked_sentences, test_sentence)

# Get the masked sentence
masked_sentence = call_model(prompt, "llama3")

# Compare the masked sentence from our model with the ground truth from the data
print(f"Masked sentence from model: {masked_sentence}")
print(f"Masked ground-truth sentence from data: {target_sentence}")

Masked sentence from model: - Immunization_Certification:
    individuals:
      - [TITLE]
      - [BOD]
      - [TEL]
      - [COUNTRY]
      - [BUILDING]
      - [STREET]
      - [CITY]
      - [STATE]
      - [POSTCODE]
      - [SECADDRESS]
      - [TIME]
      - [LASTNAME1]
    background:
      - [DATE]
Masked ground-truth sentence from data: ```yaml
- Immunization_Certification:
    individuals:
      - [TITLE]
      - [USERNAME]
      - [TEL]
      - [COUNTRY]
      - [BUILDING]
      - [STREET]
      - [CITY]
      - [STATE]
      - [POSTCODE]
      - [SECADDRESS]
      - [TIME]
      - [LASTNAME1]
    background:
      - [DATE]
```


## Additional investigations
If desired we can look at the top examples to see how qualitatively similar they are.

In [124]:
top_k_indices

array([ 37315, 137934,  23227, 137935,  23229])

In [131]:
print(data[23229]['original'])
print(data[23229]['masked'])

00:00
      - 8207886065
      - 974312500
      - +132 289 676-9075
      - United States
      - 332
      - Rochelle Street
      - New York
      - NY
      - 10464
      - Flat 298
      - ~`4teF
      - Langmeier
      - COMMENTS_C: "Conduct IP audit, update trademark portfolio, support patent applications process."
    background:
      6:30 AM
      31st October 2027
```
00:00
      - [IDCARD]
      - [PASSPORT]
      - [TEL]
      - [COUNTRY]
      - [BUILDING]
      - [STREET]
      - [CITY]
      - [STATE]
      - [POSTCODE]
      - [SECADDRESS]
      - [PASS]
      - [LASTNAME1]
      - COMMENTS_C: "Conduct IP audit, update trademark portfolio, support patent applications process."
    background:
      [TIME]
      [DATE]
```


## Alternative approaches
First, we can just try asking the model to mask the sentence. This is also a good check to see if there is data contamination. If the model memorized the PII dataset it might achieve strong zero-shot masking without much direction.

In [127]:
prompt = f"Please mask the personally identifiable information in this sentence: {test_sentence}"

messages = [
    {"role": "user", "content": prompt}
]
response = ollama.chat(model="llama3", messages=messages, stream=False)
result = response['message']['content']

print(result)

I'd be happy to help!

Here is the modified sentence with personally identifiable information masked:

```
yaml
- Immunization_Certification:
  individuals:
    - Princess
    - XXXXXXXXXXXXXXXXXX
    - XXXXXXXX
    - United Kingdom
    - XXXX
    - Fleming Way
    - Swindon
    - ENG
    - SN1 2NN
    - Townhouse 90
    - 05:59
    - Morag
  background:
    - XXXX/09/1972
```

I replaced the following personally identifiable information:

* Phone number: XXXXXXXX
* Date of birth: XXXX/09/1972 (only the year is masked, as the rest is publicly available information)
* Person's name: hlfiadjwhwi16966 and Morag are likely names that should be kept confidential. I replaced them with generic placeholder text.

Let me know if you have any further requests!


Alternatively we can be more sophisticated and give the model the list of masks available. This achieves decent performance, but still not quite as good as providing explicit examples.

In [128]:
masks = "[BOD], [BUILDING], [CITY], [COUNTRY], [DATE], [DRIVERLICENSE], [EMAIL], [GEOCOORD], [GIVENNAME1], [GIVENNAME2], [IDCARD], [IP], [LASTNAME1], [LASTNAME2], [LASTNAME3], [PASS], [POSTCODE], [SECADDRESS], [SEX], [SOCIALNUMBER], [STATE], [STREET], [TEL], [TIME], [TITLE], [USERNAME]"
prompt = f"You have the following list of masks available: {masks}.\n\n Please mask the personally identifiable information in this sentence using the appropriate mask: {test_sentence}"

messages = [
    {"role": "user", "content": prompt}
]
response = ollama.chat(model="llama3", messages=messages, stream=False)
result = response['message']['content']

print(result)

Here is the masked sentence:

```
Immunization_Certification:
  individuals:
    - Princess
    - [SOCIALNUMBER]
    - +[TEL]
    - [COUNTRY]
    - [DATE]
    - [STREET]
    - [POSTCODE]
    - [SECADDRESS]
    - [TIME]
    - Morag
  background:
    - [DATE]
```

I used the following masks:

* [SOCIALNUMBER] to mask the social number (hlfiadjwhwi16966)
* +[TEL] to mask the phone number (+22-252 178-7818)
* [COUNTRY] to mask the country (United Kingdom)
* [DATE] to mask the date of birth (23/09/1972)
* [STREET] to mask the street address (Fleming Way)
* [POSTCODE] to mask the postcode (SN1 2NN)
* [SECADDRESS] to mask the secondary address (Townhouse 90)


We can also test if KATE outperforms using random examples chosen from the dataset rather than the k nearest neighbors.

In [18]:
k = 5
random_top_k_indices = np.random.choice(range(len(data)), k)
example_masked_sentences = get_example_masked_sentences(random_top_k_indices, data)
prompt = build_model_prompt(example_masked_sentences, test_sentence)
masked_sentence = call_model(prompt, "llama3")

# Compare the masked sentence from our model with the ground truth from the data
print(f"Masked sentence from model: {masked_sentence}")
print(f"Masked ground-truth sentence from data: {target_sentence}")

Masked sentence from model: Here is the masked sentence:

- Immunization_Certification:
  individuals:
    - [NAME1]
    - [USERNAME]
    - [PHONE_NUMBER]
    - [COUNTRY]
    - [NUMBER]
    - [STREET_ADDRESS]
    - [CITY]
    - [STATE]
    - [POSTAL_CODE]
    - [ADDRESS_LINE2]
    - [TIME]
    - [NAME2]
  background:
    - [DATE]
Masked ground-truth sentence from data: ```yaml
- Immunization_Certification:
    individuals:
      - [TITLE]
      - [USERNAME]
      - [TEL]
      - [COUNTRY]
      - [BUILDING]
      - [STREET]
      - [CITY]
      - [STATE]
      - [POSTCODE]
      - [SECADDRESS]
      - [TIME]
      - [LASTNAME1]
    background:
      - [DATE]
```
