In [None]:
import numpy as np
import os 
import sys
data_dir = './spatial_two_mics_final_eval_results/'
from glob2 import glob
import pandas as pd
import joblib
from pprint import pprint 

In [None]:
result_dirs = {2: [], 3: []}
for n_speakers in result_dirs:
    result_dirs[n_speakers] = glob(data_dir + '*_'+str(n_speakers)+'_*')
    result_dirs[n_speakers] = glob(data_dir + '*_'+str(n_speakers)+'_*')
print(result_dirs)

In [None]:
available_methods = set([])
results = {}
for n_speakers in result_dirs:
    results[n_speakers] = {}
    result_folders = result_dirs[n_speakers]
    for folder in result_folders:
        info = folder.split('_')
        genders = info[-5]
        results[n_speakers][genders] = {}
        print(n_speakers, genders)
        method_files = glob(folder + '/*.gz')
        for method_file in method_files: 
            method_result = joblib.load(method_file)
            method_name = os.path.basename(method_file).split('_metrics.gz')[0]
            results[n_speakers][genders][method_name] = method_result
            available_methods.add(method_name)

In [None]:
print(available_methods)

In [None]:
method_mapper = {'duet_deep_clustering':'DC BPD', 
                 'ground_truth_deep_clustering':'DC DS', 
                 'raw_phase_diff_deep_clustering':'DC RPD', 
                 'duet_mask':'BPD', 
                 'initial_mixture':'Initial', 
                 'ground_truth_mask':'DS'}

In [None]:
sdr_increment = {}
for n_speaker in [2, 3]:
    sdr_increment[n_speaker] = {}
    for gender in results[n_speaker]:
        sdr_increment[n_speaker][gender] = {}
        for method, display_name in method_mapper.items():
            if method == 'initial_mixture':
                continue
            metrics = results[n_speaker][gender][method]['sdr']
            m = metrics.mean() - results[n_speaker][gender]['initial_mixture']['sdr'].mean()
            sdr_increment[n_speaker][gender][display_name] = round(m, 2)

In [None]:
for n_speaker in [2, 3]:
    df = pd.DataFrame.from_dict(sdr_increment[n_speaker])
    print(df)
    print(df.to_latex())

In [None]:
df2 = pd.DataFrame.from_dict(sdr_increment[2]) 
df3 = pd.DataFrame.from_dict(sdr_increment[3])
print(pd.concat([df2, df3], axis=1).to_latex())    

In [None]:
#  go for the plots only for fm and 2 speakers
genders = ['fm', 'f', 'm']
n_sp = 2 
sdr_plots_data = {}
for gender in genders:
    sdr_plots_data[gender] = {}
    for n_speaker in [2, 3]:
        sdr_plots_data[gender][n_speaker] = {}
        for method, display_name in method_mapper.items():
            metrics = results[n_speaker][gender][method]['sdr']
            sdr_plots_data[gender][n_speaker][display_name] = metrics

In [None]:
print(sdr_plots_data)

In [None]:
# Plotly Functions 
import plotly
import plotly.tools as tls
import plotly.plotly as py
import plotly.figure_factory as ff
import plotly.graph_objs as go
plotly.offline.init_notebook_mode()
print(sdr_plots_data.keys())

In [None]:
def make_boxplots(sdr_plots_data):
    methods = ['Initial', 'DC RPD', 'DC BPD',   'DC DS', 'BPD', 'DS', ]
    colors = ['#b2182b', '#ef8a62', '#fddbc7', '#e0e0e0', '#999999', '#4d4d4d']
    width = 1000
    height = 500

    traces = []
    for i, m in enumerate(methods):
        twosp_points = sdr_plots_data[2][m]
        threesp_points = sdr_plots_data[3][m]
        yd = ['2 Speakers' for _ in np.arange(twosp_points.shape[0])] + \
             ['3 Speakers' for _ in np.arange(threesp_points.shape[0])]

        traces.append(go.Box(
                selectedpoints = np.arange(5),

                y=np.concatenate((twosp_points, threesp_points)),
                x=yd,
                name=m,
                boxpoints='all',
                jitter=0.3,
                whiskerwidth=0.8,
    #             fillcolor=colors[i],
                marker=dict(
                    size=1,
    #                 color=colors[i]
                ),
                line=dict(width=2),
            ))

        layout = go.Layout(
    #         title='Points Scored by the Top 9 Scoring NBA Players in 2012',
            yaxis=dict(
                autorange=True,
                showgrid=True,
                zeroline=True,
                dtick=5,
                gridcolor='rgb(255, 255, 255)',
                gridwidth=1,
                zerolinecolor='rgb(255, 255, 255)',
                zerolinewidth=2,
                title='Absolute SDR (dB)',
                tickfont=dict(
                    family='Old Standard TT, serif',
                    size=19,
                    color='black'
                ),
               titlefont=dict(
                   family='Old Standard TT, serif',
                    size=24,
                   color='black'
                )
            ),
            xaxis = dict(zeroline = False, 
                                   titlefont=dict(
                                        size=24,
                                    ),
                                   tickfont=dict(
                                        family='Old Standard TT, serif',
                                        size=20,
                                        color='black'
                                    )),
    #         margin=dict(
    #             l=40,
    #             r=30,
    #             b=80,
    #             t=100,
    #         ),
            paper_bgcolor='rgb(243, 243, 243)',
            plot_bgcolor='rgb(243, 243, 243)',
            boxmode= 'group',
            showlegend=True,
            boxgap=0.1,

            legend=dict(
    #                         x=0.0,

                            y=1.15,
                            yanchor = 'top',
                            xanchor= 'center', x= 0.5,
                            traceorder='normal',
                            font=dict(
                                family='sans-serif',
                                size=20,
                                color='#000'
                            ),
                            bgcolor='#E2E2E2',
                            bordercolor='#FFFFFF',
                            borderwidth=3,
                            orientation="h"
                        ),

            width = width, 
            height = height,       

        )

    fig = go.Figure(data=traces, layout=layout)
    plotly.offline.plot(fig, filename='SDRplot.html', show_link=True, auto_open=True) 
#                          image_width= width,
#                          image_height = height, 
#                          image='webp')

#     plotly.io.write_image(fig, 'SDRplot.pdf')
    
for gender in genders: 
    make_boxplots(sdr_plots_data[gender])

# plotly.offline.iplot(fig, filename='npd_paper', image_width= width,
#                      image_height = height, image='svg')