## Transcription of speech segments (diarized single- or multi-speaker audio)

In this script, we will use the **'Whisper-timestamped' model** (https://github.com/linto-ai/whisper-timestamped) to transcribe speech segments of a **multi-speaker audio file**. This is an adaptation of OpenAi's Whisper that indicates for each word and sentence begin and ending timestamps. We will use these timestamps to write the whisper output to a **.textgrid** file. *Note*: to transcribe a **single-speaker audio file without previous diarization**, please use the model `transcribe_file_total()` defined in the following section.

The algorithm takes as an input a **diarized .textgrid** file, i.e. a file where speech segments of different speakers are indicated on different tiers. Importanly, the speech segments need to be filled with some characters (e.g. 'speech') to distinguish themselves from the intervals between the speech segments. The algorithm expects that the .textgrid contains multiple tiers that each indicate the speech of one speaker (see parameter 'diarized_speakers='all''), and no other tiers (for example for annotations). If you have only diarized the speech of **one speaker**, please indicate **diarized_speakers='AC'** and name the tier of the speaker 'AC' (Autistic Child). Please indicate the language spoken in the parameter 'language'. The code should work for all languages supported by Whisper. If you don't know which language is spoken, please indicate 'language='unknown'': Whisper will then predict the language on the basis of the audio and use the corresponding language model.

The output of the model is a **new .textgrid**, whose utput path to the new file is specified in the parameters of the function. The model will transcribe the speech of every speaker in three different tiers: **an 'utterance' tier, a 'word' tier and a 'warning' tier**. In the 'warning' tier, an interval is created when the utterance only contains one word or when it contains a word that appears in a list of trigger words, that often occur when Whisper tries to generate random words on noise. This list is completed for French and Dutch. If you are using this model for another language, feel free to discover which words Whisper generates on silent or noisy parts and add some trigger words to the `trigger_list`. The name of the tier corresponds to the name of the tier in the diarization textgrid, followed by an underscore and 'word', 'utterance' or 'warning', e.g. 'AC_utterance', 'AC_word' and 'AC_warning'.

Part of this notebook is based on Oriane Martin's 'Whispgrid': https://github.com/orianemartin/WhispGrid

In [2]:
def transcribe_file_segments(language, selected_model, audio_file, diarization_textgrid, output_path, speakers_diarized='all'):
    
    ## Import libraries
    from pydub import AudioSegment
    from praatio import textgrid
    import whisper_timestamped as whisper
    from whisper.tokenizer import get_tokenizer
    import torch
    import tgt

    
    ## Load audio via 'AudioSegment' library => used to calculate duration of the audio file
    audio = AudioSegment.from_wav(audio_file)
        
    ## Configurate Whisper timestamped model
    
    # Choose tokenizer
    tokenizer = get_tokenizer(multilingual=True)
    
    # The variable 'number_tokens' will contain the full words for numbers in the language in question,
    # so that we can let Whisper spell out numbers instead of writing them in digits
    number_tokens = [
        i 
        for i in range(tokenizer.eot)
        if all(c in "0123456789" for c in tokenizer.decode([i]).strip())]
    
    # Load selected Whisper timestamped model
    model = whisper.load_model(selected_model, device="cpu")
    
    whisper_languages= whisper.tokenizer.LANGUAGES
    
    # Get language codes from 'language' attribute input for model configuration
    if language=='French':
        language_code='fr'
    elif language=="Flemish":
        language_code='nl'
    # For other languages:
    elif language.lower() in whisper_languages.values():
        language_code= str(list(whisper_languages.keys())[list(whisper_languages.values()).index(language.lower())])
    # If language unknown: let Whisper detect the language
    elif language== 'unknown':
        w_audio = whisper.load_audio(audio_file)
        w_audio = whisper.pad_or_trim(w_audio)

        # make log-Mel spectrogram and move to the same device as the model
        mel = whisper.log_mel_spectrogram(w_audio).to(selected_model.device)

        # detect the spoken language
        _, probs = model.detect_language(mel)
        language_code= str(max(probs, key=probs.get))
    
    # Warning-triggering list
    trigger_list= ["Amara.org", "d'Amara.org","sous-titres","ondertitels","ondertiteling"]
    
    
    ## Transribe the audio file using Whisper timestamped => based on 'Whispgrid' script
    ## Use configurations defined above
    def transcribe_segment(start_int, end_int):
        audio[start_int*1000: end_int*1000].export('segment.wav', format='wav')
        
        result = whisper.transcribe(
            model,
            'segment.wav',
            #vad= 'auditok', # gives errors and does not seem to improve the alignment
            detect_disfluencies= True, 
            language=language_code,
            beam_size=5,
            best_of=5,
            temperature=(0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
            trust_whisper_timestamps=False,
            suppress_tokens=[-1] + number_tokens)
        #os.remove('segment.wav')
        
        return result
    
    # Output 'result' has dictionary structure
    
    def get_speech_intervals(tier, empty=False):
        entries= diar_tg.getTier(tier).entries
        intervals= []
        for entry in entries:
            if (entry.label and empty==False) or (not entry.label and empty==True):
                intervals.append((entry.start, entry.end))
        return intervals
    
    # Get speech intervals
    diar_tg= textgrid.openTextgrid(diarization_textgrid, includeEmptyIntervals= True)
    
    if speakers_diarized== 'all':
        all_tiers= diar_tg.tierNames
        all_intervals= {}
        for tier in all_tiers:
            all_intervals[tier]= get_speech_intervals(tier)
    
    elif speakers_diarized== 'AC':
        all_intervals['AC']= get_speech_intervals('AC')
        all_intervals['other']= get_speech_intervals('AC', empty=True)
        
    
    ## Define functions to check for proper alignment of the time intervals predicted by Whisper
    
    # check_overlap function: checks if the current interval overlaps with a previous interval
    def check_overlap(intervals, time_seconds):
        # Iterate through intervals in the tier
        for start, end in intervals:
            # Check if the time falls within any interval
            if start <= time_seconds < end:
                return True
        # If no interval found:
        return False
    
    # check_alignment function: checks if the start timestamp differs from the end timestamp => should not be the case
    # (if Whisper makes up words, the start timestamps are often the same as the end timestamps)
    def check_alignment(start_interval, end_interval):
        if start_interval == end_interval:
            return False 
        else:
            return True
        
    ## Open a new textgrid for the transcription
    tg = tgt.TextGrid()
    
    ## Iterate over speakers in the diarization textgrid, create output tiers and transcribe their speech
    for speaker in all_intervals:
        
    
        # Initialize interval tiers
        sentence_tier = tgt.IntervalTier(start_time=0, end_time=len(audio), name=f"{speaker} utterance")
        word_tier = tgt.IntervalTier(start_time=0, end_time=len(audio), name=f"{speaker} word")   
        # Intervals will be indicated on the 'warning tier' when it is probable that Whisper's prediction is incorrect
        # (for example, if an utterance only contains one word)
        warning_tier = tgt.IntervalTier(start_time=0, end_time=len(audio), name=f"{speaker} WARNING")
        
        # Initialize lists of word and sentence time intervals.
        # If a predicted time interval overlaps an interval that has been predicted for a previous word
        # or sentence, then the current word/ sentence will not be added to the Textgrid
        word_intervals=[]
        sentence_intervals=[]
        
        
        for (start_int, end_int) in all_intervals[speaker]:
            
            result= transcribe_segment(start_int, end_int)
            
        
            ## While checking for proper alignment, write the model's predictions to the Textgrid

            for segment in result["segments"]: # 'segment' attribute of Whisper = +- utterance

                sentence=[] # Initialize sentence text => renews for each whisper segment
                sentence_start_s=-1 # Initialize integer value for start of sentence (will be replaced by start time of first word)


                # Write words of a segment to textgrid
                if "words" in segment and segment["words"]: # Check that 'segment' has non-empty attribute 'words'

                    for word in segment["words"]:
                        
                        word_start= start_int + word['start']
                        word_end= start_int+ word['end']

                        # Check overlap and alignment of word timestamps and add word interval to word tier
                        if check_alignment(word_start, word_end) \
                        and not check_overlap(word_intervals, word_start)\
                        and word['text'] != '[*]': # Do not transcribe '[*]': Whisper's indication of pauses

                            ## tgt.Interval takes as arguments the start and end time of an interval and the text of the speech,
                            ## and creates an interval that can be added to the Praat tier

                            # Create the interval
                            word_interval = tgt.Interval(start_time=float(word_start), 
                                                         end_time=float(word_end), text=word["text"])
                            # Write the interval to the tier
                            word_tier.add_interval(word_interval)
                            # Add start and end times to word_intervals, so that future time intervals can be checked for overlap
                            word_intervals.append((word_start, word_end))

                            # Add the transcribed word to sentence text
                            sentence.append(word['text'])

                            # Define sentence start timestamp
                            if sentence_start_s== -1: # = if the sentence_start has not been modified yet
                                sentence_start_s= word_start # change sentence start to start of first word



                ## Write sentence to Textgrid

                if sentence: # if words have been transcribed for the current segment
                    sentence_end_s= word_intervals[-1][1] # end timestamp of sentence = end timestamp of last word

                    # Check alignment and add sentence interval to sentence tier (same procedure as above)
                    if check_alignment(sentence_start_s, sentence_end_s) \
                    and not check_overlap(sentence_intervals, sentence_start_s):

                        sentence_interval= tgt.Interval(start_time=sentence_start_s, 
                                                        end_time=sentence_end_s, text=' '.join(sentence))
                        sentence_tier.add_interval(sentence_interval)
                        sentence_intervals.append((sentence_start_s, sentence_end_s))

                    # Add an interval on the warning tier if there are signs of nonsense by Whisper
                    if len(sentence)==1 or any(trigger_word in sentence for trigger_word in trigger_list):
                        warning_interval= tgt.Interval(start_time=sentence_start_s, 
                                                        end_time=sentence_end_s, text= 'WARNING')
                        warning_tier.add_interval(warning_interval)


    
        ## When all speech intervals are transribed, add tiers to output textgrid 
        tg.add_tier(warning_tier)
        tg.add_tier(sentence_tier)
        tg.add_tier(word_tier)

    tgt.write_to_file(tg, output_path, format='short')

## Transcription of entire audio (non-diarized single-speaker audio)

If you have a **single-speaker audio** that is not yet diarized, you can use the following function to let Whisper **diarize and transcribe** the file. The input of a diarization file is thus not required (nor supported) in this model. The working and output of the model are the same as those of the previous model; the tiers will be named 'utterance', 'word' and 'warning', without indication of the speaker.

In [None]:
def transcribe_file_total(language, selected_model, audio_file, output_path):
    
    ## Import libraries
    from pydub import AudioSegment
    from praatio import textgrid
    import whisper_timestamped as whisper
    from whisper.tokenizer import get_tokenizer
    import torch
    import tgt
    
    ## Load audio via 'AudioSegment' library => used to calculate duration of the audio file
    audio = AudioSegment.from_wav(audio_file)
        
    ## Configurate Whisper timestamped model
    
    # Choose tokenizer
    tokenizer = get_tokenizer(multilingual=True)
    
    # The variable 'number_tokens' will contain the full words for numbers in the language in question,
    # so that we can let Whisper spell out numbers instead of writing them in digits
    number_tokens = [
        i 
        for i in range(tokenizer.eot)
        if all(c in "0123456789" for c in tokenizer.decode([i]).strip())]
    
    # Load selected Whisper timestamped model
    model = whisper.load_model(selected_model, device="cpu")
    
    whisper_languages= whisper.tokenizer.LANGUAGES
    
    # Get language codes from 'language' attribute input for model configuration
    if language=='French':
        language_code='fr'
    elif language=="Flemish":
        language_code='nl'
    # For other languages:
    elif language.lower() in whisper_languages.values():
        language_code= str(list(whisper_languages.keys())[list(whisper_languages.values()).index(language.lower())])
    # If language unknown: let Whisper detect the language
    elif language== 'unknown':
        w_audio = whisper.load_audio(audio_file)
        w_audio = whisper.pad_or_trim(w_audio)

        # make log-Mel spectrogram and move to the same device as the model
        mel = whisper.log_mel_spectrogram(w_audio).to(selected_model.device)

        # detect the spoken language
        _, probs = model.detect_language(mel)
        language_code= str(max(probs, key=probs.get))
    
    
    ## Transribe the audio file using Whisper timestamped => based on 'Whispgrid' script
    ## Use configurations defined above
    result = whisper.transcribe(
        model,
        audio_file,
        #vad= 'auditok', # gives errors and does not seem to improve the alignment
        detect_disfluencies= True, 
        language=language_code,
        beam_size=5,
        best_of=5,
        temperature=(0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
        trust_whisper_timestamps=False,
        suppress_tokens=[-1] + number_tokens)
    # Output 'result' has dictionary structure
    
    # Warning-triggering list
    trigger_list= ["Amara.org", "d'Amara.org","sous-titres","ondertitels","ondertiteling"]
    
    
    ## Open a new textgrid for the transcription
    tg = tgt.TextGrid()
    
    # Initialize interval tiers
    sentence_tier = tgt.IntervalTier(start_time=0, end_time=len(audio), name="utterance")
    word_tier = tgt.IntervalTier(start_time=0, end_time=len(audio), name="word")   
    # Intervals will be indicated on the 'warning tier' when it is probable that Whisper's prediction is wrong
    # (for example, if an utterance only contains one word)
    warning_tier = tgt.IntervalTier(start_time=0, end_time=len(audio), name="WARNING")
        
        
    ## Define functions to check for proper alignment of the time intervals predicted by Whisper
    
    # Initialize lists of word and sentence time intervals.
    # If a predicted time interval overlaps an interval that has been predicted for a previous word
    # or sentence, then the current word/ sentence will not be added to the Textgrid
    word_intervals=[]
    sentence_intervals=[]

    # check_overlap function: checks if the current interval overlaps with a previous interval
    def check_overlap(intervals, time_seconds):
        # Iterate through intervals in the tier
        for start, end in intervals:
            # Check if the time falls within any interval
            if start <= time_seconds < end:
                return True
        # If no interval found:
        return False
    
    # check_alignment function: checks if the start timestamp differs from the end timestamp => should not be the case
    # (if Whisper makes up words, the start timestamps are often the same as the end timestamps)
    def check_alignment(start_interval, end_interval):
        if start_interval == end_interval:
            return False 
        else:
            return True


    ## While checking for proper alignment, write the model's predictions to the Textgrid

    for segment in result["segments"]: # 'segment' attribute of Whisper = +- utterance

        sentence=[] # Initialize sentence text => renews for each whisper segment
        sentence_start_s=-1 # Initialize integer value for start of sentence (will be replaced by start time of first word)


        # Write words of a segment to textgrid
        if "words" in segment and segment["words"]: # Check that 'segment' has non-empty attribute 'words'
            
            for word in segment["words"]:

                # Check overlap and alignment of word timestamps and add word interval to word tier
                if check_alignment(word['start'], word['end']) \
                and not check_overlap(word_intervals,word["start"])\
                and word['text'] != '[*]': # Do not transcribe '[*]': Whisper's indication of pauses
                    
                    ## tgt.Interval takes as arguments the start and end time of an interval and the text of the speech,
                    ## and creates an interval that can be added to the Praat tier
                    
                    # Create the interval
                    word_interval = tgt.Interval(start_time=float(word['start']), 
                                                 end_time=float(word['end']), text=word["text"])
                    # Write the interval to the tier
                    word_tier.add_interval(word_interval)
                    # Add start and end times to word_intervals, so that future time intervals can be checked for overlap
                    word_intervals.append((word['start'], word['end']))

                    # Add the transcribed word to sentence text
                    sentence.append(word['text'])
                    
                    # Define sentence start timestamp
                    if sentence_start_s== -1: # = if the sentence_start has not been modified yet
                        sentence_start_s= word['start'] # change sentence start to start of first word
                        
                    # If the alignment and overlap conditions are not fulfilled: do not add word


        
        ## Write sentence to Textgrid

        if sentence: # if words have been transcribed for the current segment
            sentence_end_s= word_intervals[-1][1] # end timestamp of sentence = end timestamp of last word
            
            # Check alignment and add sentence interval to sentence tier (same procedure as above)
            if check_alignment(sentence_start_s, sentence_end_s) \
            and not check_overlap(sentence_intervals, sentence_start_s):

                sentence_interval= tgt.Interval(start_time=sentence_start_s, 
                                                end_time=sentence_end_s, text=' '.join(sentence))
                sentence_tier.add_interval(sentence_interval)
                sentence_intervals.append((sentence_start_s, sentence_end_s))
                
            # Add an interval on the warning tier if there are signs of nonsense by Whisper
            if len(sentence)==1 or any(trigger_word in sentence for trigger_word in trigger_list): 
                warning_interval= tgt.Interval(start_time=sentence_start_s, 
                                                end_time=sentence_end_s, text= 'WARNING')
                warning_tier.add_interval(warning_interval)


    
    ## When all speech intervals are transribed, add tiers to output textgrid and output it
    tg.add_tier(warning_tier)
    tg.add_tier(sentence_tier)
    tg.add_tier(word_tier)

    tgt.write_to_file(tg, output_path, format='short')