In [None]:
import numpy as np
import plotly
from plotly.offline import iplot, plot
from plotly.offline import init_notebook_mode
from plotly import graph_objs as go
init_notebook_mode()

import pandas as pd
import pickle
import os
import seaborn as sns
from ipywidgets import Image, HTML, Layout, Output, VBox, HBox
from IPython.display import Audio, display

import pickle

In [25]:
SR=48000

In [4]:
df = pd.read_pickle('spec_df.pkl')

In [5]:
df.rename(columns={'duration_s': 'dur',
                   'Start': 'start'
                  }, inplace=True)

In [7]:
df.label = [x if x!='squal' else 'squeal' for x in df.label]

In [5]:
df = df.drop(columns=['start_s'])
#df.reset_index(inplace=True)

In [6]:
distinct_colors_20 = ['#e6194b', '#3cb44b', '#ffe119', '#4363d8', '#f58231', 
                       '#911eb4', '#46f0f0', '#f032e6', '#bcf60c', '#fabebe', 
                       '#008080', '#e6beff', '#9a6324', '#fffac8', '#800000', 
                       '#aaffc3', '#808000', '#ffd8b1', '#000075', '#808080', 
                       '#ffffff', '#000000']  

In [8]:
# Load data (deserialize)
with open('image_data.pickle', 'rb') as handle:
    image_data = pickle.load(handle)

In [9]:
labeltypes = sorted(list(set(df['label'])))
if len(labeltypes)<=len(distinct_colors_20):
    color_dict = dict(zip(labeltypes, distinct_colors_20[0:len(labeltypes)]))
else:
    distinct_colors = distinct_colors_20*len(labeltypes)
    color_dict = dict(zip(labeltypes, distinct_colors[0:len(labeltypes)]))

audio_dict = {}
sub_df_dict = {}

for i,labeltype in enumerate(labeltypes):
    sub_df = df.loc[df.label==labeltype,:]
    sub_df_dict[i] = sub_df
    audio_dict[i] = sub_df['raw_audio']

#df = df.drop(columns='raw_audio')
    
traces = []
for i,labeltype in enumerate(labeltypes):
    sub_df = sub_df_dict[i]
    trace = go.Scatter3d(x=sub_df.UMAP1, 
                         y=sub_df.UMAP2, 
                         z=sub_df.UMAP3, 
                         mode='markers',
                         marker=dict(size=4,
                                     color=color_dict[labeltype],
                                     opacity=0.8),
                         name=labeltype,
                         #hovertemplate = [str(x) for x in sub_df.index_id])
                         hovertemplate = [x for x in sub_df.label])
    traces.append(trace)

layout = go.Layout(
    scene=go.layout.Scene(
        xaxis = go.layout.scene.XAxis(title='UMAP1'),
        yaxis = go.layout.scene.YAxis(title='UMAP2'),
        zaxis = go.layout.scene.ZAxis(title='UMAP3')),
        height = 700,
        width = 700)

figure = go.Figure(data=traces, layout=layout)

In [10]:
fig = go.FigureWidget(figure)

In [13]:
image_widget = Image(
    value=image_data['7.3.2004.PQ.d3b_35.655'],
    #layout=Layout(height='252px', width='400px')
    #layout=Layout(height='126px', width='200px')
    layout=Layout(height='189px', width='300px')
)

In [14]:
details = HTML()

In [20]:
def hover_fn(trace, points, state):
    
    if points.point_inds:
        # get the index of the data point being hovered
        trace_ind = points.trace_index
        sub_df = sub_df_dict[trace_ind]
        ind = points.point_inds[0]

        # Update image widget
        img_ind = sub_df.iloc[ind]['callID']
        image_widget.value = image_data[img_ind]  

        # Update details
        details.value = sub_df.iloc[ind][['UMAP1', 'UMAP2', 'UMAP3', 'label', 'project_id', 'start']].to_frame().to_html()


for i in range(len(traces)):
    fig.data[i].on_hover(hover_fn)

In [27]:
def play_audio(ind, i):
    data = audio_dict[i]
    display(Audio(data.iloc[ind], rate=SR, autoplay=True))
    
audio_widget = Output()
audio_widget.layout.visibility = 'hidden'

def click_fn(trace, points, selector):
    
    if points.point_inds:
        
        # get the index of the data point being hovered
        trace_ind = points.trace_index
        ind = points.point_inds[0]
        
        with audio_widget:
            play_audio(ind, trace_ind)

for i in range(len(traces)):
    fig.data[i].on_click(click_fn)

In [28]:
HBox([
      VBox([details, image_widget, audio_widget]),
    fig])

HBox(children=(VBox(children=(HTML(value='<table border="1" class="dataframe">\n  <thead>\n    <tr style="text…