# Testing the Entity Aware ASR

In [1]:
import os

os.chdir('..')
os.getcwd()

'/Users/farhan/Desktop/Research'

### [Optional] Run these cells when you run out of memory

mps

In [2]:
import torch

torch.mps.empty_cache()

cuda

In [None]:
import torch

torch.cuda.empty_cache()

### Load the data (Audio, Reference Transcripts, and Identified Entities)

test data

In [3]:
test_audio_path = 'data/Audio_Files_for_testing'
test_transcript_entities = 'data/true_data_150.jsonl'

In [4]:
def retrieve_key(file: str) -> int:
    try:
        # 3 digit
        key = int(file[2:5])
    except ValueError:
        # 1 digit
        if file[3] == '.':
            key = int(file[2])
        else:
            key = int(file[2:4])
    return key

In [5]:
files = os.listdir('data/Audio_Files_for_testing')

files = sorted(files, key=retrieve_key)
files = [f'data/Audio_Files_for_testing/{file}' for file in files]
print(files)
print(len(files))

['data/Audio_Files_for_testing/id1.wav', 'data/Audio_Files_for_testing/id2.wav', 'data/Audio_Files_for_testing/id3.wav', 'data/Audio_Files_for_testing/id4.wav', 'data/Audio_Files_for_testing/id5.wav', 'data/Audio_Files_for_testing/id6.wav', 'data/Audio_Files_for_testing/id7.wav', 'data/Audio_Files_for_testing/id8.wav', 'data/Audio_Files_for_testing/id9.wav', 'data/Audio_Files_for_testing/id10.wav', 'data/Audio_Files_for_testing/id11.wav', 'data/Audio_Files_for_testing/id12.wav', 'data/Audio_Files_for_testing/id13.wav', 'data/Audio_Files_for_testing/id14.wav', 'data/Audio_Files_for_testing/id15.wav', 'data/Audio_Files_for_testing/id16.wav', 'data/Audio_Files_for_testing/id17.wav', 'data/Audio_Files_for_testing/id18.wav', 'data/Audio_Files_for_testing/id19.wav', 'data/Audio_Files_for_testing/id20.wav', 'data/Audio_Files_for_testing/id21.wav', 'data/Audio_Files_for_testing/id22.wav', 'data/Audio_Files_for_testing/id23.wav', 'data/Audio_Files_for_testing/id24.wav', 'data/Audio_Files_for_te

Note: The gretel, seed, and similar_0.3 datasets are used to fine-tune both the entity-aware ASR and the LLaMa error correction module. The `_preprocessed.jsonl` files for all three contain the CORRECTED version, which is fed as the outputs for fine-tuning the LLM correction module. In the original ICASSP paper, the input prompt shows "Best hypothesis". This "Best hypothesis" does not refer to the actual corrected transcript - but the best transcript generated by the ASR (i.e., Best-N = 4, the top one).

gretel

In [None]:
gretel_audio_path = 'data/Gretel_preprocessed'

seed

In [6]:
# TODO

similar_0.3

In [None]:
# TODO

### Load the best model

Currently, the best model is `whisper-small_en_seed_gretel_similar0.3`. We shall load this model and test it on a subset to determine:

1. How the model works;
2. Whether the model is able to identify the PIIs from the transcripts

In [7]:
import torch

device = 'cpu'

if torch.backends.mps.is_available():
    device = 'mps'
elif torch.cuda.is_available():
    device = 'cuda'

print(device)

mps


In [7]:
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq

processor = AutoProcessor.from_pretrained('code/whisper-small_en_seed_gretel_similar0.3')
model = AutoModelForSpeechSeq2Seq.from_pretrained('code/whisper-small_en_seed_gretel_similar0.3').to(device)

  from .autonotebook import tqdm as notebook_tqdm


### Investigate the model's tokenizer

`whisper-small.en` default tokens

In [20]:
default_whisper_processor = AutoProcessor.from_pretrained('openai/whisper-small.en')
default_whisper_tokens = default_whisper_processor.tokenizer.added_tokens_decoder

`whsiper-small_en_seed_gretel_similar0.3` tokens

In [22]:
pii_whisper_tokens = processor.tokenizer.added_tokens_decoder

In [24]:
pii_whisper_tokens

{50256: AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
 50257: AddedToken("<|startoftranscript|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
 50258: AddedToken("<|en|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
 50259: AddedToken("<|zh|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
 50260: AddedToken("<|de|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
 50261: AddedToken("<|es|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
 50262: AddedToken("<|ru|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
 50263: AddedToken("<|ko|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
 50264: AddedToken("<|fr|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True

In [27]:
special_tokens = [pii_whisper_tokens[token_id] for token_id in pii_whisper_tokens if token_id not in default_whisper_tokens]
special_tokens

[AddedToken("[PERSON_START]", rstrip=False, lstrip=False, single_word=False, normalized=True, special=False),
 AddedToken("[PERSON_END]", rstrip=False, lstrip=False, single_word=False, normalized=True, special=False),
 AddedToken("[PHONE_START]", rstrip=False, lstrip=False, single_word=False, normalized=True, special=False),
 AddedToken("[PHONE_END]", rstrip=False, lstrip=False, single_word=False, normalized=True, special=False),
 AddedToken("[EMAIL_START]", rstrip=False, lstrip=False, single_word=False, normalized=True, special=False),
 AddedToken("[EMAIL_END]", rstrip=False, lstrip=False, single_word=False, normalized=True, special=False),
 AddedToken("[CREDIT_CARD_START]", rstrip=False, lstrip=False, single_word=False, normalized=True, special=False),
 AddedToken("[CREDIT_CARD_END]", rstrip=False, lstrip=False, single_word=False, normalized=True, special=False),
 AddedToken("[BANK_ACCOUNT_START]", rstrip=False, lstrip=False, single_word=False, normalized=True, special=False),
 Added

So, the unique PIIs to identify are:

1. Person
2. Phone
3. Email
4. Credit Card
5. Bank Account
6. Car Plate
7. NRIC
8. Passport Number

### Creating a Pandas DataFrame for logging the transcripts

In [9]:
import pandas as pd

test_df = pd.DataFrame(data=files, columns=['file_name'])
test_df.head()

Unnamed: 0,file_name
0,data/Audio_Files_for_testing/id1.wav
1,data/Audio_Files_for_testing/id2.wav
2,data/Audio_Files_for_testing/id3.wav
3,data/Audio_Files_for_testing/id4.wav
4,data/Audio_Files_for_testing/id5.wav


### Test the model with one sample

In [None]:
import librosa
from typing import List

def transcribe(audioPath: str, model: AutoModelForSpeechSeq2Seq, processor: AutoProcessor, best_n: int = 1) -> List[str]:
    """
    A function which transcribes the audio based on a given audio file path.
    Outputs the transcript along with the identified PII entities.
    
    Keyword arguments:
    audioPath (str) -- The path to the audio\n
    model (AutoModelForSpeechSeq2Seq) -- The ASR model\n
    processor (AutoProcessor) -- The processor, which contains the feature extractor and tokenizer.\n
    best_n (int) -- The best n number. By default, return the best transcription. 

    Return: The transcription along with the identified PII entities. (str)
    """
    waveform, sr = librosa.load(audioPath, sr=16000)
    inputs = processor(waveform, sampling_rate=sr, return_tensors="pt").to(device)
    with torch.no_grad():
        generated_ids = model.generate(
            input_features=inputs["input_features"], 
            temperature=1.0,
            num_beams=best_n,
            num_return_sequences=best_n
        )
    transcriptions = processor.batch_decode(generated_ids, skip_special_tokens=True)
    return transcriptions

In [15]:
test1 = transcribe(test_df['file_name'].iloc[0], model, processor, 5)
for index, transcript in enumerate(test1):
    print(f"Rank {index+1} transcription: {transcript}")

Rank 1 transcription: The day before yesterday, rahm received a different email from rai.es.n my.ec at outlook.sg
Rank 2 transcription: The day before yesterday, ram received another email from [EMAIL_START] r e m y at outlook.sg [EMAIL_END]
Rank 3 transcription: The day before yesterday, [PERSON_START] Ram received another email from [EMAIL_START] R e m y at outlook.sg [EMAIL_END]
Rank 4 transcription: The day before yesterday, ram received another email from [EMAIL_START] rhemyah at outlook.sg [EMAIL_END]
Rank 5 transcription: The day before yesterday mom received another email from [EMAIL_START] R e m y at outlook.sg [EMAIL_END]


### Testing the model across all samples in the test set

In [31]:
from tqdm import tqdm

for index, row in tqdm(test_df.iterrows(), desc="Transcribing and Identifying PII from test set...", total=len(test_df)):
    transcriptions = transcribe(row['file_name'], model, processor, 5)
    for i, transcription in enumerate(transcriptions):
        test_df.at[index, f'rank_{i+1}'] = transcription   

Transcribing and Identifying PII from test set...: 100%|██████████| 150/150 [09:05<00:00,  3.64s/it]


In [32]:
test_df.head()

Unnamed: 0,file_name,rank_1,rank_2,rank_3,rank_4,rank_5
0,data/Audio_Files_for_testing/id1.wav,"The day before yesterday, Ram received another...","the day before yesterday, [PERSON_START] Ram [...","The day before yesterday, ram received another...",The day before yesterday RAM received another...,"the day before yesterday, ram received another..."
1,data/Audio_Files_for_testing/id2.wav,my date of birth is uh second september 1992 uh,my date of birth is uhm 2 september 1992,mm my date of birth is uh 2 september [NRIC_ST...,uh my date of birth is uh 2 september[PHONE_ST...,my date of birth is uh second [PERSON_START] s...
2,data/Audio_Files_for_testing/id3.wav,She handed over a crumpled piece of paper ther...,She handed over a crumpled piece of paper the...,She handed over a crumpled piece of paper ther...,She handed over a crumpled piece of paper ther...,She handed over a crumpled piece of paper ther...
3,data/Audio_Files_for_testing/id4.wav,and uh uh three three of the other one yeah r...,and uh uh three-three of the other one yeah uh,okay and uh three three of the other one yeah,and uh uh three three of the other one yeah ...,the and uh three three of the other one yeah u...
4,data/Audio_Files_for_testing/id5.wav,uh hongs email is [EMAIL_START] P X one r z a ...,hongs email is [EMAIL_START] T x 1 rz e 4 7 at...,uh Hong's email is [EMAIL_START] Px1 rz [EMAIL...,hongs email is [EMAIL_START] T x 1 rz eh 4 7 a...,hongs email is px on rza 4 7 at [PERSON_STAR...


In [34]:
test_df.to_csv('data/whisper-small_en_seed_gretel_similar0.3_test_set_transcribed_n_best_5.csv', index=False)

### Compute the Metrics

#### Load the ground truth

In [10]:
ground_truth_df = pd.read_json('data/true_data_150.jsonl', lines=True)

In [11]:
ground_truth_df.head()

Unnamed: 0,text
0,"The day before [DATE_START] yesterday, [DATE_E..."
1,um my date of birth is uh second [DATE_START] ...
2,"she handed over a crumpled piece of paper, the..."
3,aglio olio and err uh [CARDINAL_START] three t...
4,[PERSON_START] Hong' [PERSON_END]s email is [E...


#### Dictionary of all entities

In [12]:
entities = {
    'EMAIL': ['EMAIL_START', 'EMAIL_END'],
    'NRIC': ['NRIC_START', 'NRIC_END'],
    'CREDIT_CARD': ['CREDIT_CARD_START', 'CREDIT_CARD_END'],
    'PHONE': ['PHONE_START', 'PHONE_END'],
    'PASSPORT_NUM': ['PASSPORT_NUM_START', 'PASSPORT_NUM_END'],
    'BANK_ACCOUNT': ['BANK_ACCOUNT_START', 'BANK_ACCOUNT_END'],
    'CAR_PLATE': ['CAR_PLATE_START', 'CAR_PLATE_END'],
    'PERSON': ['PERSON_START', 'PERSON_END']
}

In [13]:
additional_entities = {
    'EMAIL': ['EMAIL_START', 'EMAIL_END'],
    'NRIC': ['NRIC_START', 'NRIC_END'],
    'CREDIT_CARD': ['CREDIT_CARD_START', 'CREDIT_CARD_END'],
    'PHONE': ['PHONE_START', 'PHONE_END'],
    'PASSPORT_NUM': ['PASSPORT_NUM_START', 'PASSPORT_NUM_END'],
    'BANK_ACCOUNT': ['BANK_ACCOUNT_START', 'BANK_ACCOUNT_END'],
    'CAR_PLATE': ['CAR_PLATE_START', 'CAR_PLATE_END'],
    'PERSON': ['PERSON_START', 'PERSON_END'],
    'DATE': ['DATE_START', 'DATE_END']
}

redundant_entities = [entity for entity in additional_entities if entity not in entities.keys()]

redundant_entities

['DATE']

In [14]:
additional_entities[redundant_entities[0]]

['DATE_START', 'DATE_END']

#### Clean the test set

Remove any entities identified from the test set that is not present in the `entities` dictionary.

In [16]:
from typing import Dict, List

def remove_redundant_entities(text: str, valid_entities: Dict[str, List[str]]) -> str:
    """
    Take a row from the true data and find any identified
    entity tags that are not present in the 'entities'
    dictionary and replace them with an empty string.
    
    Keyword arguments:
    text (str) -- The text to replace.
    valid_entities (Dict[str, List[str]]) -- A dictionary containing all the valid entities to be identified

    Return: Replaced text.
    """
    # Step 1: Identify redundant tags.
    curr_entity_tag = ''
    invalid_entity_tags_detected = {}
    in_entity = False
    base_entity = ''
    for c in text:
        if c == '[':
            in_entity = True
        elif c == ']':
            in_entity = False
            # Check to see if entity is a valid one (or just some string covered in parentheses)
            if 'START' in curr_entity_tag or 'END' in curr_entity_tag:
                base_entity = curr_entity_tag[:-6] if 'START' in curr_entity_tag else curr_entity_tag[:-4]
                if base_entity not in valid_entities.keys():
                    if base_entity not in invalid_entity_tags_detected.keys():
                        invalid_entity_tags_detected[base_entity] = [curr_entity_tag]
                    else:
                        invalid_entity_tags_detected[base_entity].append(curr_entity_tag)
            base_entity = ''
            curr_entity_tag = ''
        elif in_entity:
            curr_entity_tag += c

    # Step 2: Remove the redundant entity tags from the text
    for entity in invalid_entity_tags_detected:
        if entity in text:
            replaced_text_start = text.replace(f"[{invalid_entity_tags_detected[entity][0]}]", "")
            text = replaced_text_start.replace(f"[{invalid_entity_tags_detected[entity][1]}]", "")
    
    return text

Testing the function that removes redundant entities

In [17]:
remove_redundant_entities("[ah] okay um then uh how about [DATE_START] next week  [DATE_END][DATE_START] wednesday  [DATE_END]I will be, I'll come at at around [TIME_START] two-thirty [TIME_END]", entities)

"[ah] okay um then uh how about  next week   wednesday  I will be, I'll come at at around  two-thirty "

In [18]:
remove_redundant_entities("The day before [DATE_START] yesterday, [DATE_END] [PERSON_START] Ram  [PERSON_END]received another email from [EMAIL_START] r e m y at outlook dot sg [EMAIL_END]", entities)

'The day before  yesterday,  [PERSON_START] Ram  [PERSON_END]received another email from [EMAIL_START] r e m y at outlook dot sg [EMAIL_END]'

In [19]:
for index, row in ground_truth_df.iterrows():
    ground_truth_df.at[index, 'text_cleaned'] = remove_redundant_entities(row['text'], entities)

ground_truth_df.head()

Unnamed: 0,text,text_cleaned
0,"The day before [DATE_START] yesterday, [DATE_E...","The day before yesterday, [PERSON_START] Ram..."
1,um my date of birth is uh second [DATE_START] ...,um my date of birth is uh second september n...
2,"she handed over a crumpled piece of paper, the...","she handed over a crumpled piece of paper, the..."
3,aglio olio and err uh [CARDINAL_START] three t...,aglio olio and err uh three three of the oth...
4,[PERSON_START] Hong' [PERSON_END]s email is [E...,[PERSON_START] Hong' [PERSON_END]s email is [E...


#### Compute the Precision, Recall, and F1-score

Number of types of entities

In [20]:
len(entities)

8

In [26]:
print(list(entities.keys()))

['EMAIL', 'NRIC', 'CREDIT_CARD', 'PHONE', 'PASSPORT_NUM', 'BANK_ACCOUNT', 'CAR_PLATE', 'PERSON']


We can use the BIO (Before-Inside-Outside) format to gather information on the entities

Example:

```
"The day before yesterday,   Ram     received another email from     r       e       m       y       at   outlook   dot     sg     ."
  O   O     O       O      B-PERSON      O      O       O     O   B-EMAIL I-EMAIL I-EMAIL I-EMAIL I-EMAIL I-EMAIL I-EMAIL I-EMAIL  O
```

Every initial token of the entity, even if it is the only token for that entity, will be tagged with "B-<ENTITY\>"

In [35]:
import re
from typing import Tuple

def text_to_bio(text: str) -> Tuple[list, list]:
    """ 
        Converts the output of the entity-aware ASR to BIO format.

        Arguments:
            text (str) - The transcript containing the identified entities.
        
        Returns:
            A tuple of lists, one being a list of the tokens, and the other being the
            identified tags.
    """
    tokens = []
    tags = []

    # Regular expressions to identify entity annotations
    entity_pattern = re.compile(r"\[([A-Z_]+)_START\](.*?)\[([A-Z_]+)_END\]")
    
    # Tokenize and tag
    current_index = 0
    for match in entity_pattern.finditer(text):
        # Add tokens and tags for non-entity text before the entity
        pre_entity_text = text[current_index:match.start()].strip()
        if pre_entity_text:
            pre_tokens = pre_entity_text.split()
            tokens.extend(pre_tokens)
            tags.extend(["O"] * len(pre_tokens))
        
        # Add tokens and tags for the entity
        entity_type = match.group(1)
        entity_text = match.group(2).strip()
        entity_tokens = entity_text.split()
        tokens.extend(entity_tokens)
        tags.append(f"B-{entity_type}")
        tags.extend([f"I-{entity_type}"] * (len(entity_tokens) - 1))
        
        # Update the current index
        current_index = match.end()

    # Add remaining non-entity text
    post_entity_text = text[current_index:].strip()
    if post_entity_text:
        post_tokens = post_entity_text.split()
        tokens.extend(post_tokens)
        tags.extend(["O"] * len(post_tokens))
    
    return tokens, tags

Example use

In [36]:
text = "The day before yesterday, [PERSON_START] Ram [PERSON_END] received another email from [EMAIL_START] r e m y at outlook dot sg [EMAIL_END]."
tokens, tags = text_to_bio(text)

for item in zip(tokens, tags):
    print(item)

('The', 'O')
('day', 'O')
('before', 'O')
('yesterday,', 'O')
('Ram', 'B-PERSON')
('received', 'O')
('another', 'O')
('email', 'O')
('from', 'O')
('r', 'B-EMAIL')
('e', 'I-EMAIL')
('m', 'I-EMAIL')
('y', 'I-EMAIL')
('at', 'I-EMAIL')
('outlook', 'I-EMAIL')
('dot', 'I-EMAIL')
('sg', 'I-EMAIL')
('.', 'O')
