In [13]:
# Import libraries

import plotly.graph_objs as go
import numpy as np
import pandas as pd
import os
from ipywidgets import Video, Layout, VBox, HBox, HTML, widgets, interactive

task2idx_dict = {
    "vself": 0,
    "vmin": 1,
    "vmax": 2,
    "headneck": 3,
    "dtspeech": 4,
    "dtmath": 5,
    "dtcarry": 6,
    "ec": 7
}

pheno2idx_dict = {
    "ataxia": 0,
    "episodic": 1,
    "hs": 2,
    "hypokinetic": 3,
    "normal": 4,
    "nph": 5,
    "paretic": 6,
    "phobic": 7,
    "ppv": 8,
    "psychogenic": 9,
    "sensory ataxis": 10,
    "spastic": 11,
    "suspectnph": 12
}

idx2pheno_dict = {v: k for k, v in pheno2idx_dict.items()}

idx2task_dict = {v: k for k, v in task2idx_dict.items()}

def idx2task(idx):
    return idx2task_dict[idx]

def idx2pheno(idx):
    return idx2pheno_dict[idx]

def tick_val_text_tasks():
    vals = list(idx2task_dict.keys())
    texts = [idx2task(i) for i in vals]
    return (vals, texts)

def tick_val_text_phenos():
    vals = list(idx2pheno_dict.keys())
    texts = [idx2pheno(i) for i in vals]
    return (vals, texts)


In [14]:
# initialize values

num_points = 3517
data_dir = "data/"
videos_data_dir = os.path.join(data_dir, "videos/dirname/")

# Load data
motion_z_umap = np.load(os.path.join(data_dir, "motion_z_umap.npy"))[:num_points, ]
x, y = motion_z_umap[:, 0], motion_z_umap[:, 1]
tasks = np.load(os.path.join(data_dir, "tasks_labels.npy"))[:num_points, ].astype(np.int)
phenos = np.load(os.path.join(data_dir, "pheno_labels.npy"))[:num_points, ].astype(np.int)
tasks_text = ["{}".format(idx2task(i))  for i in tasks]
phenos_text = ["{}".format(idx2pheno(i))  for i in phenos]
df_info = pd.DataFrame({"Tasks":tasks_text, "Phenotypes":phenos_text})

print("x's shape = {}\ny's shape = {}\ntasks's shape = {}\nphenos's shape = {}\n".format(
        x.shape, y.shape, tasks.shape, phenos.shape
        )
     )
print(df_info.head())

x's shape = (3517,)
y's shape = (3517,)
tasks's shape = (3517,)
phenos's shape = (3517,)

      Tasks  Phenotypes
0     vself  suspectnph
1  dtspeech     spastic
2        ec  suspectnph
3    dtmath  suspectnph
4  headneck      ataxia


In [15]:
# # This section is for widget initialization

# Load video data and widget
video_data = dict()
for i in range(x.shape[0]):
    with open("data/videos/equal_phenos/equal_phenos_{}.mp4".format(i), "rb") as f:
        b = f.read()
        video_data[i] = b
video_widget = Video(
    value = video_data[0],
    layout=Layout(height='252px', width='400px')
)

# Set HTML widget for data frame
details = HTML(
    value = df_info.iloc[0].to_frame().to_html()
)

# Dropdown menu for plotting type ("tasks" or "phenotypes")
plot_type_widget = widgets.Dropdown(
    options=["Tasks", "Phenotypes"],
    value="Tasks",
    description="Plot Types:",
)

# Dropdown menu for focusing on certain labels
focus_label_widget = widgets.Dropdown(
    options=["ALL"] + tick_val_text_tasks()[1],
    value="ALL",
    description="Label focus:",
)

# Slider for opacity
def set_opacity(focus_opacity, nonfocus_opacity):
    set_relevant_opacity(scatter, 
                         plot_type_widget.value, 
                         focus_label_widget.value, 
                         focus_opacity, nonfocus_opacity)

