In [1]:
"""
multi-parameter screening plots
Widgets > Embed Widgets
conda env base2

@todo:
    need max F0?
    
"""

import numpy as np
import plotly.express as px
import pandas as pd
import plotly.io as pio
import plotly.graph_objects as go
from ipywidgets import widgets

pio.renderers.default='browser'

# data conditioning (filter out everything that does not pass these criteria)
min_dff = 0
max_rise = 4
min_rise = 0.1
min_decay = 0.01

csv_dir = r'Z:\ilya\code\fastGCaMP_analysis\python-plots\data/data_all_20200325_GCaMP96uf.csv'


data = pd.read_csv(csv_dir, na_values = '#NUM!')

data = data.set_index(data['Construct'])
data = data.drop(columns = data.columns[data.columns.str.contains('Unnamed:')], axis=1) # drop unnamed columns
data = data.drop(index = ['TEOnly','TE only', 'TE-only', 'none', '376.13']) # 

mapping = {'GCaMP6s': '10.641' , 'jGCaMP8f': '500.456', 'jGCaMP8m': '500.686', 'jGCaMP8s': '500.688','jGCaMP8.712': '500.712',
           'GCaMP6f': '10.693', 'jGCaMP7f': '10.921', 'jGCaMP7s': '10.1473', 'jGCaMP7c': '10.1513', 'jGCaMP7b': '10.1561',
           'XCaMP-Gf': '538.1', 'XCaMP-G': '538.2', 'XCaMP-Gf0': '538.3'}

# rename columns to user friendly names
s = data.columns

for nAP in [1, 3, 10, 160]:
    # replace nFP with 'n AP'
    s = s.str.replace(str(nAP) + 'FP', str(nAP) + ' AP')
# replace TimeToPeak with Time to peak
s = s.str.replace('TimeToPeak', 'Time to peak')

data.columns = s
# replace 1FP with DF/F (1FP)
data = data.rename(columns = {'1 AP': 'DF/F (1 AP)', '3 AP': 'DF/F (3 AP)', '10 AP': 'DF/F (10 AP)', '160 AP': 'DF/F (160 AP)'})

# remove NaNs
data.dropna(axis = 0, how = 'any')

## condition data to remove erroneous results
data_filt = data[data['Rise (3 AP)'] < max_rise]
data_filt = data_filt[(data_filt['DF/F (1 AP)'] > min_dff) & (data_filt['DF/F (3 AP)'] > min_dff) 
                      & (data_filt['DF/F (10 AP)'] > min_dff) & (data_filt['DF/F (160 AP)'] > min_dff)]
data_filt = data_filt[(data_filt['Rise (1 AP)'] > min_rise) & (data_filt['Rise (3 AP)'] > min_rise) 
                      & (data_filt['Rise (10 AP)'] > min_rise) & (data_filt['Rise (160 AP)'] > min_rise)]
data_filt = data_filt[(data_filt['Decay (1 AP)'] > min_decay) & (data_filt['Decay (3 AP)'] > min_decay) 
                      & (data_filt['Decay (10 AP)'] > min_decay) & (data_filt['Decay (160 AP)'] > min_decay)]


# hightlight_txt_array for plotting highlight names in mapping. '' for non-highlighted, name for highlighted
mapping_swapped = dict([(value, key) for key, value in mapping.items()])
hightlight_txt_array = [mapping_swapped.get(c_id) if (c_id in mapping_swapped.keys()) else '' for c_id in data_filt['Construct'] ]

# scatter_fig = px.scatter(data_filt, x="1FP", y="Rise (1FP)", color="Decay (1FP)", 
#                          size='Decay (1FP)', hover_data=['Construct'])

print('Total constructs: ' + str(len(data)))
print('Filtered constructs: ' + str(len(data_filt)))
print('Filtered out: ' + str(len(data) - len(data_filt)))

