In [None]:
from dash import Dash, html, dcc, callback, Output, Input
import dash_bootstrap_components as dbc
import plotly.express as px
import plotly.graph_objects as go
from copy import deepcopy 
import plotly.graph_objects as go 

import pandas as pd
import numpy as np
import math

In [None]:
df = pd.read_csv('Data\Ruhland2016.DESeq2.results.csv' , index_col=0)
df_count = pd.read_csv( "Data/Ruhland2016.norm_counts.csv" , index_col= 0 )
df_Go = pd.read_csv( "Data/Ruhland2016.Mm.GOterm.csv" , index_col= 0 )
df_count_t = df_count.transpose()
df_count_t['condition'] = [ x[:-1] for x in  df_count_t.index ]


display(df.head())
display(df_count_t.head())
display(df_Go.head())

In [None]:
import dash
from dash import dcc, html, Input, Output, callback
import plotly.graph_objects as go
import pandas as pd
import numpy as np
import plotly.express as px
from dash.dash_table import DataTable
import dash_bootstrap_components as dbc
from copy import deepcopy


def make_volcano_plot(df, p_value_magnitude, log_fold_change):
    fig = go.Figure()

    p_value = 10**p_value_magnitude
    # Adjusted p-value and log2FoldChange thresholding
    significant = (df.padj < p_value) & (df.log2FoldChange.abs() > log_fold_change)

    colors = np.array(['lightgrey', 'blue', 'red'])[1 * significant * (1 + (df.log2FoldChange > 0))]

    fig.add_trace(go.Scatter(
        x=df.log2FoldChange,
        y=df.padj,
        mode='markers',
        marker_color=colors,
        hovertemplate="gene name\t: %{text}<br>" +
                      "logFC\t\t: %{x:.1f}",
        text=df.gene_name,
        customdata=df.index
    ))

    fig.update_yaxes(type='log', autorange="reversed", exponentformat='power', title_text='adjusted p-value')
    fig.update_xaxes(title_text='log2 Fold Change')
    fig.add_hline(y=p_value, line_dash="dash")
    fig.add_vline(x=-1*log_fold_change, line_dash="dash")
    fig.add_vline(x=1*log_fold_change, line_dash="dash")

    fig.update_layout(clickmode='event+select')

    return fig

# Initialize the Dash app
external_stylesheets = [dbc.themes.CERULEAN]
app = dash.Dash("Volcano Plot App", external_stylesheets=external_stylesheets)

# Layout with Dropdown for Geneset above graph 1
# Layout with Sliders for p-value and log fold change
# Layout with Sliders below the volcano plot
app.layout = dbc.Container([
    dbc.Row([
        dbc.Col([
            html.H5('Select Gene Set:'),
            dcc.Dropdown(
                id='gene_set-dropdown',
                options=[
                    {'label': 'None', 'value': 'None'},
                    {'label': 'Inflammatory Response', 'value': 'GO:0006954'},
                    {'label': 'Chemotaxis', 'value': 'GO:0006935'},
                    {'label': 'JAK-STAT Cascade', 'value': 'GO:0007259'}
                ],
                value='None',  # default value
                clearable=False)
        ], width=6),
        dbc.Col([
        html.H5('Search for a Gene'),
        dcc.Input(
            id='gene-search-bar',
            type='text',
            placeholder='Search for a gene...',
            debounce=True,  # Only update on typing stop
            style={'width': '100%'}
        ),
        ], width=6)
    ]),
    dbc.Row([
        dbc.Col(dcc.Graph(id='volcano-plot'), width=6),  # Volcano plot takes half width
        dbc.Col(dcc.Graph(id='graph-content'), width=6)  # Second plot showing selected gene data
    ]),
    dbc.Row([
        dbc.Col([
            html.H5('Adjust p-value and log Fold Change'),
            html.Label('P-value Magnitude:'),
            dcc.Slider(
                id='pvalue-slider',
                min=-40,
                max=0,
                step=0.5,
                value=-2,  # default value
                marks={i: f'{i}' for i in np.arange(-40, 0, 5).tolist()},
                tooltip={"placement": "bottom", "always_visible": True}
            ),
            html.Label('log Fold Change threshold:'),
            dcc.Slider(
                id='logfc-slider',
                min=0,
                max=max(df.log2FoldChange),
                step=0.1,
                value=1,  # default value
                marks={i: f'{i}' for i in range(0, math.ceil(max(df.log2FoldChange)))},
                tooltip={"placement": "bottom", "always_visible": True}
            )
        ], width={'size': 10, 'offset': 1})
    ]),
    dbc.Row([  # Table to show significant genes
    dbc.Col([
        html.H5('Significant Genes'),
        DataTable(
            id='significant-genes-table',
            columns=[
                {'name': 'Gene Name', 'id': 'gene_name'},
                {'name': 'log2 Fold Change', 'id': 'log2FoldChange'},
                {'name': 'Adjusted p-value', 'id': 'padj'}
            ],
            style_table={'height': '300px', 'overflowY': 'auto'},
            style_cell={'textAlign': 'center', 'padding': '5px'},
            style_header={'backgroundColor': 'lightgrey', 'fontWeight': 'bold'}
            )
        ], width={'size': 10, 'offset': 1})
    ])
], fluid=True)

