In [1]:
%load_ext autoreload
%autoreload 2

import pandas as pd
import pickle

from ipywidgets import Image, HTML, Layout, Output, VBox, HBox
from IPython.display import Audio, display
from path import CWD, DATA
from plotly import graph_objs as go


LABEL_COL = "label"       # Column that contains label - calls in the plot will be colored by this column
ID_COL = "callID"         # Column that contains call identifier
AUDIO_COL = "raw_audio"   # Column that contains audio data, which is played back.
                          # could also use filtered_audio

HOVER_COLS = [LABEL_COL, ID_COL] # hover_details are those data that will
                                 # be provided in the info table when hovering
                                 # over the plot. Add any columns of df that you like.

In [2]:
distinct_colors_20 = [
    '#e6194b',
    '#3cb44b',
    '#ffe119',
    '#4363d8',
    '#f58231',
    '#911eb4',
    '#46f0f0',
    '#f032e6',
    '#bcf60c',
    '#fabebe',
    '#008080',
    '#e6beff',
    '#9a6324',
    '#fffac8',
    '#800000',
    '#aaffc3',
    '#808000',
    '#ffd8b1',
    '#000075',
    '#808080',
    '#ffffff',
    '#000000'
]

P_DIR = str(CWD)

DF_NAME = DATA.joinpath('df_umap.pkl')

# Load dataframe
df = pd.read_pickle(DF_NAME)

# Load image data (deserialize)
with open(DATA.joinpath('image_data.pkl'), 'rb') as handle:
    image_data = pickle.load(handle)

if ID_COL not in df.columns:
    print("Missing identifier column: ", ID_COL)
    raise

labeltypes = sorted(
    list(
        set(df[LABEL_COL])
    )
)

if len(labeltypes) <= len(distinct_colors_20):
    color_dict = dict(
        zip(
            labeltypes,
            distinct_colors_20[0:len(labeltypes)]
        )
    )
else:
    # if > 20 different labels, some will have the same color
    distinct_colors = distinct_colors_20 * len(labeltypes)
    color_dict = dict(
        zip(
            labeltypes,
            distinct_colors[0:len(labeltypes)]
        )
    )

In [3]:
# hover_details are those data that will be provided in the info table when hovering
# over datapoints

hover_details = HOVER_COLS

# Everything here is separated by labeltype, so that all datapoints from one specific label have their own trace

audio_dict = {} # dictionary that contains audio data for each labeltype
sr_dict = {} # dictionary that contains samplerate data for each labeltype
sub_df_dict = {} # dictionary that contains the dataframe for each labeltype

# build dictionary
for i, labeltype in enumerate(labeltypes):
    sub_df = df.loc[df.label == labeltype, :]
    sub_df_dict[i] = sub_df
    audio_dict[i] = sub_df[AUDIO_COL]
    sr_dict[i] = sub_df['samplerate_hz']

# build traces
traces = []

for i, labeltype in enumerate(labeltypes):
    sub_df = sub_df_dict[i]
    trace = go.Scatter3d(
            x=sub_df.UMAP1,
            y=sub_df.UMAP2,
            z=sub_df.UMAP3,
            mode='markers',
            marker=dict(
                size=4,
                color=color_dict[labeltype],
                opacity=0.8
            ),
            name=labeltype,
            hovertemplate=[x for x in sub_df[LABEL_COL]]
    )
    traces.append(trace)

layout = go.Layout(
    scene=go.layout.Scene(
            xaxis=go.layout.scene.XAxis(title='UMAP1'),
            yaxis=go.layout.scene.YAxis(title='UMAP2'),
            zaxis=go.layout.scene.ZAxis(title='UMAP3')
    ),
    height=1000,
    width=1000
)

figure = go.Figure(data=traces, layout=layout)

fig = go.FigureWidget(figure)

# Initialize with any image (taking the first in the dictionary)

print(image_data[list(image_data.keys())[0]])

image_widget = Image(
    value=image_data[list(image_data.keys())[0]],
    layout=Layout(height='189px', width='300px')
)

