In [None]:
import numpy as np
import os 
import sys
data_dir = './spatial_two_mics_final_eval_results/'
from glob2 import glob
import pandas as pd
import joblib
from pprint import pprint 

In [None]:
result_dirs = {2: [], 3: []}
for n_speakers in result_dirs:
    result_dirs[n_speakers] = glob(data_dir + '*_'+str(n_speakers)+'_*')
    result_dirs[n_speakers] = glob(data_dir + '*_'+str(n_speakers)+'_*')
print(result_dirs)

In [None]:
available_methods = set([])
results = {}
for n_speakers in result_dirs:
    results[n_speakers] = {}
    result_folders = result_dirs[n_speakers]
    for folder in result_folders:
        info = folder.split('_')
        genders = info[-5]
        results[n_speakers][genders] = {}
        print(n_speakers, genders)
        method_files = glob(folder + '/*.gz')
        for method_file in method_files: 
            method_result = joblib.load(method_file)
            method_name = os.path.basename(method_file).split('_metrics.gz')[0]
#             sampling_points = 300
#             for k, v in method_result.items():
#                 method_result[k] = np.random.choice(v, sampling_points, replace=False)
            results[n_speakers][genders][method_name] = method_result
            available_methods.add(method_name)

In [None]:
print(available_methods)

In [None]:
method_mapper = {'duet_deep_clustering':'DC BPD', 
                 'ground_truth_deep_clustering':'DC DS', 
                 'raw_phase_diff_deep_clustering':'DC RPD', 
                 'duet_mask':'BPD', 
                 'initial_mixture':'Initial', 
                 'ground_truth_mask':'DS'}

In [None]:
sdr_increment = {}
for n_speaker in [2, 3]:
    sdr_increment[n_speaker] = {}
    for gender in results[n_speaker]:
        sdr_increment[n_speaker][gender] = {}
        for method, display_name in method_mapper.items():
            if method == 'initial_mixture':
                continue
            metrics = results[n_speaker][gender][method]['sdr']
            m = metrics.mean() - results[n_speaker][gender]['initial_mixture']['sdr'].mean()
            sdr_increment[n_speaker][gender][display_name] = round(m, 2)
                        
for n_speaker in [2, 3]:
    sdr_increment[n_speaker]['all'] = {}
    for method, display_name in method_mapper.items():
        if method == 'initial_mixture':
            continue
        sdr_increment[n_speaker]['all'][display_name] = (((sdr_increment[n_speaker]['f'][display_name] + sdr_increment[n_speaker]['m'][display_name]) /2. + sdr_increment[n_speaker]['fm'][display_name]) / 2.)
        sdr_increment[n_speaker]['all'][display_name] = round(sdr_increment[n_speaker]['all'][display_name], 2)

In [None]:
for n_speaker in [2, 3]:
    df = pd.DataFrame.from_dict(sdr_increment[n_speaker])
    df = df[['f', 'fm', 'm', 'all']]
    df = df.reindex(["DC RPD", "DC BPD", "DC DS", "BPD", "DS"])
    print(df)
    print(df.to_latex())

In [None]:
df2 = pd.DataFrame.from_dict(sdr_increment[2]) 
df2 = df2[['f', 'fm', 'm', 'all']]
df2 = df2.reindex(["DC RPD", "DC BPD", "DC DS", "BPD", "DS"])
df3 = pd.DataFrame.from_dict(sdr_increment[3])
df3 = df3[['f', 'fm', 'm', 'all']]
df3 = df3.reindex(["DC RPD", "DC BPD", "DC DS", "BPD", "DS"])
print(pd.concat([df2, df3], axis=1).to_latex())    

In [None]:
#  go for the plots only for fm and 2 speakers
gender = 'fm'
n_sp = 2 
sdr_plots_data = {}

for gender in ['f', 'm', 'fm']:
    sdr_plots_data[gender] = {}
    for n_speaker in [2, 3]:
        sdr_plots_data[gender][n_speaker] = {}
        for method, display_name in method_mapper.items():
            metrics = results[n_speaker][gender][method]['sdr']
            sdr_plots_data[gender][n_speaker][display_name] = metrics

In [None]:
print(sdr_plots_data)

In [None]:
# Plotly Functions 
import plotly
import plotly.tools as tls
import plotly.plotly as py
import plotly.figure_factory as ff
import plotly.graph_objs as go
plotly.offline.init_notebook_mode()
print(sdr_plots_data.keys())

