In [1]:
import soundfile as sf
from pydub import AudioSegment

from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
import torch

from datetime import timedelta
import math

import sys
sys.path.append(f"/usr/local/bin/ffmpeg")

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Load model and tokenizer
processor = Wav2Vec2Processor.from_pretrained(f"classla/wav2vec2-xls-r-parlaspeech-hr")
model = Wav2Vec2ForCTC.from_pretrained(f"classla/wav2vec2-xls-r-parlaspeech-hr")

# Load audio sample. Must be in wav format with sample rate 16000
speech = AudioSegment.from_wav(f"data/VRH-121 - 27.05.2022.wav")

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [3]:
# Get a chunk from the audio with specified starting time, offset and audio duration in miliseconds
# Chunk is exported to a temporary file
def get_speech_line(speech: AudioSegment, start: int, end: int, offset: int) -> None:
    if start + offset >= end:
        speech_line = speech[start:end]
    else:
        speech_line = speech[start:start+offset]
    speech_line.export(out_f=f"data/cc_tmp.wav", format="wav")
    return

# Get a transcription for the audio fromthe temporary file
def get_text_line(model: Wav2Vec2ForCTC,processor: Wav2Vec2Processor) -> str:
    speech_line, sample_rate=sf.read(f"data/cc_tmp.wav")
    input_values = processor(speech_line, sampling_rate=sample_rate, return_tensors="pt").input_values.to(device)
    logits = model.to(device)(input_values).logits
    predicted_ids = torch.argmax(logits, dim=-1)
    transcription = processor.decode(predicted_ids[0]).lower()
    return transcription

# Format miliseconds to timestamp: 00:00:00.000
def get_timestamp(t: int) -> str:
    t_hour = math.floor(t / 1000 / 60 / 60)
    t = t - t_hour * 60 * 60 * 1000
    t_minute = math.floor(t / 1000 / 60)
    t = t - t_minute * 60 * 1000
    t_second = math.floor(t / 1000)
    t = t - t_second * 1000
    t_milisecond = t % 1000
    return ("{:02d}".format(t_hour) + ":" + "{:02d}".format(t_minute) + ":" + "{:02d}".format(t_second) + "." + "{:03d}".format(t_milisecond))

# Format time start and time end of the current text line to the SRT timestamp standard:
# 00:00:00.000 --> 00:00:02.500
def get_time_line(start: int, end: int) -> str:
    return get_timestamp(start) + " --> " + get_timestamp(end)

In [4]:
speech = AudioSegment.from_wav(f"data/VRH-121 - 27.05.2022.wav")
start = 0
end = speech.duration_seconds * 1000
offset = 2500

In [5]:
i = 0

while start + offset < end:
    get_speech_line(speech, start, end, offset)
    text_line = get_text_line(model, processor)
    time_line = get_time_line(start, start + offset)
    
    with open(f"data/VRH-121 - 27.05.2022.srt", "a") as f:
        f.write(str(i) + "\n" + time_line + "\n" + text_line + "\n\n")
    start = start + offset
    i += 1

KeyboardInterrupt: 