g = go.FigureWidget({
    'data': [{'customdata': data_filt['Construct'],
              'hovertemplate': '<b>%{customdata}</b><br>DF/F (1 AP)=%{x:.3f}<br>Rise (1 AP)=%{y:.3f}<br>Decay (1 AP)=%{marker.color:.3f}', 
              # ('%{x}<br>Rise (1FP)=%{y}<br' ... '{customdata[0]}<extra></extra>'),
              'legendgroup': '',
              'marker': {'color': data_filt['Decay (1 AP)'],
                         'coloraxis': 'coloraxis',
                         'size': 10, # data_filt['Decay (1FP)'],
                         'sizemode': 'area',
                         'sizeref': 0.02,
                         'symbol': 'circle'},
              'text': hightlight_txt_array,
              'textposition': 'top center',
              'mode': 'text+markers',
              'name': '',
              'orientation': 'v',
              'showlegend': False,
              'type': 'scatter',
              'x': data_filt['DF/F (1 AP)'],
              'xaxis': 'x',
              'y': data_filt['Rise (1 AP)'],
              'yaxis': 'y'}],
    'layout': {'coloraxis': {'colorbar': {'title': {'text': 'Decay (1 AP)'}},
                             'colorscale': [[0.0, '#0d0887'], [0.1111111111111111,
                                            '#46039f'], [0.2222222222222222,
                                            '#7201a8'], [0.3333333333333333,
                                            '#9c179e'], [0.4444444444444444,
                                            '#bd3786'], [0.5555555555555556,
                                            '#d8576b'], [0.6666666666666666,
                                            '#ed7953'], [0.7777777777777778,
                                            '#fb9f3a'], [0.8888888888888888,
                                            '#fdca26'], [1.0, '#f0f921']]},
               'legend': {'itemsizing': 'constant', 'tracegroupgap': 0},
               'margin': {'t': 60},
               'height': 500,
               'width' : 700,
               'xaxis': {'anchor': 'y', 'domain': [0.0, 1.0], 'title': {'text': 'DF/F (1 AP)'}},
               'yaxis': {'anchor': 'x', 'domain': [0.0, 1.0], 'title': {'text': 'DF/F (1 AP)'}}}
})

plottableVars = ['DF/F (1 AP)', 'DF/F (3 AP)', 'DF/F (10 AP)', 'DF/F (160 AP)',
       'Rise (1 AP)', 'Rise (3 AP)', 'Rise (10 AP)', 'Rise (160 AP)',
       'Time to peak (1 AP)', 'Time to peak (3 AP)', 'Time to peak (10 AP)',
       'Time to peak (160 AP)', 'Decay (1 AP)', 'Decay (3 AP)',
       'Decay (10 AP)', 'Decay (160 AP)', 
       'SNR (1 AP)', 'SNR (3 AP)', 'SNR (10 AP)', 'SNR (160 AP)',
       'Norm. F0']

# GUI elements
x_dropdown = widgets.Dropdown(
    options=plottableVars,
    value='DF/F (1 AP)',
    description='X axis:',
)
y_dropdown = widgets.Dropdown(
    options=plottableVars,
    value='Rise (1 AP)',
    description='Y axis:',
)
color_dropdown = widgets.Dropdown(
    options=plottableVars,
    value='Decay (1 AP)',
    description='color:',
)
xscale_radio = widgets.RadioButtons(
    options=['linear', 'log'],
    description='X scale:',
    disabled=False
)
yscale_radio = widgets.RadioButtons(
    options=['linear', 'log'],
    description='Y scale:',
    disabled=False
)
# output widget
out = widgets.Output(layout={'border': '1px solid black'})

def response(change):
    with g.batch_update():
        x_val = data_filt[x_dropdown.value]
        y_val = data_filt[y_dropdown.value]
        
        g.data[0]['x'] = x_val# np.log10(x_val) if is_log_xaxis else x_val
        
        g.data[0]['y'] = y_val# np.log10(y_val) if is_log_yaxis else y_val
        g.data[0].marker.color = data_filt[color_dropdown.value]
        
        g.layout.xaxis.title.text = x_dropdown.value
        g.layout.yaxis.title.text = y_dropdown.value

        g.layout.coloraxis.colorbar.title.text = color_dropdown.value
        g.layout.xaxis.type = xscale_radio.value
        g.layout.yaxis.type = yscale_radio.value
        
        # update hover text
        g.data[0].hovertemplate = '<b>%{customdata}</b><br>' + x_dropdown.value + '=%{x:.3f}<br>' + y_dropdown.value + '=%{y:.3f}<br>' + color_dropdown.value + '=%{marker.color:.3f}'
        
# click behavior (https://plotly.com/python/click-events/)
def update_point(trace, points, selector):
    c = list(g.data[0].marker.color)
    out.clear_output()
    out.append_display_data(pd.DataFrame(data_filt.iloc[points.point_inds[0]][plottableVars]))
    '''
    for i in points.point_inds:
        c[i] = '#bae2be'
        s[i] = 20
        with f.batch_update():
            g.data[0].marker.color = c
            g.data[0].marker.size = s
    '''
g.data[0].on_click(update_point)

x_dropdown.observe(response, names="value")
y_dropdown.observe(response, names="value")
color_dropdown.observe(response, names="value")
xscale_radio.observe(response, names="value")
yscale_radio.observe(response, names="value")

scale_wdgets = widgets.HBox([xscale_radio, yscale_radio])
dropdown_wdgts = widgets.HBox([x_dropdown, y_dropdown, color_dropdown])


v = widgets.VBox([dropdown_wdgts, 
                  scale_wdgets, 
                  g,
                 out])
v


Total constructs: 776
Filtered constructs: 716
Filtered out: 60


VBox(children=(HBox(children=(Dropdown(description='X axis:', options=('DF/F (1 AP)', 'DF/F (3 AP)', 'DF/F (10…