In [None]:
def make_boxplots(sdr_plots_data):
    methods = ['Initial', 'DC RPD', 'DC BPD',   'DC DS', 'BPD', 'DS', ]
    colors = ['#b2182b', '#ef8a62', '#fddbc7', '#e0e0e0', '#999999', '#4d4d4d']
    width = 700
    height = 1000
    sampling_points = 400 

    traces = []
    for i, m in enumerate(methods):
        twosp_points = np.random.choice(sdr_plots_data[2][m], sampling_points, replace=False)
        threesp_points = np.random.choice(sdr_plots_data[3][m], sampling_points, replace=False)
        estimated_density = np.concatenate((twosp_points, threesp_points))
        yd = ['2 Speakers' for _ in np.arange(twosp_points.shape[0])] + \
             ['3 Speakers' for _ in np.arange(threesp_points.shape[0])]

        traces.append(go.Box(
                y=estimated_density,
                x=yd,
                name=m,
                boxpoints='all',
                jitter=0.3,
                whiskerwidth=0.7,
    #             fillcolor=colors[i],
                marker=dict(
                    size=2,
    #                 color=colors[i]
                ),
                line=dict(width=2),
            ))

        d1 = [0.04 + i*0.081 for i in np.arange(len(methods))]
        d1[1] += 0.005
        d1[-1] += 0.005
        d1[-2] += 0.01
        
        annotations1=[
                dict(
                    x=d1[i],
                    y=1,
                    showarrow=False,
                    text=m,
                    font=dict(
                        family='sans serif',
                        size=16,
                        color='black'
                    ),
                    align='center',
                    xref='paper',
                    yref='paper',
#                     textangle=-90,

                ) for i, m in enumerate(['Initial', 'DC<br>RPD', 'DC<br>BPD',   
                            'DC<br>DS', 'BPD', 'DS', ])
            ]

        d2 = [0.56 + i*0.081 for i in np.arange(len(methods))]
        d2[2] += 0.02
        d2[-1] -= 0.01
        
        
        annotations2=[
                dict(
                    x=d2[i],
                    y=1,
                    showarrow=False,
                    text=m,
                    font=dict(
                        family='sans serif',
                        size=16,
                        color='black'
                    ),
                    align='center',
                    xref='paper',
                    yref='paper',
#                     textangle=-90,

                ) for i, m in enumerate(['Initial', 'DC<br>RPD', 'DC<br>BPD',   
                            'DC<br>DS', 'BPD', 'DS', ])
        ]

        
        
        
        layout = go.Layout(
    #         title='Points Scored by the Top 9 Scoring NBA Players in 2012',
            yaxis=dict(
                autorange=True,
                showgrid=True,
                zeroline=True,
                dtick=3.,
    #             gridcolor='rgb(255, 255, 255)',
                gridwidth=2,
    #             zerolinecolor='rgb(255, 255, 255)',
                zerolinewidth=1,
                title='Absolute SDR (dB)',
                tickfont=dict(
                    family='Old Standard TT, serif',
                    size=24,
                    color='black'
                ),
               titlefont=dict(
                   family='Old Standard TT, serif',
                    size=24,
                   color='black'
                )
            ),
                        
            annotations = annotations1 + annotations2,           
            
            xaxis = dict(zeroline = False, 
                                   titlefont=dict(
                                        size=24,
                                    ),
                                   tickfont=dict(
                                        family='Old Standard TT, serif',
                                        size=24,
                                        color='black'
                                    )),
    #         margin=dict(
    #             l=40,
    #             r=30,
    #             b=80,
    #             t=100,
    #         ),
            paper_bgcolor='rgb(255, 255, 255)',
            plot_bgcolor='rgb(255, 255, 255)',
            boxmode= 'group',
            showlegend=False,
            boxgap=0.1,

            legend=dict(
    #                         x=0.0,

                            y=1.15,
                            yanchor = 'top',
                            xanchor= 'center', x= 0.5,
                            traceorder='normal',
                            font=dict(
                                family='sans-serif',
                                size=18,
                                color='#000'
                            ),
                            bgcolor='rgb(255, 255, 255)',
#                             bordercolor='rgb(255, 255, 255)',
#                             borderwidth=3,
                            orientation="h"
                        ),

#             width = width, 
#             height = height,       

        )

    fig = go.Figure(data=traces, layout=layout)
#     plotly.offline.iplot(fig, 
#     #                      image_width= width,
#     #                      image_height = height, 
#                          filename='SDR plot')
#     plotly.offline.iplot(fig, filename='SDR plot', 
# #                          image_width= width,
# #                          image_height = height, 
#                          image='svg')

#     plotly.offline.iplot(fig, filename='SDR plot')

    
    

make_boxplots(sdr_plots_data['fm'])
# for gender in ['m', 'f', 'fm']:
#     make_boxplots(sdr_plots_data[gender])

# plotly.offline.iplot(fig, filename='npd_paper', image_width= width,
#                      image_height = height, image='svg')

