In [1]:
import dash
import dash_core_components as dcc
import dash_html_components as html
import plotly.express as px
import pandas as pd
from braivest.model.emgVAE import emgVAE
import wandb
import tensorflow as tf

from braivest.analysis.wandb_utils import load_wandb_model
from braivest.utils import load_data
from braivest.model.emgVAE import emgVAE
from braivest.preprocess.dataset_utils import bin_data
from braivest.preprocess.wavelet_utils import get_wavelet_freqs
from braivest.analysis.plotting_utils import *

from scipy.signal import welch

Connecting julwang@pioneer.cshl.edu:3306


In [2]:
external_stylesheets = ['https://codepen.io/chriddyp/pen/bWLwgP.css']
app = dash.Dash(__name__, external_stylesheets=external_stylesheets)

In [None]:
with wandb.init(project="braivest_tutorial", job_type="download") as run:
    artifact = run.use_artifact("analysis_set:v0")
    artifact_dir = artifact.download()
subject0_sess = load_data(artifact_dir, 'sess_datas.npy', allow_pickle=True)
hypno = load_data(artifact_dir, "hypno.npy", allow_pickle=True)

In [3]:
model = load_wandb_model("juliahwang/lfp_VAE/v2l9tltt")
encodings = tf.convert_to_tensor(model.encode(subject0_sess[0])).numpy()

Instructions for updating:
Do not pass `graph_parents`.  They will  no longer be used.


<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7f6d61a2f580>

In [4]:
with wandb.init(project="braivest_tutorial", job_type="download") as run:
    wavelet_artifact = run.use_artifact("wavelet_data:v0")
    wavelet_artifact_dir = wavelet_artifact.download()
wavelet = load_data(wavelet_artifact_dir, "lfp_wave_session{}.npy".format(0), allow_pickle=True)
binned_wavelet = bin_data(wavelet, original_sample=1000, sampling_rate=0.5)

In [7]:
with wandb.init(project="braivest_tutorial", job_type="download") as run:
    raw_artifact = run.use_artifact("raw_data:v0")
    raw_artifact_dir = raw_artifact.download()
lfp = load_data(raw_artifact_dir, "lfp_session{}.npy".format(0), allow_pickle=True)
emg = load_data(raw_artifact_dir, "emg_session{}.npy".format(0), allow_pickle=True)
binned_lfp = bin_data(lfp, original_sample=1000, sampling_rate=0.5)
f, Pxx = welch(binned_lfp, fs=1000)

In [8]:
@app.callback(dash.dependencies.Output('encodings', 'figure'),
[ dash.dependencies.Input('color', 'value')])
def create_scatter(color):
    val_hypno = hypno
    if color == 'expert':
        hypno_unique = np.unique(val_hypno)
        encodings = encodings[:len(val_hypno)]
        legend = {hypno_unique[0]:'REM',hypno_unique[1]:'SWS',hypno_unique[2]:'Wake', hypno_unique[3]:'X'}
        color_map = {'REM':"#0000ff", "Wake":"#ff0000", "SWS":"#00ff00"}
        fig = plot_encodings(encodings=encodings, color=[legend[i] for i in val_hypno], color_map=color_map, x_range=(-6, 3))
    else:
        bands = {'delta': (0, 2), 'theta': (2, 4), 'beta': (4, 8), 'gamma': (8, 13)}
        start, stop = bands[color]
        color = get_feature_color(Pxx, f, start, stop)
        fig = plot_encodings(encodings, color=color, x_range = (-6, 3), scatter_kwargs={'color_continuous_scale': 'portland', 'range_color': (-3, 3)})
    return fig

In [9]:
fig = create_scatter()

In [10]:
app.layout = html.Div([
    html.Div([
        dcc.Dropdown(
            id='color',
            options=[
                {'label': 'expert', 'value': 'expert'},
                {'label': 'delta', 'value': 'delta'},
                {'label': 'theta', 'value': 'theta'},
                {'label': 'beta', 'value': 'beta'},
                {'label': 'gamma', 'value': 'gamma'}
            ],
            value='expert'
        ),
        dcc.Graph(
            id='encodings'
        ),
    ], style={'width': '49%', 'display': 'inline-block', 'padding': '0 20'}),
    html.Div([
        dcc.Graph(id='raw'),
        dcc.Graph(id='emg'),
        dcc.Graph(id='original'),
    ], style={'display': 'inline-block', 'width': '49%'})
])

In [11]:
freqs = get_wavelet_freqs(0.5, 50, 30)

In [12]:
@app.callback(
    dash.dependencies.Output('raw', 'figure'),
    [dash.dependencies.Input('encodings', 'clickData')])
def update_raw(hoverData):
    x = hoverData['points'][0]['x']
    y = hoverData['points'][0]['y']
    index = np.where(encodings == [x, y])[0][0]
    raw_index = index*2000
    data = lfp[raw_index - 5000: raw_index + 7000]
    fig = px.line(x=np.arange(-5, 7, 1/1000), y = data, title="Raw Data")
    fig.add_vrect( x0=0, x1=2, fillcolor="LightSalmon", opacity=0.5, layer="below", line_width=0)
    fig.update_layout(height=225, margin={'l': 20, 'b': 30, 'r': 10, 't': 10})
    return fig

In [13]:
@app.callback(
    dash.dependencies.Output('emg', 'figure'),
    [dash.dependencies.Input('encodings', 'clickData')])
def update_raw_emg(hoverData):
    x = hoverData['points'][0]['x']
    y = hoverData['points'][0]['y']
    index = np.where(encodings == [x, y])[0][0]
    raw_index = index*2000
    data = emg[raw_index - 5000: raw_index + 7000]
    fig = px.line(x=np.arange(-5, 7, 1/1000), y = data, title="Raw Data", range_y = (-1.1, 1.1))
    fig.add_vrect( x0=0, x1=2, fillcolor="LightSalmon", opacity=0.5, layer="below", line_width=0)
    fig.update_layout(height=225, margin={'l': 20, 'b': 30, 'r': 10, 't': 10})
    return fig

In [14]:
@app.callback(
    dash.dependencies.Output('original', 'figure'),
    [dash.dependencies.Input('encodings', 'clickData')])
def update_original(hoverData):
    x = hoverData['points'][0]['x']
    y = hoverData['points'][0]['y']
    index = np.where(encodings == [x, y])[0][0]
    data = binned_wavelet[index, :30]
    fig = px.line(x=freqs, y = data, title="Original No Scaling")
    fig.update_layout(height=225, margin={'l': 20, 'b': 30, 'r': 10, 't': 10})
    return fig

In [None]:
app.run_server(debug=True, use_reloader=False)

Dash is running on http://127.0.0.1:8050/

 * Serving Flask app "__main__" (lazy loading)
 * Environment: production
[2m   Use a production WSGI server instead.[0m
 * Debug mode: on
