In [3]:
from egocom import audio
from egocom.multi_array_alignment import gaussian_kernel
from egocom.transcription import async_srt_format_timestamp
from scipy.io import wavfile
import os
import numpy as np
import pandas as pd
from sklearn.metrics import accuracy_score
from egocom.transcription import write_subtitles

In [4]:
def gaussian_smoothing(arr, samplerate = 44100, window_size = 0.1):
    '''Returns a locally-normalized array by dividing each point by a the 
    sum of the points around it, with greater emphasis on the points 
    nearest (using a Guassian convolution)
    
    Parameters
    ----------
    arr : np.array
    samplerate : int
    window_size : float (in seconds)
    
    Returns
    -------
    A Guassian smoothing of the input arr'''
    
    kern = gaussian_kernel(kernel_length=int(samplerate * window_size), nsigma=3)
    return np.convolve(arr, kern, 'same')

In [5]:
# Tests for audio.avg_pool_1d

def test_exact_recoverability(
    arr = range(10),
    pool_size = 4,
    weights = [0.2,0.3,0.3,0.2],
): 
    '''Verify that downsampled signal can be fully recovered exactly.'''
    complete_result = audio.avg_pool_1d(range(10), pool_size, filler = True, weights = weights)
    downsampled_result = audio.avg_pool_1d(range(10), pool_size, filler = False, weights = weights)
    # Try to recover filled_pooled_mags using the downsampled pooled_mags
    upsampled_result = audio.upsample_1d(downsampled_result, len(arr), pool_size)
    assert(np.all(upsampled_result == complete_result))
    
    
def test_example(
    arr = range(10),
    pool_size = 4,
    weights = [0.2,0.3,0.3,0.2],
):
    '''Verify that avg_pool_1d produces the result we expect.'''
    result = audio.avg_pool_1d(range(10), pool_size, weights = weights)
    expected = np.array([1.5, 1.5, 1.5, 1.5, 5.5, 5.5, 5.5, 5.5, 8.5, 8.5])
    assert(np.all(result - expected < 1e-6))
    
test_exact_recoverability()
test_example()

# Generate speaker labels from max raw audio magnitudes

In [25]:
data_dir = '/Users/cgn/Dropbox (Facebook)/EGOCOM/raw_audio/wav/'

fn_dict = {}
for fn in sorted(os.listdir(data_dir)):
    key = fn[9:23] + fn[32:37] if 'part' in fn else fn[9:21]
    fn_dict[key] = fn_dict[key] + [fn] if key in fn_dict else [fn]

In [26]:
samplerate = 44100
window = 1 # Averages signals with windows of N seconds.
window_length = int(samplerate * window)

labels = {}
for key in list(fn_dict.keys()):
    print(key, end = " | ")
    fns = fn_dict[key]
    wavs = [wavfile.read(data_dir + fn)[1] for fn in fns]
    duration = min(len(w) for w in wavs)
    wavs = np.stack([w[:duration] for w in wavs])
    
    # Only use the magnitudes of both left and right for each audio wav.
    mags = abs(wavs).sum(axis = 2) 

    # DOWNSAMPLED (POOLED) Discretized/Fast (no overlap) gaussian smoothing with one-second time window.
    kwargs = {
        'pool_size': window_length, 
        'weights': gaussian_kernel(kernel_length=window_length),
        'filler': False,
    }
    pooled_mags = np.apply_along_axis(audio.avg_pool_1d, 1, mags, **kwargs) 

    # Create noisy speaker labels
    threshold = np.percentile(pooled_mags, 10, axis = 1)
    no_one_speaking = (pooled_mags > np.expand_dims(threshold, axis = 1)).sum(axis = 0) == 0
    speaker_labels = np.argmax(pooled_mags, axis = 0)
    speaker_labels[no_one_speaking] = -1
    
    # User 1-based indexing for speaker labels (ie increase by 1)
    speaker_labels = [z if z < 0 else z + 1 for z in speaker_labels]
    
    # Store results
    labels[key] = speaker_labels

day_1__con_1__part1 | day_1__con_1__part2 | day_1__con_1__part3 | day_1__con_1__part4 | day_1__con_1__part5 | day_1__con_2__part1 | day_1__con_2__part2 | day_1__con_2__part3 | day_1__con_2__part4 | day_1__con_2__part5 | day_1__con_3__part1 | day_1__con_3__part2 | day_1__con_3__part3 | day_1__con_3__part4 | day_1__con_4__part1 | day_1__con_4__part2 | day_1__con_4__part3 | day_1__con_4__part4 | day_1__con_5__part1 | day_1__con_5__part2 | day_1__con_5__part3 | day_1__con_5__part4 | day_1__con_5__part5 | day_2__con_1__part1 | day_2__con_1__part2 | day_2__con_1__part3 | day_2__con_1__part4 | day_2__con_1__part5 | day_2__con_2__part1 | day_2__con_2__part2 | day_2__con_2__part3 | day_2__con_2__part4 | day_2__con_3 | day_2__con_4 | day_2__con_5 | day_2__con_6 | day_2__con_7 | day_3__con_1 | day_3__con_2 | day_3__con_3 | day_3__con_4 | day_3__con_5 | day_3__con_6 | day_4__con_1 | day_4__con_2 | day_4__con_3 | day_4__con_4 | day_4__con_5 | day_4__con_6 | day_5__con_1 | day_5__con_2 | day_5__con_

