In [None]:
import logging
logging.getLogger("speechbrain").setLevel(logging.WARNING)

import torch
from torchaudio.io import StreamReader
from asr import load_asr_model, transcribe
from decoder import build_decoder

In [109]:
import numpy as np
from IPython.display import Audio

In [1052]:
CHUNK_SIZE = 8
CHUNK_SAMPLES = 639
SAMPLE_FILE = "https://upload.wikimedia.org/wikipedia/commons/transcoded/9/97/Spoken_Wikipedia_-_One_Times_Square.ogg/Spoken_Wikipedia_-_One_Times_Square.ogg.mp3"
DECODE_BUFFER_LEN = 64

asr_model, context = load_asr_model(CHUNK_SIZE, 2)
decoder = build_decoder()
tokenizer = asr_model.hparams.tokenizer

sr = asr_model.hparams.sample_rate
frames_per_chunk = CHUNK_SIZE * CHUNK_SAMPLES

INFO:pyctcdecode.alphabet:Alphabet determined to be of BPE style.


In [1053]:
streamer = StreamReader(src=SAMPLE_FILE)
streamer.add_basic_audio_stream(
    frames_per_chunk=frames_per_chunk, sample_rate=sr, num_channels=1
)
stream = streamer.stream(timeout=-1)

logits_buffer = np.empty((0, 1000))
decoded_beams = []

global_frame = 0
saved_audio_chunks = torch.tensor([])

In [1054]:
def greedy_decode(logits, tokenizer, blank_id=0, decode_frame=0):
    preds = logits.argmax(-1)
    pred_frames = np.where(preds != blank_id)[-1]
    filtered_preds = preds[pred_frames]
    
    tokens = [tokenizer.decode(pred) for pred in filtered_preds.tolist()]
    
    return  tokens, filtered_preds, pred_frames + decode_frame

In [1055]:
output_text = []
output_frames_start = np.array([])
output_frames_end = np.array([])
decoding_start = 0
current_frame = 0

correction_start_idx = 0
correction_end_idx = 0

In [1056]:
# #####
# LOOP
# #####

In [1057]:
# get chunk
(chunk,) = next(stream)
chunk = chunk.squeeze(-1).unsqueeze(0).float()

# save the tail of the audio
saved_audio_chunks = torch.cat((saved_audio_chunks, chunk), -1)
saved_audio_chunks = saved_audio_chunks[:, -min(saved_audio_chunks.size(-1), DECODE_BUFFER_LEN * CHUNK_SAMPLES):]

# run inference
with torch.no_grad():
    logits, words = transcribe(asr_model, context, chunk)

# update buffer
logits_buffer = np.concatenate((logits_buffer, logits[0].numpy()))
logits_buffer = logits_buffer[-min(logits_buffer.shape[1], DECODE_BUFFER_LEN):, :]

# decode new tokens
tokens, _, frames_nums = greedy_decode(logits[0], tokenizer, blank_id=0, decode_frame=current_frame)

# update output
output_text += tokens
output_frames_start = np.concatenate((output_frames_start, frames_nums))
output_frames_end = np.concatenate((output_frames_end, frames_nums+1))

# update frame tracker
current_frame += logits.size(1)

In [1058]:
print(f"Output: {output_text}")
print(f"Starts: {output_frames_start}")
print(f"Ends:   {output_frames_end}")
print()
print(f"Current frame: {current_frame}") 
print(f"Decoding start: {decoding_start}")

Output: []
Starts: []
Ends:   []

Current frame: 8
Decoding start: 0


In [1059]:
_, _, decoded_text_frames, *_ = decoder.decode_beams(logits_buffer)[0]
decoded_text, decoded_frames = zip(*decoded_text_frames)
decoded_start_frames, decoded_end_frames = zip(*decoded_frames)

decoded_start_frames = np.array(decoded_start_frames, np.int64) + decoding_start
decoded_end_frames = np.array(decoded_end_frames, np.int64) + decoding_start

print(f"Decoded output: {decoded_text}")
print(f"Decoded starts: {decoded_start_frames}")
print(f"Decoded end:    {decoded_end_frames}")

