In [12]:
import time
import os
import soundfile as sf
import sounddevice as sd
from scipy.io import wavfile
import pandas as pd

### Record/play functions

In [5]:
# record audio
FS = 44100  # default sample rate
REC_S = 3  # defaut duration of recording

def record(fn='output.wav', duration=REC_S, fs=FS):
    recording = sd.rec(int(duration * fs), samplerate=fs, channels=1)
    print('Recording', end='')
    for i in range(10):
        time.sleep(duration/8.0)
        print('.', end='')
    
    sd.wait()  # Wait until recording is finished
    print(' Done')
    wavfile.write(fn, fs, recording)  # Save as WAV file 
    
def play(fn):
    # Extract data and sampling rate from file
    data, fs = sf.read(fn, dtype='float32')  
    sd.play(data, fs)
    status = sd.wait()  # Wait until file is done playing

### wav2vec

Tested on:
- **transformers==4.4.0.dev** (installed via `pip install -e '.[dev]'` on latest version of [repo](https://github.com/huggingface/transformers))
- **torch==1.7.1**

In [56]:
import transformers
import torch

print(torch.__version__)
print(transformers.__version__)

1.7.1
4.4.0.dev0


In [57]:
import torch
from transformers import Wav2Vec2ForMaskedLM, Wav2Vec2Tokenizer

# load pretrained model
tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h")
model = Wav2Vec2ForMaskedLM.from_pretrained("facebook/wav2vec2-base-960h")

We'll test inference on short audio (1-10s) using LJ Speech Dataset (download [here](https://keithito.com/LJ-Speech-Dataset/), read also: [TensorFlow LJSpeech](https://www.tensorflow.org/datasets/catalog/ljspeech)).

We may also choose to test via pre-recorded audio or own recording using `record()`. Note that `wav2vec2` model expects **16 kHz, single-channel** audio.

#### Get sample from LJ Speech

In [63]:
LJ_DIR = 'LJSpeech-1.1'
if not os.path.exists(LJ_DIR):
    raise Exception('Download LJ Speech first.')
    
ds = pd.read_csv(os.path.join(LJ_DIR, 'metadata.csv'), sep='|', names=['id', 'text', 'text_normalized'])
#print('Text range: {:.0f}-{:.0f}'.format(ds.text_normalized.str.len().min(),ds.text_normalized.str.len().max()))

MAX_LEN = 75
ds_short = ds[ds.text_normalized.str.len() < MAX_LEN]

samp = ds_short.sample().iloc[0]
print(samp.id, samp.text_normalized)

LJ019-0328 Stringent rules were prescribed for the prison surgeons;


In [64]:
# LJ Speech is not 16kHz, so we re-sample
from scipy.signal import resample

target_fs = 16000
test_file = 'output.wav'

raw_file = os.path.join(LJ_DIR, 'wavs', samp.id+'.wav')
x, orig_fs = sf.read(raw_file)
x = resample(x, num=int(len(x)*target_fs/orig_fs))
wavfile.write(test_file, target_fs, x)  # Save as WAV file 

play(test_file)

#### Record own audio / Pre-recorded

In [41]:
# set name appropriately and comment-out record() if gonna use pre-recorded
test_file = 'output.wav'
record(fn=test_file, fs=16000, duration=3)

play(test_file)

#### Inference

In [65]:
# load audio
audio_input, _ = sf.read(test_file) # expected 16kHz, single-channel

# transcribe
input_values = tokenizer(audio_input, return_tensors="pt").input_values
logits = model(input_values).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = tokenizer.batch_decode(predicted_ids)[0]
print(transcription)

STRINGENT RULES WERE PRESCRIBED FOR THE PRISON SURGEONS