In [27]:
# Write result to file
loc = '/Users/cgn/Dropbox (Facebook)/EGOCOM/raw_audio_speaker_labels_{}.json'.format(str(window))
def default(o):
    if isinstance(o, np.int64): return int(o)  
    raise TypeError
    
import json
with open(loc, 'w') as fp:
    json.dump(labels, fp, default = default)
fp.close()

In [28]:
# Read result into a dict
import json
with open(loc, 'r') as fp:
    labels = json.load(fp)
fp.close()

## Generate ground truth speaker labels

In [29]:
def create_gt_speaker_labels(
    df_times_speaker,
    duration_in_seconds,
    time_window_seconds = 0.5,
):
    stack = rev_times[::-1]
    stack_time = stack.pop()
    label_times = np.arange(0, duration_in_seconds, time_window_seconds)
    result = [-1] * len(label_times)

    for i, t in enumerate(label_times):
        while stack_time['endTime'] > t and stack_time['endTime'] <= t + time_window_seconds:
            result[i] = stack_time['speaker']
            if len(stack) == 0:
                break
            stack_time = stack.pop()
        
    return result

In [30]:
df = pd.read_csv("/Users/cgn/Dropbox (Facebook)/EGOCOM/ground_truth_transcriptions.csv")[
    ["key", "endTime", "speaker", ]
].dropna()

In [31]:
gt_speaker_labels = {}
for key, sdf in df.groupby('key'):
    print(key, end = " | ")
    wavs = [wavfile.read(data_dir + fn)[1] for fn in fn_dict[key]]
    duration = min(len(w) for w in wavs)
    DL = sdf[["endTime", "speaker"]].to_dict('list')
    rev_times = [dict(zip(DL,t)) for t in zip(*DL.values())]
    duration_in_seconds = np.ceil(duration / float(samplerate))
    gt_speaker_labels[key] = create_gt_speaker_labels(rev_times, duration_in_seconds, window)

day_1__con_1__part1 | day_1__con_1__part2 | day_1__con_1__part3 | day_1__con_1__part4 | day_1__con_1__part5 | day_1__con_2__part1 | day_1__con_2__part2 | day_1__con_2__part3 | day_1__con_2__part4 | day_1__con_2__part5 | day_1__con_3__part1 | day_1__con_3__part2 | day_1__con_3__part3 | day_1__con_3__part4 | day_1__con_4__part1 | day_1__con_4__part2 | day_1__con_4__part3 | day_1__con_4__part4 | day_1__con_5__part1 | day_1__con_5__part2 | day_1__con_5__part3 | day_1__con_5__part4 | day_1__con_5__part5 | day_2__con_1__part1 | day_2__con_1__part2 | day_2__con_1__part3 | day_2__con_1__part4 | day_2__con_1__part5 | day_2__con_2__part1 | day_2__con_2__part2 | day_2__con_2__part3 | day_2__con_2__part4 | day_2__con_3 | day_2__con_4 | day_2__con_5 | day_2__con_6 | day_2__con_7 | day_3__con_1 | day_3__con_2 | day_3__con_3 | day_3__con_4 | day_3__con_5 | day_3__con_6 | day_4__con_1 | day_4__con_2 | day_4__con_3 | day_4__con_4 | day_4__con_5 | day_4__con_6 | day_5__con_1 | day_5__con_2 | day_5__con_

In [32]:
# Write result to file
loc = '/Users/cgn/Dropbox (Facebook)/EGOCOM/rev_ground_truth_speaker_labels_{}.json'.format(str(window))
with open(loc, 'w') as fp:
    json.dump(gt_speaker_labels, fp, default = default)
fp.close()

In [33]:
# Read result into a dict
with open(loc, 'r') as fp:
    gt_speaker_labels = json.load(fp)
fp.close()

In [37]:
scores = []
for key in labels.keys():
    true = gt_speaker_labels[key]
    pred = labels[key]
    if len(true) > len(pred):
        true = true[:-1]
#         diff = round(accuracy_score(true[:-1], pred) - accuracy_score(true[1:], pred), 3)
#         scores.append(diff)
#         print(key, accuracy_score(true[1:], pred), accuracy_score(true[:-1], pred), diff)
    score = accuracy_score(true, pred)
    scores.append(score)
    print(key, np.round(score, 3))

