In [20]:
# Modify display of clustermap

def create_clustermap(data, row_count):
    sns.set(font_scale=.7)
    cell_size = .6
    size = cell_size*row_count
    figsize=(size, size)
    clustermap = sns.clustermap(data,
                                annot=True,
                                fmt=".2f",
                                cmap="viridis",
                                cbar=False,
                                dendrogram_ratio=(.05, .05),
                                figsize = figsize)
    clustermap.ax_heatmap.set_facecolor("gray")
    clustermap.cax.set_visible(False)
    return clustermap


In [39]:
import polars as pl
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
import io
import base64
import pandas
import re
from collections import defaultdict
from natsort import natsorted
import dash
from dash import dcc, html
from dash.dependencies import Input, Output

data =  pl.read_csv("~/Documents/hich/results/hicrep/hicrep.tsv", separator = "\t")
states = data.select("resolution", "h", "dBPMax", "bDownSample", "chrom").unique()

def prefix(strings):
    min_len = min([len(s) for s in strings])
    prefix = ""
    for i in range(min_len):
        if len(set([s[i] for s in strings])) == 1:
            prefix += strings[0][i]
        else:
            break

class Model:
    cluster_cols = ["file1", "file2", "score"]

    def __init__(self):
        self.data = None

    def read_csv(self, filename, separator = "\t"):
        data_scc = pl.read_csv(filename, separator = separator)
        row_count = len(data_scc)

        scc_labels = pl.Series(["scc"]*row_count)
        cluster_by_scc = scc_labels.alias("cluster by")
        data_scc = data_scc.with_columns(cluster_by_scc)
        data_scc = data_scc.rename({"scc":"score"})

        distance_expression = ((.5*(1 - data_scc["score"]))**.5)
        distance_score = distance_expression.alias("score")
        distance_label = "distance: 0 ≤ √(.5(1-scc)) ≤ 1"
        distance_labels = pl.Series([distance_label]*row_count)
        cluster_by_distance = distance_labels.alias("cluster by")
        data_distance = data_scc.with_columns(distance_score, cluster_by_distance)
        self.data = pl.concat([data_scc, data_distance])
        

class Controller:
    def __init__(self, model):
        self.model = model
        self.options = {}
        self.static_options = {}
        self.multi_options = {}
        if self.model.data: self.setup()

    def load_file(self, filename):
        self.model.read_csv(filename)
    
    def setup(self):
        if self.model.data:
            self.option_cols = self.model.data.columns.copy()
            self.option_cols = [col
                        for col in self.model.data.columns
                        if col not in Model.cluster_cols]
            self.states = self.model.data.select(self.option_cols).unique()
            self.state = self.states.row(0, named = True)
            self.update_state(self.state)
            self.options = {}
            
            for col in self.option_cols:
                self.options[col] = self.states[col].unique()
            
            self.sort_options()

            self.static_options = {}
            self.multi_options = {}
            for option, values in self.options.items():
                if len(values) == 1:
                    self.static_options[option] = values[0]
                elif len(values) > 1:
                    self.multi_options[option] = values
            self.static_label = " ".join([
                f"{option}: {value}" for option, value in self.static_options.items() 
            ])

        self.build_layout()

    def sort_options(self):
        for option, values in self.options.items():
            self.options[option] = natsorted(values)

    def update_state(self, update):
        new_state = self.state.copy()
        new_state.update(update)
        compatible_states = self.states.join(pl.DataFrame(update), on = update.keys())
        assert not compatible_states.is_empty(), f"No compatible states for {update}"
        
        if not any(new_state == row for row in compatible_states.to_dicts()):
            self.state = compatible_states.row(0, named = True)
        else:
            self.state = new_state
        
        self.set_cluster_data()
    
    def set_cluster_data(self):
        state_df = pl.DataFrame(self.state)
        self.cluster_data = self.model.data.join(state_df,
                                               on = self.state.keys()) \
                                         .select(Model.cluster_cols)

    def clustermap_src(self):
        cluster_data = self.cluster_data.to_pandas().pivot(
            values = "score",
            index = "file1",
            columns = "file2"
        )
        cluster_data = cluster_data.combine_first(cluster_data.T)
        clustermap = create_clustermap(cluster_data, len(cluster_data))

        # Save the figure to a BytesIO object
        buffer = io.BytesIO()
        plt.savefig(buffer, format="png", bbox_inches='tight')
        buffer.seek(0)
        plt.close()
        
        # Convert the image to base64
        image_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
        return f"data:image/png;base64,{image_base64}"

    def build_layout(self):
        header = html.Div([
            html.H1("Hich HicRep"),
            dcc.Upload(
                id='upload-data',
                children=html.Div([
                    'Drag and Drop or ',
                    html.A('Select a File')
                ]),
                style={
                    'width': '100%',
                    'height': '60px',
                    'lineHeight': '60px',
                    'borderWidth': '1px',
                    'borderStyle': 'dashed',
                    'borderRadius': '5px',
                    'textAlign': 'center',
                    'margin': '10px'
                },
                multiple=False  # Only allow one file at a time
            )], style={"textAlign":"center"})
        
        gui_elements = [header]

        if self.model.data:
            static_label = html.Div([
                html.Label(controller.static_label)
            ], style={"textAlign":"center"})

            clustermap_image = html.Div(
                [html.Img(id='clustermap-image', src=controller.clustermap_src())],
                style = {"textAlign":"center"}
            )

            slider_divs = []
            for option, values in controller.multi_options.items():
                current_value = controller.state[option]
                current_value_index = values.index(current_value)
                space_size = 4
                char_count = sum([len(str(val)) for val in values]) + len(values)*space_size
                char_px = 5
                width = char_count*char_px
                slider_div = html.Div([
                    html.Label(option),
                    dcc.Slider(
                        id = f"{option}",
                        min = 0,
                        max = len(values)-1,
                        marks = {i: str(value) for i, value in enumerate(values)},
                        value = current_value_index,
                        step = 1
                    )],
                    style = {
                        "width": f"{width}px",
                        "margin": "10px auto",
                    }
                )
                slider_divs.append(slider_div)

            gui_elements = gui_elements + [static_label] + slider_divs + [clustermap_image]

        # Define the layout with a slider selector and an image display
        app.layout = html.Div(gui_elements)

model = Model()
controller = Controller(model)
#controller.load_file("~/Documents/hich/results/hicrep/hicrep.tsv")



In [44]:
import dash
from dash import html, dcc, callback_context
from dash.dependencies import Input, Output
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
import io
import base64

# Create a Dash app
app = dash.Dash(__name__, suppress_callback_exceptions=True)
controller.app = app

@app.callback(
    [Input('upload-data', 'contents')],
    [dash.dependencies.State('upload-data', 'filename'),
     dash.dependencies.State('upload-data', 'last_modified')]
)
def update_output(filename, date):
    if filename:
        controller.load_file(filename)

@app.callback(
    Output('clustermap-image', 'src'),
    [Input(option, 'value') for option in controller.multi_options]
)
def update_output(*slider_values):
    # Accessing the callback context to determine which input(s) changed
    triggered = callback_context.triggered

    if triggered:
        # Determine which sliders triggered the callback
        option = triggered[0]['prop_id'].split('.')[0]
        value_idx = triggered[0]['value']
        value = controller.options[option][value_idx]
        controller.update_state({option:value})


    return controller.clustermap_src()

if __name__ == "__main__":
    controller.setup()
    app.run_server(debug=True)