# Callback to update volcano plot based on slider values
@callback(
    [Output('volcano-plot', 'figure'),
    Output('significant-genes-table', 'data')],
    [Input('pvalue-slider', 'value'),
     Input('logfc-slider', 'value'),
     Input('gene_set-dropdown', 'value'),
     Input('gene-search-bar', 'value')] 
)
def update_volcano_plot(pvalue, logfc, gene_set, gene_search):
    p_value = 10**pvalue
    significant = (df.padj < p_value) & (df.log2FoldChange.abs() > logfc)

    # Filter the data based on GO term selection (if applicable)
    if gene_set != 'None':
        go_to_geneset = df_Go.groupby('GO').ENSEMBL.agg(set).to_dict()
        sub_df = df[df.index.isin(go_to_geneset[gene_set])]
    else:
        sub_df = df

    # If a gene is being searched, filter the DataFrame to highlight that gene
    if gene_search:
        sub_df = sub_df[sub_df['gene_name'].str.contains(gene_search, case=False, na=False)]
    
    # Get the significant genes
    significant_genes = sub_df[significant]

    # Create volcano plot
    fig = make_volcano_plot(sub_df, p_value_magnitude=pvalue, log_fold_change=logfc)

    # Highlight the selected gene if present
    if gene_search:
        if not sub_df.empty:
            selected_gene = sub_df[sub_df['gene_name'].str.contains(gene_search, case=False, na=False)]
            if not selected_gene.empty:
                gene_index = selected_gene.index[0]  # Get the index of the selected gene
                
                # Add the selected gene as a new trace (highlight it with a yellow star)
                fig.add_trace(go.Scatter(
                    x=[selected_gene['log2FoldChange'].values[0]],
                    y=[selected_gene['padj'].values[0]],
                    mode='markers',
                    marker=dict(size=10, color='yellow', symbol='star'),
                    name=f"Selected Gene: {gene_search}",
                    hovertemplate=f"Gene Name: {gene_search}<br>logFC: {selected_gene['log2FoldChange'].values[0]:.2f}<br>p-value: {selected_gene['padj'].values[0]:.2e}"
                ))
                
    # Lock axis ranges to prevent zooming into the selected gene
    fig.update_layout(
        xaxis=dict(range=[sub_df['log2FoldChange'].min(), sub_df['log2FoldChange'].max()]),
        yaxis=dict(range=[sub_df['padj'].min(), sub_df['padj'].max()])
    )
    
    significant_genes = sub_df[significant]
    table_data = significant_genes[['gene_name', 'log2FoldChange', 'padj']].to_dict('records')

    return fig, table_data



# Callback to update the strip plot based on volcano plot click
@callback(
    Output('graph-content', 'figure'),
    Input('volcano-plot', 'clickData')
)
def update_graph(clickData):
    if clickData is None:
        fig2 = px.strip(df_count_t, x='condition', y='ENSMUSG00000100480')

        return fig2

    gid = clickData['points'][0]['customdata']
    label = clickData['points'][0]['text']
    logFC = clickData['points'][0]['x']
    padj = clickData['points'][0]['y']

    fig2 = px.strip(df_count_t,
                    x='condition', y=gid,
                    title="log2-FC {:.1f} - adjusted p-value {:.1e}".format(logFC, padj))
    fig2.update_yaxes(title_text=label)

    return fig2

# Run the app
if __name__ == '__main__':
    app.run_server(debug=True, host='127.0.0.1', port=8053)