day_1__con_1__part1 0.724
day_1__con_1__part2 0.763
day_1__con_1__part3 0.702
day_1__con_1__part4 0.773
day_1__con_1__part5 0.867
day_1__con_2__part1 0.68
day_1__con_2__part2 0.742
day_1__con_2__part3 0.792
day_1__con_2__part4 0.869
day_1__con_2__part5 0.85
day_1__con_3__part1 0.822
day_1__con_3__part2 0.735
day_1__con_3__part3 0.758
day_1__con_3__part4 0.867
day_1__con_4__part1 0.859
day_1__con_4__part2 0.833
day_1__con_4__part3 0.77
day_1__con_4__part4 0.72
day_1__con_5__part1 0.666
day_1__con_5__part2 0.607
day_1__con_5__part3 0.637
day_1__con_5__part4 0.67
day_1__con_5__part5 0.6
day_2__con_1__part1 0.699
day_2__con_1__part2 0.743
day_2__con_1__part3 0.723
day_2__con_1__part4 0.68
day_2__con_1__part5 0.714
day_2__con_2__part1 0.715
day_2__con_2__part2 0.68
day_2__con_2__part3 0.677
day_2__con_2__part4 0.633
day_2__con_3 0.623
day_2__con_4 0.754
day_2__con_5 0.842
day_2__con_6 0.728
day_2__con_7 0.663
day_3__con_1 0.868
day_3__con_2 0.787
day_3__con_3 0.769
day_3__con_4 0.793
day_3_

In [51]:
print('Average accuracy:', str(np.round(np.mean(scores), 3)* 100) + '%')

Average accuracy: 72.3%


In [52]:
loc = '/Users/cgn/Dropbox (Facebook)/EGOCOM/subtitles/'
for key in labels.keys():
    gt = gt_speaker_labels[key]
    est = labels[key]
    with open(loc + "speaker_" + key + '.srt', 'w') as f:
        print(key, end = " | ")
        for t, s_est in enumerate(est):
            s_gt = gt[t]
            print(t + 1, file = f)
            print(async_srt_format_timestamp(t*window), end = "", file = f)
            print(' --> ', end = '', file = f)
            print(async_srt_format_timestamp(t*window+window), file = f)
            print('Rev.com Speaker:', end = " ", file = f)
            if s_gt == -1:
                print('No one is speaking', file = f)
            elif s_gt == 1:
                print('Curtis', file = f)
            else:
                print('Speaker ' + str(s_gt), file = f)
            print('MaxMag Speaker:', end = " ", file = f)
            if s_est == -1:
                print('No one is speaking', file = f)
            elif s_est == 1:
                print('Curtis', file = f)
            else:
                print('Speaker ' + str(s_est), file = f)
            print(file = f)

day_1__con_1__part1 | day_1__con_1__part2 | day_1__con_1__part3 | day_1__con_1__part4 | day_1__con_1__part5 | day_1__con_2__part1 | day_1__con_2__part2 | day_1__con_2__part3 | day_1__con_2__part4 | day_1__con_2__part5 | day_1__con_3__part1 | day_1__con_3__part2 | day_1__con_3__part3 | day_1__con_3__part4 | day_1__con_4__part1 | day_1__con_4__part2 | day_1__con_4__part3 | day_1__con_4__part4 | day_1__con_5__part1 | day_1__con_5__part2 | day_1__con_5__part3 | day_1__con_5__part4 | day_1__con_5__part5 | day_2__con_1__part1 | day_2__con_1__part2 | day_2__con_1__part3 | day_2__con_1__part4 | day_2__con_1__part5 | day_2__con_2__part1 | day_2__con_2__part2 | day_2__con_2__part3 | day_2__con_2__part4 | day_2__con_3 | day_2__con_4 | day_2__con_5 | day_2__con_6 | day_2__con_7 | day_3__con_1 | day_3__con_2 | day_3__con_3 | day_3__con_4 | day_3__con_5 | day_3__con_6 | day_4__con_1 | day_4__con_2 | day_4__con_3 | day_4__con_4 | day_4__con_5 | day_4__con_6 | day_5__con_1 | day_5__con_2 | day_5__con_

## Generate subtitles

In [49]:
for key in labels.keys():
    gt = labels[key]
    with open("subtitles/est_" + key + '.srt', 'w') as f:
        for t, s in enumerate(gt):
            print(t + 1, file = f)
            print(async_srt_format_timestamp(t*window), end = "", file = f)
            print(' --> ', end = '', file = f)
            print(async_srt_format_timestamp(t*window+window), file = f)
            print('Max mag of wavs speaker id', file = f)
            if s == -1:
                print('No one is speaking', file = f)
            elif s == 1:
                print('Curtis', file = f)
            else:
                print('Speaker ' + str(s), file = f)
            print(file = f)