In [None]:
def make_boxplots(sdr_plots_data):
    methods = ['Initial', 'DC RPD', 'DC BPD',   'DC DS', 'BPD', 'DS', ]
    colors = ['#b2182b', '#ef8a62', '#fddbc7', '#e0e0e0', '#999999', '#4d4d4d']
    width = 700
    height = 800
    sampling_points = 600 

    traces = []
    for i, m in enumerate(methods):
        twosp_points = np.random.choice(sdr_plots_data[2][m], sampling_points, replace=False)
        threesp_points = np.random.choice(sdr_plots_data[3][m], sampling_points, replace=False)
        estimated_density = np.concatenate((twosp_points, threesp_points))
        yd = ['2 Speakers' for _ in np.arange(twosp_points.shape[0])] + \
             ['3 Speakers' for _ in np.arange(threesp_points.shape[0])]

        traces.append(go.Box(
                y=estimated_density,
                x=yd,
                name=m,
                boxpoints='all',
                jitter=0.3,
                whiskerwidth=0.7,
    #             fillcolor=colors[i],
                marker=dict(
                    size=2,
    #                 color=colors[i]
                ),
                line=dict(width=2),
            ))

        d1 = [0.04 + i*0.081 for i in np.arange(len(methods))]
        d1[1] += 0.005
        d1[-1] += 0.005
        d1[-2] += 0.01
        
        annotations1=[
                dict(
                    x=d1[i],
                    y=0.98,
                    showarrow=False,
                    text=m,
                    font=dict(
                        family='sans serif',
                        size=16,
                        color='black'
                    ),
                    align='center',
                    xref='paper',
                    yref='paper',
#                     textangle=-90,

                ) for i, m in enumerate(['Initial', 'DC<br>RPD', 'DC<br>BPD',   
                            'DC<br>DS', 'BPD', 'DS', ])
            ]

        d2 = [0.56 + i*0.081 for i in np.arange(len(methods))]
        d2[2] += 0.02
        d2[-1] -= 0.01
        
        
        annotations2=[
                dict(
                    x=d2[i],
                    y=0.98,
                    showarrow=False,
                    text=m,
                    font=dict(
                        family='sans serif',
                        size=16,
                        color='black'
                    ),
                    align='center',
                    xref='paper',
                    yref='paper',
#                     textangle=-90,

                ) for i, m in enumerate(['Initial', 'DC<br>RPD', 'DC<br>BPD',   
                            'DC<br>DS', 'BPD', 'DS', ])
        ]

        
        
        
        layout = go.Layout(
    #         title='Points Scored by the Top 9 Scoring NBA Players in 2012',
            yaxis=dict(
                autorange=True,
                showgrid=True,
                zeroline=True,
                dtick=3.,
    #             gridcolor='rgb(255, 255, 255)',
                gridwidth=2,
    #             zerolinecolor='rgb(255, 255, 255)',
                zerolinewidth=1,
                title='Absolute SDR (dB)',
                tickfont=dict(
                    family='Old Standard TT, serif',
                    size=24,
                    color='black'
                ),
               titlefont=dict(
                   family='Old Standard TT, serif',
                    size=24,
                   color='black'
                )
            ),
                        
            annotations = annotations1 + annotations2,           
            
            xaxis = dict(zeroline = False, 
                                   titlefont=dict(
                                        size=24,
                                    ),
                                   tickfont=dict(
                                        family='Old Standard TT, serif',
                                        size=24,
                                        color='black'
                                    )),
    #         margin=dict(
    #             l=40,
    #             r=30,
    #             b=80,
    #             t=100,
    #         ),
            paper_bgcolor='rgb(255, 255, 255)',
            plot_bgcolor='rgb(255, 255, 255)',
            boxmode= 'group',
            showlegend=False,
            boxgap=0.1,
            
            margin=go.layout.Margin(
#             l=5,
            r=0,
            b=31,
            t=5,
            pad=0),

            legend=dict(
    #                         x=0.0,

                            y=1.15,
                            yanchor = 'top',
                            xanchor= 'center', x= 0.5,
                            traceorder='normal',
                            font=dict(
                                family='sans-serif',
                                size=18,
                                color='#000'
                            ),
                            bgcolor='rgb(255, 255, 255)',
#                             bordercolor='rgb(255, 255, 255)',
#                             borderwidth=3,
                            orientation="h"
                        ),

            width = width, 
            height = height,       

        )

    fig = go.Figure(data=traces, layout=layout)
#     plotly.offline.iplot(fig, 
#     #                      image_width= width,
#     #                      image_height = height, 
#                          filename='SDR plot')
    plotly.offline.iplot(fig, filename='SDR plot', 
                         image_width= width,
                         image_height = height, 
                         image='svg')

#     plotly.offline.iplot(fig, filename='SDR plot')

    
    

make_boxplots(sdr_plots_data['fm'])
# for gender in ['m', 'f', 'fm']:
#     make_boxplots(sdr_plots_data[gender])

# plotly.offline.iplot(fig, filename='npd_paper', image_width= width,
#                      image_height = height, image='svg')