In [1]:
# Identify ASR transcriptions that are similar to hot-words

In [2]:
# Imports

from sentence_transformers import SentenceTransformer # InstructorEmbedding.INSTRUCTOR does not work, therefore fall back to the generic sentence_transformers.SentenceTransformer
import torch
import tqdm
from sklearn.metrics.pairwise import cosine_similarity

In [3]:
# Hyper-parameters

device = "cuda"

In [4]:
# Define method to read reference CSV file

def read_ref_csv(path):
    # Read a CSV file containing reference word sequences
    ref = dict()
    with open(path, "r", encoding="utf-8") as file:
        for line in file:
            line = line.rstrip().split(",")
            if ".mp3" in line[0]:
                utt_name = line[0]
                assert(utt_name not in ref)
                ref[utt_name] = line[1].upper() # wav2vec2-large-960h only supports upper case
    return ref

In [5]:
# Load text embedding model
model = SentenceTransformer('hkunlp/instructor-large')
model.to(device)

SentenceTransformer(
  (0): Transformer({'max_seq_length': 512, 'do_lower_case': False}) with Transformer model: T5EncoderModel 
  (1): Pooling({'word_embedding_dimension': 768, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False, 'pooling_mode_weightedmean_tokens': False, 'pooling_mode_lasttoken': False, 'include_prompt': False})
  (2): Dense({'in_features': 1024, 'out_features': 768, 'bias': False, 'activation_function': 'torch.nn.modules.linear.Identity'})
  (3): Normalize()
)

In [6]:
# Load transcription
transcriptions_filename = "/home/jeremy/htx_test1234/asr-train/cv-valid-dev.csv"
transcriptions = read_ref_csv(transcriptions_filename)

In [7]:
# Compute embeddings of hot words
prompt = "Represent the meaning of the following sentence:"
hot_words = ["be careful", "destroy", "stranger"]
hot_words_embeddings = dict()
for w in hot_words:
    hot_words_embeddings[w] = model.encode([[prompt, w.upper()]]) # Convert to upper case to match transcriptions

In [8]:
# Compute embeddings of the transcriptions
transcriptions_embeddings = dict()
for utt_name in tqdm.tqdm(list(sorted(transcriptions.keys()))):
    transcriptions_embeddings[utt_name] = model.encode([[prompt, transcriptions[utt_name]]])

100%|███████████████████████████████████████████████████████████████████████████████████| 4076/4076 [00:49<00:00, 81.66it/s]


In [9]:
# Find utterances whose embeddings are within a cosine similarity threshold of the hot words
threshold = 0.95 # If given more time, this threshold should be tuned on the training set.
num_detected = 0
utterances_with_hotwords = set()
with torch.no_grad():
    for utt_name in tqdm.tqdm(list(sorted(transcriptions_embeddings.keys()))):
        for w in hot_words_embeddings.keys():
            similarity = cosine_similarity(hot_words_embeddings[w], transcriptions_embeddings[utt_name])
            if similarity >= threshold:
                utterances_with_hotwords.add(utt_name)
                num_detected += 1
                continue
print("{} out of {} utterances found to be similar to hot-words".format(num_detected, len(transcriptions)))

100%|█████████████████████████████████████████████████████████████████████████████████| 4076/4076 [00:02<00:00, 1552.98it/s]

22 out of 4076 utterances found to be similar to hot-words





In [10]:
# Write detected utterances to file
output_filename = "/home/jeremy/htx_test1234/hotword-detection/cv-valid-dev.csv"
with open(output_filename, "w", encoding="utf-8") as file:
    print("utternace_name,generated_text,similarity", file=file)
    for utt_name in sorted(transcriptions.keys()):
        print("{},{},{}".format(
            utt_name,
            transcriptions[utt_name],
            "true" if utt_name in utterances_with_hotwords else "false"
        ), file=file)