# Transcription-based model for echolalia detection

The following function, `transcription_model_echolalia`, takes as input a transcription of a multi-speaker audio in **.TextGrid** format. The speech of each speaker should be transcribed in a different tier, and the tier of the child under scope should be named **'AC'** (Autistic Child). Please remove any tiers that contain other content than transcriptions from the input file. The function will return a **dataframe** containing the prediction ('echolalic' or 'non-echolalic') for each utterance pair with additional information about the timestamps of the utterances, the speaker of the source utterances, etcetera. Moreover, a **new .TextGrid** is produced on the basis of the input .TextGrid, that contains annotations of the echolalic utterance pairs in two new tiers: 'source' and 'echolalia'. The path to this new file needs to be specified in the parameter 'output path'. Other parameters are the language (supported: 'nl' (Dutch) and 'fr' (French); for other languages, the user needs to specify a new list of words that need to be disregarded in `to_delete`), the threshold of the amount of common content words that is needed for an echolalic utterance pair (default: 1), and 'allow_overlap'. 'Allow_overlap', *True* by default, specifies whether the model can consider utterances that begin for the start of the AC's utterance but continue when the AC is already speaking as source utterances. 

In [17]:
def transcription_model_echolalia(textgrid_file, output_path, language='nl', threshold=1, allow_overlap=True):
    """Takes a TextGrid file of the transcription of an audio file as input, 
    makes a prediction on the label (echolalic or not) for each suitable utterance-pair
    and returns a dataframe containing these predictions for each pair,
    along with an annotation of the predictions in a new TextGrid"""
    
    # Import libraries
    import pandas as pd
    from praatio import textgrid
    import re
    import spacy
    import string
    import tgt
    
    # Load the spacy model and prepare the extraction of speech intervals
    nlp= spacy.load(f'{language}_core_news_sm')
        
    def get_entries(tier):
        """Takes tier name as input
        and returns the entry dictionary for that tier"""
        tg= textgrid.openTextgrid(textgrid_file, False)

        return tg.getTier(tier).entries
    
    def get_speech_intervals(tier):
        """Takes tier name as input
        and returns a dictionary containing the timestamps and transcriptions
        of the utterances spoken by the speaker of the tier"""
        
        entries= get_entries(tier)
        intervals={}
        regex= r"^(xxx|yyy)\s?(\[.+\])?\.?$" # Regex that identifies unintelligible utterances; adapt if necessary
        for entry in entries:
            if not re.match(regex, entry.label):
                intervals[entry.start, entry.end]= entry.label

        return intervals 
    
    def get_other_speakers():
        """Finds the names of tiers of other speakers (not 'AC') in the input TextGrid"""
        tg = textgrid.openTextgrid(textgrid_file, includeEmptyIntervals=False)
        other_speakers=[]
        for tier in tg.tierNames:
            if tier != 'AC':
                other_speakers.append(tier)
        return other_speakers
    
    def preprocess_string(string):
        """Takes as input a string and outputs a new string where punctuation,
        truncated words and fillers have been removed"""
        
        string= string.translate(str.maketrans('','', '+?!,/.()[]'))
        if language== 'fr':
            to_delete = ['euhm', 'euh', 'uhm', 'mmh', 'xxx', 'eh', 'ben','hein', 'ah', 'bah', 'oh', 'bon']
        elif language=='nl':
            to_delete = ['euhm', 'euh', 'uhm', 'mmh', 'xxx', 'he', 'hè', 'hé', 'ah', 'oh']
        string= ' '.join([word for word in string.split() if word not in to_delete and word[-1]!='-' and word[0]!='-'])
        return string
    
    # Now create the output df and textgrid:

    child_intervals = get_speech_intervals('AC')

    df= pd.DataFrame(columns=['file','s2_tier', 'AC_int', 's2_int', 
                              'AC_trans', 's2_trans', 'child_lemmas', 's2_lemmas',
                              'predicted_n_lemmas', 'predicted_label'])
    
    
    tg = tgt.io.read_textgrid(textgrid_file, encoding='utf-16')
    source_tier = tgt.IntervalTier(start_time=0, name='source')
    rep_tier = tgt.IntervalTier(start_time=0, name='echolalia')
    all_sources=[]
    all_echoes=[]
    

    row_df=0
    
    # Iterate over the speaker_data dataframe to compare the utterances of the autistic child
    # with those of all other speakers:

    for other_speaker in get_other_speakers():
        s2_tier= other_speaker

        s2_intervals = get_speech_intervals(s2_tier)


        for start_child, end_child in child_intervals:

            child_int= [start_child, end_child]
            
            child_trans= child_intervals[start_child, end_child]
            new_child_trans= preprocess_string(child_trans)
            doc= nlp(new_child_trans)
            child_lemmas= [(token.lemma_.lower(), token.pos_) for token in doc] 
            
            

            for start_s2, end_s2 in s2_intervals:

                s2_int= [start_s2,end_s2]

                if 0 < start_child - start_s2 <= 10\
                    and (end_s2 < start_child or allow_overlap==True): # if other speaker interval starts at most 10 seconds before child interval:

                    s2_trans= s2_intervals[start_s2,end_s2]
                    new_s2_trans= preprocess_string(s2_trans)
                    doc= nlp(new_s2_trans)
                    s2_lemmas= [(token.lemma_.lower(), token.pos_) for token in doc]
                    
                    
                    # Get predictions
                    function_pos= ['PRON','AUX', 'DET','INTJ', 'ADP', 'CCONJ', 'SCONJ', 'PUNCT', 'SYM']
                    
                    common_lemmas=[]
                    common_content_lemmas=[]
                    
                    for lemma1, pos1 in child_lemmas:
                        for lemma2, pos2 in s2_lemmas: 
                            # Append all identical non-punctuation lemmas to 'common_lemmas'
                            if (lemma1, pos1) == (lemma2, pos2) \
                            and lemma1 not in common_lemmas and pos1 not in ['PUNCT', 'SYM']:
                                common_lemmas.append(lemma1)
                                
                                # Append all identical lemmas of content words to 'common_content_lemmas'
                                if pos1 not in function_pos and (lemma1, pos1) not in [('pas','ADV'), ('niet','ADV')]\
                                and lemma1 not in common_content_lemmas:
                                    common_content_lemmas.append(lemma1)
                    
                    predicted_n_lemmas= np.nan
                    # Check if the number of common content lemmas exceeds the threshold:
                    if len(common_content_lemmas) >= threshold:
                        predicted_label= 'echolalic'
                        predicted_n_lemmas= len(common_lemmas)
                    else:
                        predicted_label= 'non-echolalic'
                    
                    
                    # Add original and echolalic utterance to the output textgrid if the prediction is 'echolalic'
                    
                    if predicted_label=='echolalic':
                        if s2_int not in all_sources:
                            all_sources.append(s2_int)
                            source_interval_tg = tgt.Interval(start_time=float(start_s2),
                                                          end_time=float(end_s2),
                                                          text= 'source')
                            source_tier.add_interval(source_interval_tg)
                            
                        if child_int not in all_echoes:
                            all_echoes.append(child_int)
                            rep_interval_tg = tgt.Interval(start_time=float(start_child),
                                                          end_time=float(end_child),
                                                          text= 'echolalic')
                            rep_tier.add_interval(rep_interval_tg)
                   
                    
                    
                    # Append all features of the utterance pair to the output dataframe:

                    df.loc[row_df]= [textgrid_file, s2_tier, str(child_int), str(s2_int),
                                    child_trans, s2_trans, child_lemmas, s2_lemmas, 
                                    predicted_n_lemmas, predicted_label]

                    row_df+=1
    
    # Output the textgrid outside of the loop
    tg.add_tier(source_tier)
    tg.add_tier(rep_tier)
    tgt.write_to_file(tg, output_path, format='short')


    return df
