# Testing the Entity Aware ASR

In [2]:
import os

os.chdir('..')
os.getcwd()

'/Users/farhan/Desktop/Research'

### [Optional] Run these cells when you run out of memory

mps

In [60]:
import torch

torch.mps.empty_cache()

cuda

In [None]:
import torch

torch.cuda.empty_cache()

### Load the test data (Audio, Reference Transcripts, and Identified Entities)

In [9]:
test_audio_path = 'data/Audio_Files_for_testing'
test_transcript_entities = 'data/true_data_150.jsonl'

In [30]:
def retrieve_key(file: str) -> int:
    try:
        # 3 digit
        key = int(file[2:5])
    except ValueError:
        # 1 digit
        if file[3] == '.':
            key = int(file[2])
        else:
            key = int(file[2:4])
    return key

In [33]:
files = os.listdir('data/Audio_Files_for_testing')

files = sorted(files, key=retrieve_key)
files = [f'data/Audio_Files_for_testing/{file}' for file in files]
print(files)
print(len(files))

['data/Audio_Files_for_testing/id1.wav', 'data/Audio_Files_for_testing/id2.wav', 'data/Audio_Files_for_testing/id3.wav', 'data/Audio_Files_for_testing/id4.wav', 'data/Audio_Files_for_testing/id5.wav', 'data/Audio_Files_for_testing/id6.wav', 'data/Audio_Files_for_testing/id7.wav', 'data/Audio_Files_for_testing/id8.wav', 'data/Audio_Files_for_testing/id9.wav', 'data/Audio_Files_for_testing/id10.wav', 'data/Audio_Files_for_testing/id11.wav', 'data/Audio_Files_for_testing/id12.wav', 'data/Audio_Files_for_testing/id13.wav', 'data/Audio_Files_for_testing/id14.wav', 'data/Audio_Files_for_testing/id15.wav', 'data/Audio_Files_for_testing/id16.wav', 'data/Audio_Files_for_testing/id17.wav', 'data/Audio_Files_for_testing/id18.wav', 'data/Audio_Files_for_testing/id19.wav', 'data/Audio_Files_for_testing/id20.wav', 'data/Audio_Files_for_testing/id21.wav', 'data/Audio_Files_for_testing/id22.wav', 'data/Audio_Files_for_testing/id23.wav', 'data/Audio_Files_for_testing/id24.wav', 'data/Audio_Files_for_te

### Load the best model

Currently, the best model is `whisper-small_en_seed_gretel_similar0.3`. We shall load this model and test it on a subset to determine:

1. How the model works;
2. Whether the model is able to identify the PIIs from the transcripts

In [6]:
import torch

device = 'cpu'

if torch.backends.mps.is_available():
    device = 'mps'
elif torch.cuda.is_available():
    device = 'cuda'

print(device)

mps


In [8]:
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq

processor = AutoProcessor.from_pretrained('code/whisper-small_en_seed_gretel_similar0.3')
model = AutoModelForSpeechSeq2Seq.from_pretrained('code/whisper-small_en_seed_gretel_similar0.3').to(device)

### Creating a Pandas DataFrame for logging the transcripts

In [34]:
import pandas as pd

test_df = pd.DataFrame(data=files, columns=['file_name'])
test_df.head()

Unnamed: 0,file_name
0,data/Audio_Files_for_testing/id1.wav
1,data/Audio_Files_for_testing/id2.wav
2,data/Audio_Files_for_testing/id3.wav
3,data/Audio_Files_for_testing/id4.wav
4,data/Audio_Files_for_testing/id5.wav


### Test the model with one sample

In [53]:
import librosa

def transcribe(audioPath: str, model: AutoModelForSpeechSeq2Seq, processor: AutoProcessor, best_n: int = 1) -> str:
    """
    A function which transcribes the audio based on a given audio file path.
    Outputs the transcript along with the identified PII entities.
    
    Keyword arguments:
    audioPath (str) -- The path to the audio\n
    model (AutoModelForSpeechSeq2Seq) -- The ASR model\n
    processor (AutoProcessor) -- The processor, which contains the feature extractor and tokenizer.\n
    best_n (int) -- The best n number. By default, return the best transcription. 

    Return: The transcription along with the identified PII entities. (str)
    """
    waveform, sr = librosa.load(audioPath, sr=16000)
    inputs = processor(waveform, sampling_rate=sr, return_tensors="pt").to(device)
    with torch.no_grad():
        generated_ids = model.generate(
            input_features=inputs["input_features"], 
            temperature=0.0,
            num_beams=best_n,
            num_return_sequences=best_n
        )
    transcriptions = processor.batch_decode(generated_ids, skip_special_tokens=True)
    return transcriptions

In [61]:
test1 = transcribe(test_df['file_name'].iloc[16], model, processor, 5)
for transcript in test1:
    print(transcript)

RuntimeError: MPS backend out of memory (MPS allocated: 17.79 GB, other allocations: 13.22 MB, max allowed: 18.13 GB). Tried to allocate 514.98 MB on private pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure).

### Testing the model across all samples in the test set