In [None]:
import argparse
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import warnings
warnings.filterwarnings("ignore")
import json
from tqdm import tqdm

import librosa
import numpy as np

import torch
import librosa
from transformers import WhisperProcessor, WhisperForConditionalGeneration
import numpy as np
import pprint
import random


def find_audios(parent_dir, exts=['.wav', '.mp3', '.flac', '.webm', '.mp4', '.m4a']):
    audio_files = []
    for root, dirs, files in os.walk(parent_dir):
        for file in files:
            if os.path.splitext(file)[1] in exts:
                audio_files.append(os.path.join(root, file))
    return audio_files


#################### Whisper ####################
def remove_possible_overlaps(transcriptions, maximum_overlapping_tokens = 5):
    cleaned_transcriptions = []

    for i, current in enumerate(transcriptions):
        if i == 0:
            # Add the first segment without changes
            cleaned_transcriptions.append(current)
            continue
        

        previous_text = cleaned_transcriptions[-1]['text'].split(' ')
        if len(previous_text) > 0:
            if len(previous_text[-1]) > 0:
                previous_text[-1] = previous_text[-1][:-1] if previous_text[-1][-1] == '.' else previous_text[-1]
        
        current_text  = current['text'].split(' ')

        for j in range(8):
            prev_tokens = previous_text[-j-1:]
            cur_tokens  = current_text[:j+1]
            if prev_tokens == cur_tokens:
                # Matched
                current_text = current_text[j+1:]
                break

        current['text'] = ' '.join(current_text)
        cleaned_transcriptions.append(current)

    return cleaned_transcriptions


def chunk_audio(audio_path, sample_rate=16000, segment_length=30, overlap = 1.):
    """Load and split audio into 30-second chunks, 0.5 seconds overlap."""
    audio, _ = librosa.load(audio_path, sr=sample_rate, mono=True)

    # Calculate the number of samples per segment and the overlap in samples
    num_samples = segment_length * sample_rate
    overlap_samples = int(overlap * sample_rate)

    # Create the chunks with overlap
    chunks = [
        audio[i:i + num_samples] 
        for i in range(0, len(audio) - overlap_samples, num_samples - overlap_samples)
    ]

    chunk_timestamps = [(i/sample_rate, (i + num_samples)/ float(sample_rate)) for i in range(0, len(audio) - overlap_samples, num_samples - overlap_samples)]

    return chunks, sample_rate, chunk_timestamps

def transcribe_chunk(chunk, processor, model, no_speech_threshold=0.6, logprob_threshold=-1.0, temperature = (0.4, 0.7)):
    input_features = processor(chunk, sampling_rate=16000, return_tensors="pt").input_features.to("cuda")

    # Generate transcription with VAD filtering
    with torch.no_grad():
        predicted_ids = model.generate(
            input_features,
            output_scores=True,
            return_dict_in_generate=True,
            max_new_tokens=400,
            no_speech_threshold=no_speech_threshold,
            logprob_threshold=logprob_threshold,
            temperature= random.uniform(temperature[0], temperature[1])
        )
    
    # Decode and collect no-speech probability and log-probability information
    decoded_text = processor.batch_decode(predicted_ids.sequences, skip_special_tokens=True)[0]
    return {
        "text": decoded_text
    }


def transcribe_full(audio_path, processor, model, temperature = (0.4, 0.7), segment_length=30, no_speech_threshold=0.6, logprob_threshold=-1.0):
    # Chunk the audio
    chunks, _, timestamps = chunk_audio(audio_path, segment_length=segment_length)

    all_segments = []
    for index, chunk in enumerate(chunks):
        # Transcribe each chunk and gather segment details
        segment = transcribe_chunk(chunk, processor, model, no_speech_threshold, logprob_threshold, temperature)
        segment['timestamp'] = [timestamps[index][0], timestamps[index][1]]
        all_segments.append(segment)

    all_segments = remove_possible_overlaps(all_segments)

    return all_segments

def transcribe_and_save(whisper_model, processor, args):
    """transcribe the audio, and save the result with the same relative path in the output_dir
    """
    audio_files = find_audios(args.input_dir)

    for file in tqdm(audio_files):
        output_file = os.path.join(args.output_dir, os.path.relpath(file, args.input_dir))
        os.makedirs(os.path.dirname(output_file), exist_ok=True)
 
        
        if os.path.exists(output_file + '.json'):
            continue

        try:
            results = transcribe_full(
                    audio_path = file, 
                    processor = processor,
                    model = whisper_model,
                    temperature=args.temperature,
                    no_speech_threshold=args.no_speech_threshold,
                    logprob_threshold=args.logprob_threshold,
                )
            
            with open(output_file + '.json', 'w') as f:
                json.dump(results, f, indent=4, ensure_ascii=False)

        except Exception as e:
            print("ERROR: ", e)
            print("Please check: ", file)
            continue


In [2]:
class args:
    model =  "vinai/PhoWhisper-small"
    prompt = 'lời nhạc: '
    language = 'vi'
    input_dir = '/home/anh/Documents/vietnamese-song-scraping/out/validation-audio-100-demuc'
    output_dir = '/home/anh/Documents/vietnamese-song-scraping/out/PhoWhisper-small/validation-audio-100-demuc_nospeech-noremove'
    n_shard = 1
    shard_rank = 0
    threshold = 0
    debug = False
    top_n_sample = 2
    no_speech_threshold = None # 0.6
    logprob_threshold = None #-1.0
    temperature = (0.4,0.8)

*ADDED PhoWhispers*

In [3]:
model_name = "vinai/PhoWhisper-small"
processor = WhisperProcessor.from_pretrained(model_name)
model = WhisperForConditionalGeneration.from_pretrained(model_name).to("cuda")


In [4]:
transcribe_and_save(model, processor, args)

  0%|          | 0/100 [00:00<?, ?it/s]Due to a bug fix in https://github.com/huggingface/transformers/pull/28687 transcription using a multilingual Whisper will default to language detection followed by transcription instead of translation to English.This might be a breaking change for your use case. If you want to instead always translate your audio to English, make sure to pass `language='en'`.
Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.43.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
From v4.47 onwards, when a model cache is to be returned, `generate` will return a `Cache` instance instead by default (as opposed to