ValueError: not enough values to unpack (expected 2, got 0)

In [1060]:
decoding_start += 8

correction_start_idx += next((i for i, start in enumerate(output_frames_start[correction_start_idx:]) if start >= decoding_start))
correction_end_idx += next((i for i, end in enumerate(output_frames_end[correction_end_idx:]) if end > decoding_start + DECODE_BUFFER_LEN), len(output_frames_end)-correction_end_idx)

correction_start_idx, correction_end_idx

StopIteration: 

In [1061]:
output_text = output_text[:correction_start_idx] + list(decoded_text) + output_text[correction_end_idx:]
output_frames_start = np.concatenate((output_frames_start[:correction_start_idx], decoded_start_frames, output_frames_start[correction_end_idx:]))
output_frames_end = np.concatenate((output_frames_end[:correction_start_idx], decoded_end_frames, output_frames_end[correction_end_idx:]))

print(f"Updated output: {output_text}")
print(f"Updated starts: {output_frames_start}")
print(f"Updated ends:   {output_frames_end}")

Updated output: ['ENCYCLOPAEDIA']
Updated starts: [124.]
Updated ends:   [141.]


In [1063]:
correction_end_idx += (len(decoded_text) - (correction_end_idx-correction_start_idx))
correction_start_idx, correction_end_idx

(0, 1)

In [None]:
# END TEST HERE

In [813]:
# run beam search
if logits_buffer.shape[0] >= DECODE_BUFFER_LEN:
    decoded_text, _, decoded_text_frames, *_ = decoder.decode_beams(logits_buffer)[0]

    # only update what's within decoding window
    correction_start_idx = next(
    (i + correction_start_idx for i, start in enumerate(output_frames_start[correction_start_idx:])
     if start >= decoding_start),
    correction_start_idx
    )
    correction_end_idx = next(
        (i + correction_end_idx for i, end in enumerate(output_frames_end[correction_end_idx:])
         if end > decoding_start + DECODE_BUFFER_LEN),
        len(output_frames_end)
    )

print(correction_start_idx)
print(correction_end_idx)

2
4


In [814]:
if logits_buffer.shape[0] >= DECODE_BUFFER_LEN:

    decoded_text, decoded_frames = zip(*decoded_text_frames)
    
    # compare 
    print("Before overwriting:")
    print(f"Relevant output: {output_text[correction_start_idx:correction_end_idx]}")
    print(f"Start times:     {output_frames_start}\n")
    
    print(f"Corrected output: {decoded_text}")
    print(f"Start/end times:  {decoded_frames}")

Before overwriting:
Relevant output: ['SQUAR', 'E']
Start times:     [ 8. 16. 26. 49.]

Corrected output: ('ONE', 'TIME', 'SQUARE')
Start/end times:  ((0, 1), (8, 9), (18, 26))


In [815]:
if logits_buffer.shape[0] >= DECODE_BUFFER_LEN:
    # update outputs
    output_text = output_text[:correction_start_idx] + list(decoded_text) + output_text[correction_end_idx:]
    
    decoded_starts, decoded_ends = zip(*decoded_frames)
    output_frames_start = np.concatenate((output_frames_start[:correction_start_idx], list(decoded_starts), output_frames_start[correction_end_idx:]))
    output_frames_end = np.concatenate((output_frames_end[:correction_start_idx], list(decoded_ends), output_frames_end[correction_end_idx:]))

    print(f"New output: {output_text}")
    print(f"output start/end: {output_frames_start, output_frames_end}")
    
    # update correction end pointer to account for resizing
    correction_end_idx += len(decoded_text) - (correction_end_idx-correction_start_idx)
    
    # update decoding window
    decoding_start += logits.size(1)

# show beams
# print(decoded_text_frames)

# play audio that's just been rescored
Audio(saved_audio_chunks, rate=sr)

New output: ['ONE', 'TIME', 'ONE', 'TIME', 'SQUARE']
output start/end: (array([ 8., 16.,  0.,  8., 18.]), array([ 9., 17.,  1.,  9., 26.]))
