# 10_interactive_vis.ipynb

In [10]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append('/Users/cchu/Desktop/phd_work/hyperChromatin/src/PoincareMaps')
import os
workdir = '../results/10'
os.makedirs(workdir, exist_ok=True)

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.express as px
import plotly.graph_objects as go
import plotly.subplots as sp
from refs import celltype_colors
from main import *
from poincare_maps import *

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [3]:
dist_from_origins_fn = "../results/05/dist_from_origins.csv"
dist_from_origins_df = pd.read_csv(dist_from_origins_fn, sep=',')

In [6]:
dist_from_origins_df.query("name == 'RNA PCA'")

Unnamed: 0,x,y,labels,dist_from_origin,dist_from_TAC-1,name
0,-0.041638,-0.006655,TAC-1,0.084382,0.059359,RNA PCA
1,-0.107620,0.002094,TAC-1,0.216119,0.176808,RNA PCA
2,0.031848,-0.027072,TAC-1,0.083649,0.117622,RNA PCA
3,-0.092457,0.012742,TAC-1,0.187207,0.160049,RNA PCA
4,-0.052169,-0.081695,TAC-1,0.194473,0.111875,RNA PCA
...,...,...,...,...,...,...
6431,0.098845,0.023981,Hair Shaft-cuticle.cortex,0.204131,0.275847,RNA PCA
6432,0.153514,0.008973,Hair Shaft-cuticle.cortex,0.310012,0.372287,RNA PCA
6433,0.028187,0.018011,Hair Shaft-cuticle.cortex,0.066925,0.148843,RNA PCA
6434,0.160933,-0.086035,Hair Shaft-cuticle.cortex,0.369108,0.394774,RNA PCA


In [28]:
datasets_labels

{'ATAC PCA (standard)': array(['TAC-1', 'TAC-1', 'TAC-1', ..., 'Hair Shaft-cuticle.cortex',
        'Hair Shaft-cuticle.cortex', 'Hair Shaft-cuticle.cortex'],
       dtype=object),
 'ATAC Simba': array(['TAC-2', 'Medulla', 'IRS', ..., 'TAC-1', 'Medulla', 'Medulla'],
       dtype=object),
 'RNA PCA': array(['TAC-1', 'TAC-1', 'TAC-1', ..., 'Hair Shaft-cuticle.cortex',
        'Hair Shaft-cuticle.cortex', 'Hair Shaft-cuticle.cortex'],
       dtype=object),
 'RNA Simba': array(['TAC-2', 'IRS', 'Hair Shaft-cuticle.cortex', ..., 'TAC-1',
        'Medulla', 'TAC-1'], dtype=object),
 'Simba Multi': array(['TAC-1', 'TAC-1', 'TAC-2', ..., 'TAC-2', 'TAC-1', 'Medulla'],
       dtype=object)}

In [32]:
import numpy as np
import plotly.graph_objects as go
from dash import Dash, dcc, html, Input, Output, State
import dash_bootstrap_components as dbc
from dash import ctx

# Initial data



def complex_to_list(z_array):
    return [[z.real, z.imag] for z in z_array]

def list_to_complex(z_list):
    return np.array([complex(x, y) for x, y in z_list])

def mobius_to_origin(z, z0):
    """Apply Möbius transformation to move z0 to the origin."""
    return (z - z0) / (1 - np.conj(z0) * z)

def plot_poincare_disk(z, labels=None, colors=None):
    t = np.linspace(0, 2 * np.pi, 200)
    fig = go.Figure()

    # Disk boundary
    fig.add_trace(go.Scatter(x=np.cos(t), y=np.sin(t), mode='lines', name='Boundary'))

    # Points with labels and colors
    fig.add_trace(go.Scatter(
        x=z.real, y=z.imag,
        mode='markers',
        marker=dict(
            size=3,
            color=colors if colors is not None else 'blue',
            # colorscale='Viridis',
            # colorbar=dict(title='Color') if colors is not None else None
        ),
        text=labels if labels is not None else None,
        hoverinfo='text',
        name='Points',
    ))

    fig.update_layout(
        xaxis=dict(scaleanchor="y", range=[-1.1, 1.1], visible=False),
        yaxis=dict(range=[-1.1, 1.1], visible=False),
        showlegend=False,
        width=500,
        height=500,
    )

    return fig

# Dash app
app = Dash(__name__)


datasets = {name: list_to_complex(np.array(linear_scale(coor_df[['x', 'y']].values))) for name, coor_df in dist_from_origins_df.groupby('name')}
datasets_labels = {name: coor_df['labels'].values for name, coor_df in dist_from_origins_df.groupby('name')}
datasets_colors = {name: coor_df['labels'].replace(celltype_colors) for name, coor_df in dist_from_origins_df.groupby('name')}

coor_df = dist_from_origins_df.query("name == 'RNA PCA'").copy()
zoom1_coor_df = coor_df.copy()
zoom1_coor_df[['x', 'y']] = np.array(linear_scale(coor_df[['x', 'y']].values))

z_init = list_to_complex(zoom1_coor_df[['x', 'y']].values)

app.layout = html.Div([
    html.H4("Choose Dataset"),
    dcc.Dropdown(
        id='dataset-selector',
        options=[{'label': name, 'value': name} for name in datasets],
        value=list(datasets.keys())[0]  # Default
    ),
    dcc.Graph(id='poincare-plot', config={'displayModeBar': False}),
    html.Button("Reset View", id="reset-button"),
    dcc.Store(id='z-coords'),
    dcc.Store(id='z-original'),
])

@app.callback(
    Output('poincare-plot', 'figure'),
    Output('z-coords', 'data'),
    Output('z-original', 'data'),
    Input('dataset-selector', 'value'),
    Input('poincare-plot', 'clickData'),
    Input('reset-button', 'n_clicks'),
    State('z-coords', 'data'),
    State('z-original', 'data'),
)
def update_poincare_plot(dataset_name, clickData, reset_clicks, z_data, z_original_data):
    triggered_id = ctx.triggered_id

    # Load fresh data if dataset changed
    if triggered_id == 'dataset-selector' or z_data is None:
        z = datasets[dataset_name]
        z_list = complex_to_list(z)
        fig = plot_poincare_disk(z, labels=datasets_labels[dataset_name], colors=datasets_colors[dataset_name])
        return fig, z_list, z_list

    # Reset view
    elif triggered_id == 'reset-button':
        z = list_to_complex(z_original_data)

    # Click-to-center
    else:
        z = list_to_complex(z_data)
        if clickData and 'points' in clickData:
            x_click = clickData['points'][0]['x']
            y_click = clickData['points'][0]['y']
            z0 = x_click + 1j * y_click
            z = mobius_to_origin(z, z0)

    # Plot transformed data
    fig = plot_poincare_disk(z, labels=datasets_labels[dataset_name], colors=datasets_colors[dataset_name])
    return fig, complex_to_list(z), z_original_data


if __name__ == '__main__':
    app.run(debug=True)