<a href="https://colab.research.google.com/github/jqug/MLPR_crop_disease_mapping/blob/master/Luganda_ASR_inference_on_SEMA_recordings.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Download models

ASR Luganda model and a generic speech enhancement model.

In [101]:
language = 'lug' #@param['lug','en']

In [None]:
%%capture
!pip install transformers
!pip install speechbrain
!pip install pyctcdecode
!pip install https://github.com/kpu/kenlm/archive/master.zip
!pip install "rich[jupyter]"

In [102]:
%%capture
from transformers import AutoProcessor, AutoModelForCTC, AutoFeatureExtractor
from transformers import Wav2Vec2ProcessorWithLM, Wav2Vec2CTCTokenizer
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC

import librosa
import torch
import torchaudio
from speechbrain.pretrained import SpectralMaskEnhancement
import glob
import os
import json
import pandas as pd
from IPython.display import Audio, display
import rich

if language == 'lug':
  processor = Wav2Vec2ProcessorWithLM.from_pretrained("Sunbird/sunbird-asr")
  model = AutoModelForCTC.from_pretrained("Sunbird/sunbird-asr")
elif language == 'en':
  processor = Wav2Vec2ProcessorWithLM.from_pretrained(
      "jonatasgrosman/wav2vec2-large-xlsr-53-english")
  model = AutoModelForCTC.from_pretrained(
      "jonatasgrosman/wav2vec2-large-xlsr-53-english")

enhancer = SpectralMaskEnhancement.from_hparams(
    source="speechbrain/metricgan-plus-voicebank",
    savedir="pretrained_models/metricgan-plus-voicebank",
)

In [103]:
def transcribe_audio_file(audio_path):
  y, _ = librosa.load(audio_path, sr=16000)
  inputs = processor(
      y, sampling_rate=16000, return_tensors="pt", padding="longest")
  with torch.no_grad():
      logits = model(inputs.input_values).logits
  if type(processor) == Wav2Vec2ProcessorWithLM:
    transcription = processor.batch_decode(logits.numpy()).text[0]
  else:  
    predicted = torch.argmax(logits, dim=-1)
    predicted[predicted == -100] = processor.tokenizer.pad_token_id
    transcription = processor.tokenizer.batch_decode(predicted)[0]
  return transcription

Listen to the recordings, with and without enhancement/noise suppression.

In [121]:
print('Raw audio recording')
if language == 'lug':
  # audio_file = 'SEMA1-2022-10-19T091146-5.wav'
  audio_file = 'SEMA1-2022-10-19T091054-4.wav'
if language == 'en':
  # audio_file = 'SEMA1-2022-10-19T090905-2.wav'
  audio_file = 'SEMA1-2022-10-19T094827-4.wav'
  # audio_file = 'SEMA1-2022-11-10T090511-1.wav' # loud background noise
if not os.path.exists(audio_file):
  !wget -q https://sema-audio-files.s3.amazonaws.com/audio/SEMA1/{audio_file}
display(Audio(audio_file))

print('With speech enhancement')
enhanced = enhancer.enhance_file(audio_file)
display(Audio(data=enhanced.squeeze(0), rate=16000))
torchaudio.save('enhanced-' + audio_file,enhanced[None, :], 16000)

Raw audio recording


With speech enhancement


Try applying the speech enhancement model and listen to the results.

In [122]:
transcription_raw = transcribe_audio_file(audio_file)
print(f'Transcription from raw audio:\n{transcription_raw}')

Transcription from raw audio:
eddagala batuwa omusirikale oba eri gye gye tulina gye batalina bagamba nti tulina na nagula ko tusobola bagamba nti taliiwo


In [123]:
transcription_enhanced = transcribe_audio_file('enhanced-' + audio_file)
print(f'Transcription from enhanced audio:\n{transcription_enhanced}')

Transcription from enhanced audio:
ne ddagala batuwa kusigala oba ali kye tulina kye batalina bagaala tetulina yawulako kyekyo oba baagisala


# Visualise word-level confidences

We can try to score the confidence of each word, to understand which detected words are high are low certainty.

In [124]:
# Get model predictions, including the positions of detection words
y, _ = librosa.load(audio_file, sr=16000)
inputs = processor(
    y, sampling_rate=16000, return_tensors="pt", padding="longest")
with torch.no_grad():
    logits = model(inputs.input_values).logits
transcription = processor.batch_decode(
    logits.numpy(), output_word_offsets=True)

In [130]:
# Score the confidence, taken as the average of the max logit score
# for all the time frames corresponding to each word.
x = logits.numpy()[0, ...]
x = np.max(x, axis=1)
confidences = [
    np.mean(x[w['start_offset']:w['end_offset']])
    for w in transcription.word_offsets[0]
]
words = [w['word'] for w in transcription.word_offsets[0]]

# Render colour-coded text to see word confidences.
markup = []
c_max = 10
c_min = 5
for w, c in zip(words, confidences):
  score = (c - c_min) / (c_max - c_min)
  score = np.clip(score, 0, 1)
  a = int(255 * (1 - score))
  markup.append(f'[rgb({a},{a},{a})]{w}[/]')

console = rich.console.Console()
console.print(' '.join(markup), width=80)