opacity_slider = interactive(set_opacity,
                             focus_opacity=(0.0, 1.0, 0.01),
                             nonfocus_opacity=(0.0, 1.0, 0.01)
                             )

In [16]:
def scatter_plot_type(scatter, plot_type):
    
    if plot_type == "tasks":
        target_text = tasks_text
        target_tick_func = tick_val_text_tasks
        target_color = tasks
    elif plot_type == "phenos":
        target_text = phenos_text
        target_tick_func = tick_val_text_phenos
        target_color = phenos
        
    scatter.text = target_text
    scatter.marker.color = target_color
    scatter.marker.colorbar = dict(
        title = plot_type,
        tickvals = target_tick_func()[0],
        ticktext = target_tick_func()[1],
        ticks = 'outside'
    )
    scatter.marker.colorscale = "Jet"
    
def set_relevant_opacity(scatter, plot_type, focus_label, focus_opacity, nonfocus_opacity):
    alphas = np.zeros(num_points)
    if plot_type == "Tasks":
        if focus_label == "ALL":
            alphas[:] = focus_opacity
        else:
            task_index = task2idx_dict[focus_label_widget.value]
            alphas[tasks==task_index] = focus_opacity
            alphas[tasks!=task_index] = nonfocus_opacity        
    elif plot_type == "Phenotypes":
        if focus_label == "ALL":
            alphas[:] = focus_opacity
        else:
            pheno_index = pheno2idx_dict[focus_label]
            alphas[phenos==pheno_index] = focus_opacity
            alphas[phenos!=pheno_index] = nonfocus_opacity
    scatter.marker.opacity = alphas
    
    
def hover_fn(trace, points, state):
    
    # Fetch the hovered point
    ind = points.point_inds[0]
    
    # Update entry in shown dataframe
    details.value = df_info.iloc[ind].to_frame().to_html()
    
    # Update image widget
    video_widget.value = video_data[ind]

def plot_type_response(change):
    if plot_type_widget.value == "Tasks":
        scatter_plot_type(scatter, "tasks")
        focus_label_widget.options = ["ALL"] + tick_val_text_tasks()[1]
        focus_label_widget.value = "ALL"
    if plot_type_widget.value == "Phenotypes":
        scatter_plot_type(scatter, "phenos")
        focus_label_widget.options = ["ALL"] + tick_val_text_phenos()[1]
        focus_label_widget.value = "ALL"

def focus_label_response(change):
    set_relevant_opacity(scatter, 
                         plot_type_widget.value, 
                         focus_label_widget.value, 
                         opacity_slider.children[0].value,
                         opacity_slider.children[1].value)


opacity_slider.children[0].layout.width = '300px'
opacity_slider.children[1].layout.width = '300px'
opacity_slider.children[0].description = 'Focus alpha'
opacity_slider.children[1].description = 'Other alpha'

        
fig = go.FigureWidget(
    data=[
        dict(
            type='scattergl',
            x=x,
            y=y,
            mode='markers',
        )
    ],
)

fig.layout.title = 'Latent space visualization'
fig.layout.titlefont.size = 12
fig.layout.xaxis.title = "x"
fig.layout.yaxis.title = "y"
fig.layout.autosize = False
fig.layout.width, fig.layout.height = 600, 600
fig.layout.hovermode = 'closest'

scatter = fig.data[0]
scatter.hoverinfo = "text"
scatter_plot_type(scatter, "tasks")
scatter.marker.size = 8


scatter.on_hover(hover_fn)
plot_type_widget.observe(plot_type_response, names="value")
focus_label_widget.observe(focus_label_response, names="value")
HBox([fig,
      VBox([plot_type_widget, focus_label_widget, opacity_slider, video_widget, details])])

HBox(children=(FigureWidget({
    'data': [{'hoverinfo': 'text',
              'marker': {'color': array([0, 4…

In [6]:
df_info.Phenotypes.value_counts()

spastic        158
ataxia         154
normal         148
nph            148
suspectnph     147
hypokinetic    142
hs             100
ppv              3
Name: Phenotypes, dtype: int64