details = HTML(
    layout=Layout(width='20%')
)


# define what happens when hovering over datapoint
def hover_fn(trace, points, state):
    if points.point_inds:
        # get the index of the data point being hovered
        trace_ind = points.trace_index
        sub_df = sub_df_dict[trace_ind]
        ind = points.point_inds[0]
        # Update image widget
        img_ind = sub_df.iloc[ind][ID_COL]
        image_widget.value = image_data[img_ind]

        # Update details
        details.value = sub_df.iloc[ind][hover_details].to_frame().to_html()


for i in range(len(traces)):
    fig.data[i].on_hover(hover_fn)


# audio-playback function
def play_audio(ind, i):
    data = audio_dict[i]
    srs = sr_dict[i]
    display(
        Audio(
            data.iloc[ind],
            rate=srs.iloc[ind],
            autoplay=True
        )
    )

audio_widget = Output()
audio_widget.layout.visibility = 'hidden'


# define what happens when clicking on datapoint
def click_fn(trace, points, selector):
    if points.point_inds:

        # get the index of the data point being hovered
        trace_ind = points.trace_index
        ind = points.point_inds[0]

        # play audio
        with audio_widget:
            play_audio(ind, trace_ind)


for i in range(len(traces)):
    fig.data[i].on_click(click_fn)

# Put everything together.

# This renders the plot within jupyter notebook. Install voila to convert the notebook into a standalone web app
# (see https://voila.readthedocs.io/en/stable/using.html for details)
# Once installed, navigate to the jupyter notebook file in your file system and run
# > voila <path-to-02_viz_tool.ipynb>

# adjust vertical (VBox) and horizontal (HBoxes) if readibility is not good.

VBox([
    details,
    HBox(
        [image_widget, fig],
        layout=Layout(flex="none")
    ),
    audio_widget
])

b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x02\x80\x00\x00\x01\xe0\x08\x06\x00\x00\x005\xd1\xdc\xe4\x00\x00\x009tEXtSoftware\x00Matplotlib version3.5.1, https://matplotlib.org/\xd8a\xf2\xbd\x00\x00\x00\tpHYs\x00\x00\x0fa\x00\x00\x0fa\x01\xa8?\xa7i\x00\x003\xbcIDATx\x9c\xed\xddyt\x14e\xa2\xfe\xf1\xa7;\x9dtBH:\x84\x84\x84H\x80\xb0\x99(;\xca& HdS\x14e\\p\x19AD\xf42(p\x87\x9f\xe2\x02\xa3\xe3\x18\xb6+\x1ePAG\x84\xcb\x15\\pP\x06\x1d\x19\xb9 *\x10D\x10\x04T\xa2A\x10$$\x08\x98\x04B\xf6\xae\xdf\x1f\x86\xbeiAeIR\x1d\xde\xef\xe7\x9c>\xd2\xd5U\xd5O\xa5\x9a\xf0\xf8v\xd7\xdb\x0e\xcb\xb2,\x01\x00\x00\xc0\x18N\xbb\x03\x00\x00\x00\xa0fQ\x00\x01\x00\x00\x0cC\x01\x04\x00\x000\x0c\x05\x10\x00\x00\xc00\x14@\x00\x00\x00\xc3P\x00\x01\x00\x00\x0cC\x01\x04\x00\x000\x0c\x05\x10\x00\x00\xc00\x14@\x00\x00\x00\xc3P\x00\x01\x00\x00\x0cC\x01\x04\x00\x000\x0c\x05\x10\x00\x00\xc00\x14@\x00\x00\x00\xc3P\x00\x01\x00\x00\x0cC\x01\x04\x00\x000\x0c\x05\x10\x00\x00\xc00\x14@\x00\x00\x00\xc3P\x00\x01\x00\x00\x0cC\x01\x04\

VBox(children=(HTML(value='', layout=Layout(width='20%')), HBox(children=(Image(value=b'\x89PNG\r\n\x1a\n\x00\…