In [None]:
import pandas as pd
from pathlib import Path
import json
import librosa
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches

def plot_importance(filename, importance_file, remove_silence=False, labels=[None], thr=0):
    df_cough = pd.read_csv('/mnt/data/IEMOCAP-happy-cough/metadata.csv')
    df_cough['stem'] = df_cough['filename'].apply(lambda x: Path(x).stem)
    cough_row = df_cough.loc[df_cough['stem'] == Path(filename).stem]
    x, fs = librosa.core.load(filename, sr=16000)
    x = x[int(0.4*fs):int(7.32*fs)]
    if not isinstance(importance_file, list):
        importance_file = []
    importances = []
    for fname in importance_file:
        with open(fname, 'r') as f:
            ip = json.load(f)
            print(ip['metadata'])
            if ('true_markers' in ip['metadata']) and (len(ip['metadata']['true_markers']) > 0):
                gt_markers = ip['metadata']['true_markers'][0]
            ip = np.array(ip['importance_scores']['random_forest_tree_importance']['values'])
            
            importances.append(ip)
    total_duration = x.shape[0]/fs
    importance_t = np.linspace(0, total_duration, importances[0].shape[0])

    fig, ax = plt.subplots(nrows=2, ncols=1, figsize=(7,4), height_ratios=[1,1.6])
    cough_filename = cough_row['cough_filename'].values[0]
    cough_path = '/mnt/data/hear-selected/esc50-v2.0.0-full/16000/fold0{}/{}'.format(int(cough_filename.split('-')[0])-1, cough_filename)
    cough, fs = librosa.core.load(cough_path, sr=16000)
    splits = librosa.effects.split(cough)
    cough_end = splits[-1][-1]
    if cough_end > int(x.shape[0]*0.2):
        cough_end = int(x.shape[0]*0.2)
    
    cough = cough[:cough_end]
    cough_ = np.zeros_like(x)
    cough_[int(cough_row['cough_start'].values[0]-0.4*fs): int(cough_row['cough_end'].values[0]-0.4*fs)] = cough
    speech, fs = librosa.core.load(cough_row['speech_filename'].values[0], sr=None)
    speech = speech[int(0.4*fs):int(7.32*fs)]
    trans_file = str(Path(cough_row['speech_filename'].values[0]).resolve()).replace('/wav','/ForcedAlignment').replace('.wav','.wdseg')
    y_text = -1.3
    with open(trans_file, 'r') as f:
        trans = f.read()
        trans_rows = trans.split('\n')
        starts = []
        ends = []
        words = []
        for i, r in enumerate(trans_rows):
            if i>0 and (i<len(trans_rows)-2):
                r_parts = r.split()
                word = r_parts[-1].split('(')[0]
                if word in ['<s>','<sil>']:
                    word = ''
                if word == '</s>':
                    word = '(LAUGH)'
                start = int(r_parts[0])/100.0
                end = int(r_parts[1])/100.0 + 0.01
                starts.append(start)
                ends.append(end)
                words.append(word)
    print(words)
    starts = starts[2:-1]
    words = words[2:-1]
    ends = ends[2:-1]
    if thr>0:
        data = [(si,ei,wi) for si,ei,wi in zip(starts,ends,words)]
        print(data)
        merged = merge_close_timestamps(data, thr)
        print(merged)
        starts = [m[0] for m in merged]
        ends = [m[1] for m in merged]
        words = [m[2] for m in merged]

    if remove_silence:
        speech = speech[int(starts[1]*fs):int(starts[-1]*fs)]
        cough_ = cough_[int(starts[1]*fs):int(starts[-1]*fs)]
        gt_markers = [gt_markers[0]-starts[1], gt_markers[1]-starts[1]]
        total_duration = speech.shape[0]/fs
        ws = starts[1]
        starts = starts[1:-1]
        ends = ends[1:-1]
        words = words[1:-1]
        
        starts = [s - ws for s in starts]
        ends = [e - ws for e in ends]
        start_idx = int(np.floor(10*ws))
        end_idx = int(np.floor(10*(ends[-1] + ws)))
        importances = [ip[start_idx:end_idx] for ip in importances]
        importance_t = np.linspace(0, total_duration, importances[0].shape[0])
        
    gt_markers = [2.85025 - 0.4, 3.49025 - 0.4]
    starts = [s - 0.4 for s in starts]
    ends = [e - 0.4 for e in ends]
    t = np.linspace(0,total_duration,speech.shape[0])
    last_pos = 0
    for s, e, w in zip(starts, ends, words):
        pos = e
        last_pos = pos
        text = ax[0].text(pos, y_text, w, fontsize=10, 
             ha='right', va='top', rotation=40)
        #bbox = text.get_window_extent(renderer=ax[0].figure.canvas.get_renderer())
        #text_width = bbox.width / ax[0].figure.dpi  # Convert to data coordinates
        #pos_adjusted = pos - text_width / 2  # Shift left slightly
        #text.set_x(pos_adjusted)
    ax[0].plot(t, speech/np.max(np.abs(speech)), c='b', alpha=0.5)
    ax[0].plot(t, cough_/np.max(np.abs(cough_)), c='r', alpha=0.5)
    ax[0].vlines(x=gt_markers,ymin=-1.1,ymax=x.max()+1.1, color='k', linestyle='--', linewidth=1)
    ax[0].set_ylim(-1.1, 1.1)
    ax[0].set_xlim(0, total_duration)
    ax[0].set_ylabel('Amplitude')
    ax[0].set_xticks(starts)
    ax[0].set_xticklabels(['']*len(starts))
    #ax[0].spines['bottom'].set_position(('data', - 1.15))  # Shift x-axis down
    ax[0].tick_params(axis='x', length=7)

    colors = ['r','b']
    for i,ip in enumerate(importances):
        ax[1].plot(importance_t, ip,  drawstyle='steps-mid', c=colors[i], label=labels[i])
    ax[1].vlines(x=gt_markers,ymin=importances[0].min()-0.1,ymax=importances[0].max()+0.1, color='k', linestyle='--', linewidth=1)
    ax[1].set_ylim(importances[0].min(), importances[0].max()+0.01)
    ax[1].set_ylabel('RF importance')
    ax[1].set_xlim(0, total_duration)
    ax[1].set_xlabel('Time (s)')
    ax[1].legend(title='Classifier Training Data', bbox_to_anchor=[1.01,1.04,0,0])
    
    fig.align_ylabels()
    plt.tight_layout()
    plt.savefig('importances.pdf')

In [None]:
plt.rcParams.update({'font.size': 10})
plt.rc('axes', labelsize=12)
plot_importance('/mnt/data/IEMOCAP-happy-cough/Session5/sentences/wav/Ses05F_impro03/Ses05F_impro03_M037.wav', 
                ['/home/lpepino/ft2_1_zeros_cough.json','/home/lpepino/ft2_1_zeros_nocough.json'], 
                remove_silence=False,
                labels=['Corrupted IEMOCAP','Original IEMOCAP'],
                thr=0.2)