## jGCaMP8 interactive multiparameter plot

Created with [plotly](https://plotly.com/)

### Instructions
 Click on `Voila` button on the top banner to run and display the interactive plot

### Operation

All features can be plotted on the X axis, Y axis, or incorporated into the colormap. You can zoom in, pan, and scale axes using the pop-up menu on the top right of the plot. To return to the default view, click the "Reset axes" button on the pop-up menu. Click on any construct to show all construct features in a table at the bottom of the page. Controls, jGCaMP8 series, and XCaMP series constructs are highlighted in red.

**All features are normalized to in-plate GCaMP6s controls**. 

For example, the table below (generated by clicking on any construct) should be interpreted as "The DF/F (1 AP) of construct 500.656 is 4.96-fold higher than GCaMP6s. The half-rise time (1 AP) of construct 500.656 is 0.28-fold of GCaMP6s (i.e. 500.656 is 3.6x faster). 

| | 500.656|
| ----------- | ----------- |
| **DF/F (1 AP)** | 4.96 |
|**Half-rise time (1 AP)** | 0.28 |



### Widget controls

* **X axis / Y axis / color**: Set what to plot on each axis

* **X scale, Y scale**: linear or logarithmic axes

* **Show all construct names**: Turn on to show all construct names. _Note: the construct names may not show up immediately. You may need to pan or zoom once to have them appear.__


### Contact
ilya kolb ([email](kolbi@hhmi.org))

In [2]:
"""
multi-parameter screening plots

"""

import numpy as np
import plotly.express as px
import pandas as pd
import plotly.io as pio
import plotly.graph_objects as go
from ipywidgets import widgets
from utils import condition_df, filter_construct_range

pio.renderers.default='browser'

plottableVars = ['DF/F (1 AP)', 'DF/F (3 AP)', 'DF/F (10 AP)', 'DF/F (160 AP)',
       'Half-rise time (1 AP)', 'Half-rise time (3 AP)', 'Half-rise time (10 AP)', 'Half-rise time (160 AP)',
       'Time to peak (1 AP)', 'Time to peak (3 AP)', 'Time to peak (10 AP)',
       'Time to peak (160 AP)', 'Half-decay time (1 AP)', 'Half-decay time (3 AP)',
       'Half-decay time (10 AP)', 'Half-decay time (160 AP)', 
       'SNR (1 AP)', 'SNR (3 AP)', 'SNR (10 AP)', 'SNR (160 AP)',
        'dprime (1 AP)', 'dprime (3 AP)', 'dprime (10 AP)', 'dprime (160 AP)',
       'Norm. F0']

# data conditioning (filter out everything that does not pass these criteria)
min_dff = 0
max_rise = 4
min_rise = 0.1
min_decay = 0.01
timetopeak_max = 3

csv_dir = r'./data/data_all_GCaMP6scontrol.csv'
id_seq_dir = r'./data/jG8-id-and-seq.csv'
failed_wells_dir = r'./data/failed_wells.csv'

mapping = {'GCaMP6s': '10.641' , 'jGCaMP8f': '500.456', 'jGCaMP8m': '500.686', 'jGCaMP8s': '500.688','jGCaMP8.712': '500.712',
           'GCaMP6f': '10.693', 'jGCaMP7f': '10.921', 'jGCaMP7s': '10.1473', 'jGCaMP7c': '10.1513', 'jGCaMP7b': '10.1561',
           'XCaMP-Gf': '538.1', 'XCaMP-G': '538.2', 'XCaMP-Gf0': '538.3'}
mapping_swapped = dict([(value, key) for key, value in mapping.items()])
data = pd.read_csv(csv_dir, na_values = 'NaN', sep='\t') # data from screen

#load and process id-sequence map table
data_id_seq = pd.read_csv(id_seq_dir)
data_id_seq['ID'] = data_id_seq['ID'].str.replace('dot', '.')
data_id_seq = data_id_seq.set_index('ID', verify_integrity=True, drop=True)

# load and process id-longname map table


# load list of failed wells
data_failed_wells = pd.read_csv(failed_wells_dir)

'''
add constructs that failed segmentation to the list, i.e. constructs in data_failed_wells that are NOT in the full dataset
remove TE (dummy wells)
concatenate to main dataset
'''
failed_seg_constructs = [d for d in data_failed_wells['Construct'] if not any(data['construct'].str.fullmatch(d))]
failed_seg_constructs = [f for f in failed_seg_constructs if ('TE' not in f)]
data = pd.concat([data, pd.DataFrame(data={'construct':failed_seg_constructs})])

data = data.set_index('construct', drop=True, verify_integrity=True)

# condition data table
data = condition_df(data)

# drop 500.723-1000
data = data[[not filter_construct_range(i, '500', 722, 1000) for i in data.index]]

data.dropna(axis = 0, how = 'any') # remove NaNs
data = data.drop(columns=['d_fmax_f0', 'd_fmax_f0_p', 'es50', 'first_assay_date', 'last_assay_date', 'variant_type'\
                         ]) # remove extra columns

# add aa sequence information
data = data.join(data_id_seq) 
data['DNA sequence'] = data['DNA sequence'].replace(np.nan, 'Not available')
data['AA sequence'] = data['AA sequence'].replace(np.nan, 'Not available')
data['mutations'] = data['mutations'].replace(np.nan, 'Not available')

# pretty column names
data = data.rename(columns={'construct':'Construct', 'replicate_number': 'Number of wells', 'sequence': 'DNA sequence', 'aa_sequence': 'AA sequence'})

# add mutation name (long name)
data['Clone name'] = data['Clone name'].replace(np.nan, 'Not available')

## condition data to eliminate metrics of sensors that didn't pass QC
data_filt = data.copy()
data_filt.loc[(data['Half-rise time (1 AP)'] >= max_rise) | 
         (data['Half-rise time (3 AP)'] >= max_rise) | 
         (data_filt['DF/F (1 AP)'] <= min_dff) | 
         (data_filt['DF/F (3 AP)'] <= min_dff) | 
         (data_filt['DF/F (10 AP)'] <= min_dff) | 
         (data_filt['DF/F (160 AP)'] <= min_dff) | 
         (data_filt['Half-rise time (1 AP)'] <= min_rise) |
         (data_filt['Half-rise time (3 AP)'] <= min_rise) |
         (data_filt['Half-rise time (10 AP)'] <= min_rise) |
         (data_filt['Half-rise time (160 AP)'] <= min_rise) |
         (data_filt['Half-decay time (1 AP)'] <= min_decay) |
         (data_filt['Half-decay time (3 AP)'] <= min_decay) |
         (data_filt['Half-decay time (10 AP)'] <= min_decay) |
         (data_filt['Half-decay time (160 AP)'] <= min_decay) | 
         (data_filt['Time to peak (1 AP)'] >= timetopeak_max) | 
         (data_filt['Time to peak (3 AP)'] >= timetopeak_max) |
         (data_filt['Time to peak (10 AP)'] >= timetopeak_max) | 
         (data_filt['Time to peak (160 AP)'] >= timetopeak_max), plottableVars + list(data.columns[data.columns.str.endswith('_p')].values)]= np.nan


# for public use, filter by DF/F (1 AP), add colloquial names
data_filt_public = data_filt.sort_values(by='DF/F (1 AP)', ascending=False).copy()
data_filt_public = data_filt_public.rename(index=mapping_swapped)

# add failed segmentation label
data_filt_public.loc[data_filt_public.index.isin(failed_seg_constructs), 'DF/F (1 AP)'] = 'Failed segmentation'

# add No detectable response label
data_filt_public.loc[data_filt_public['DF/F (1 AP)'].isna(), 'DF/F (1 AP)'] = 'No detectable response'

# save filtered data
print('Saving data_filt pickle... ', end='')
data_filt_public.to_pickle('outputs/data_filt_public.pkl')
print('Done')

# save filtered data csv
print('Saving data_filt csv... ', end='')
data_filt_public.to_csv(r'outputs/data_filt_public.csv')
print('Done')

# hightlight_txt_array for plotting highlight names in mapping. '' for non-highlighted, name for highlighted
hightlight_txt_array = [mapping_swapped.get(c_id) if (c_id in mapping_swapped.keys()) else '' for c_id in data_filt.index ]
highlight_TF_array = np.logical_not(np.array(hightlight_txt_array) == '')

print('Total constructs: ' + str(len(data)))

failure_reasons = ['Failed segmentation', 'No detectable response']

for f in failure_reasons:
    print(f + ': ' + str((data_filt_public['DF/F (1 AP)'] == f).sum()))

g = go.FigureWidget({
    'data': [{'customdata': data_filt.index,
              'hovertemplate': '<b>%{customdata}</b><br>DF/F (1 AP)=%{x:.3f}<br>Half-rise time (1 AP)=%{y:.3f}<br>Half-decay time (1 AP)=%{marker.color:.3f}', 
              # ('%{x}<br>Half-rise (1FP)=%{y}<br' ... '{customdata[0]}<extra></extra>'),
              'legendgroup': '',
              'marker': {'color': data_filt['Half-decay time (1 AP)'],
                         'coloraxis': 'coloraxis',
                         'size': 10, # data_filt['Decay (1FP)'],
                         'sizemode': 'area',
                         'sizeref': 0.02,
                         'symbol': 'circle',
                         'opacity': 0.4,
                         'line' : {
                             'color': 'red',
                             'width': 2 * highlight_TF_array,
                            }
                        },
              'text': hightlight_txt_array,
              'textfont': {'color': 'red'},
              'textposition': 'top center',
              'mode': 'text+markers',
              'name': '',
              'orientation': 'v',
              'showlegend': False,
              'type': 'scatter',
              'x': data_filt['DF/F (1 AP)'],
              'xaxis': 'x',
              'y': data_filt['Half-rise time (1 AP)'],
              'yaxis': 'y'}],
    'layout': {'coloraxis': {'colorbar': {'title': {'text': 'Half-decay time (1 AP)'}},
                             'colorscale': [[0.0, '#0d0887'], [0.1111111111111111,
                                            '#46039f'], [0.2222222222222222,
                                            '#7201a8'], [0.3333333333333333,
                                            '#9c179e'], [0.4444444444444444,
                                            '#bd3786'], [0.5555555555555556,
                                            '#d8576b'], [0.6666666666666666,
                                            '#ed7953'], [0.7777777777777778,
                                            '#fb9f3a'], [0.8888888888888888,
                                            '#fdca26'], [1.0, '#f0f921']]},
               'legend': {'itemsizing': 'constant', 'tracegroupgap': 0},
               'margin': {'t': 60},
               'height': 500,
               'width' : 700,
               'xaxis': {'anchor': 'y', 'domain': [0.0, 1.0], 'title': {'text': 'DF/F (1 AP)'}},
               'yaxis': {'anchor': 'x', 'domain': [0.0, 1.0], 'title': {'text': 'Half-rise time (1 AP)'}}}
})


# GUI elements
x_dropdown = widgets.Dropdown(
    options=plottableVars,
    value='DF/F (1 AP)',
    description='X axis:',
)
y_dropdown = widgets.Dropdown(
    options=plottableVars,
    value='Half-rise time (1 AP)',
    description='Y axis:',
)
color_dropdown = widgets.Dropdown(
    options=plottableVars,
    value='Half-decay time (1 AP)',
    description='color:',
)
xscale_radio = widgets.RadioButtons(
    options=['linear', 'log'],
    description='X scale:',
    disabled=False
)
yscale_radio = widgets.RadioButtons(
    options=['linear', 'log'],
    description='Y scale:',
    disabled=False
)
show_names_chkbx = widgets.Checkbox(
    value=False,
    description='Show all construct names',
    disabled=False,
    indent=False
)

# construct info table
construct_table = widgets.Output(layout={'border': '1px solid black', 'width':'30%'})
construct_info = widgets.Output(layout={'border': '1px solid black', 'width':'40%'})

def response(change):
    with g.batch_update():
        x_val = data_filt[x_dropdown.value]
        y_val = data_filt[y_dropdown.value]
        
        g.data[0]['x'] = x_val# np.log10(x_val) if is_log_xaxis else x_val
        
        g.data[0]['y'] = y_val# np.log10(y_val) if is_log_yaxis else y_val
        g.data[0].marker.color = data_filt[color_dropdown.value]
        
        g.layout.xaxis.title.text = x_dropdown.value
        g.layout.yaxis.title.text = y_dropdown.value

        g.layout.coloraxis.colorbar.title.text = color_dropdown.value
        g.layout.xaxis.type = xscale_radio.value
        g.layout.yaxis.type = yscale_radio.value
        
        # update construct text
        if show_names_chkbx.value:
            # show all
            hightlight_txt_array = data_filt.index
            
        else:
            # show only highlights
            hightlight_txt_array = [mapping_swapped.get(c_id) if (c_id in mapping_swapped.keys()) else '' for c_id in data_filt.index ]
        g.data[0].text = hightlight_txt_array
        
        # update hover text
        g.data[0].hovertemplate = '<b>%{customdata}</b><br>' + x_dropdown.value + '=%{x:.3f}<br>' + y_dropdown.value + '=%{y:.3f}<br>' + color_dropdown.value + '=%{marker.color:.3f}'
    
# click behavior (https://plotly.com/python/click-events/)
def update_point(trace, points, selector):
    construct_table.clear_output()
    construct_info.clear_output()
    construct_table.append_display_data(pd.DataFrame(data_filt.iloc[points.point_inds[0]][plottableVars]))
    construct_info.append_display_data('Name: ' + data_filt.iloc[points.point_inds[0]]['Clone name'])
    construct_info.append_display_data('Sequence: ' + data_filt.iloc[points.point_inds[0]]['AA sequence'])
    
    
g.data[0].on_click(update_point)

x_dropdown.observe(response, names="value")
y_dropdown.observe(response, names="value")
color_dropdown.observe(response, names="value")
xscale_radio.observe(response, names="value")
yscale_radio.observe(response, names="value")
show_names_chkbx.observe(response, names='value')

scale_wdgets = widgets.HBox([xscale_radio, yscale_radio])
dropdown_wdgts = widgets.HBox([x_dropdown, y_dropdown, color_dropdown])
out_widgts = widgets.HBox([construct_table, construct_info])

v = widgets.VBox([dropdown_wdgts, 
                  scale_wdgets, 
                  show_names_chkbx,
                  g,
                 out_widgts])
v

Saving data_filt pickle... Done
Saving data_filt csv... Done
Total constructs: 822
Failed segmentation: 48
No detectable response: 87


VBox(children=(HBox(children=(Dropdown(description='X axis:', options=('DF/F (1 AP)', 'DF/F (3 AP)', 'DF/F (10…