In [1]:
import numpy as np
import pandas as pd
import pacmap
import json
from bokeh.plotting import figure, show
from bokeh.transform import factor_cmap
from bokeh.models import Div, Slider, TabPanel, Tabs, Legend, ColumnDataSource
from bokeh.layouts import layout
from bokeh.io import curdoc, output_notebook
from bokeh.embed import json_item
from bokeh.models import Column
from bokeh.layouts import column

In [2]:
class BokehPlotter:
    from bokeh.themes import Theme

    white_theme = Theme(json={
        "attrs": {
            # "Plot": { "toolbar_location": None },
            # "Grid": { "grid_line_color": None },
            "Axis": {
            #     "axis_line_color": None,
                "major_label_text_color": 'black',
                "major_label_text_font": 'Arial',
            #     "major_tick_line_color": None,
            #     "minor_tick_line_color": None,
            },
            "Legend": {
                "label_text_color": 'black',
                "label_text_font": 'Arial',
            },
            "Title": {
                "text_color": 'black',
                "text_font": 'Arial',
            },
        }
    })

    def __init__(self, df, cols, custom_color_palette, title=None,
                x_range=None, y_range=None, datapoint_size=5, 
                tooltip_dx_cols='WHO 2022 Diagnosis', width=1000, height=800):
        self.df = df
        self.cols = cols
        self.custom_color_palette = custom_color_palette
        self.title = title #+ ', n=' + str(self.df.shape[0])
        self.x_range = x_range or (-50, 50)
        self.y_range = y_range or (-50, 50)
        self.tabs = None
        self.points = None
        self.slider = None
        self.layout = None
        self.datapoint_size = datapoint_size or 5
        self.tooltip_dx_cols = tooltip_dx_cols
        self.width = width
        self.height = height

    def create_figure(self):
        p = figure(title=self.title, 
                width=self.width, height=self.height + 500, sizing_mode="inherit",
                x_axis_label='Longitude (PaCMAP 1)', y_axis_label='Latitude (PaCMAP 2)',
                x_range=self.x_range, y_range=self.y_range,
                tools="pan,wheel_zoom,reset,save", active_drag=None,
                active_scroll="auto",
                tooltips=[("Dx", "@{"+self.tooltip_dx_cols+"}")])
        curdoc().theme = BokehPlotter.white_theme
        return p

    def create_scatters(self, p, hue):
        df = self.df[~self.df[hue].isna()]  # Filter out rows with NaN values for the hue column
        filtered_dfs = [df[df[hue] == val] for val in df[hue].value_counts().sort_values(ascending=False).index.to_list()]
        
        renderers = []
        items = []
        for i in range(len(filtered_dfs)):
            name = filtered_dfs[i][hue].head(1).values[0]
            color = self.custom_color_palette[i % len(self.custom_color_palette)]
            source = ColumnDataSource(filtered_dfs[i])
            r = p.scatter(x="PaCMAP 1", y="PaCMAP 2", source=source,
                         fill_alpha=0.8, size=self.datapoint_size,
                         color=color)
            renderers.append(r)
            items.append((name, [r]))

        return renderers, items

    def create_tabs_and_points(self):
        self.tabs = Tabs(tabs=[TabPanel(child=self.create_figure(), title=title) for title in self.cols],
                tabs_location='left')

        self.points = [self.create_scatters(tab.child, hue=col) for tab, col in zip(self.tabs.tabs, self.cols)]

    def finalize_tabs(self):
        for p, (renderers, items) in zip(self.tabs.tabs, self.points):
            p.child.toolbar.logo = None
            p.child.toolbar_location = 'above'
            
            # Create a new legend with the items
            legend = Legend(items=items, location='center')

            # Add the legend to below the plot and set orientation
            p.child.add_layout(legend, 'below')
            p.child.legend.orientation = "vertical"
            p.child.legend.click_policy = 'hide'
            
        for i in range(len(self.tabs.tabs)):
            self.tabs.tabs[i].child.legend.title = self.tabs.tabs[i].title
            self.tabs.tabs[i].child.output_backend = "svg"

    def create_slider(self):
        self.slider = Slider(title="Adjust datapoint size", start=0, end=10, step=1, value=self.points[0][0][0].glyph.size)
        for i in range(len(self.points)): 
            for r in self.points[i][0]: 
                self.slider.js_link("value", r.glyph, "size")

    def create_layout(self):
        div = Div(text="""<br>""", width=1000, height=10)
        self.layout = layout([[[div, self.tabs, self.slider]]])

    def plot(self):
        self.create_tabs_and_points()
        self.finalize_tabs()
        self.create_slider()
        self.create_layout()
        show(self.layout)

    def create_json_item(self):
        # This function is used to embed the plot in the web app
        return json_item(self.layout)

    def save_plot_as_json(self, filename="plot.json"):
        # Ensure the plot is fully set up
        self.create_tabs_and_points()
        self.finalize_tabs()
        self.create_slider()
        self.create_layout()

        # Generate the JSON representation of the plot
        plot_json = json_item(self.layout)

        # Write the JSON to the specified file
        with open(filename, 'w') as f:
            json.dump(plot_json, f)

        print(f"Plot JSON saved to {filename}")

def get_custom_color_palette():
    list = [
    '#1f77b4',  # Vivid blue
    '#ff7f0e',  # Vivid orange 
    '#2ca02c',  # Vivid green
    '#d62728',  # Vivid red
    '#9467bd',  # Vivid purple 
    '#7f7f7f',  # Medium gray
    '#e377c2',  # Pink
    '#e7ba52',  # Light orange
    '#bcbd22',  # Olive
    '#17becf',  # Light blue
    '#393b79',  # Dark blue
    '#8c564b',  # Brown
    '#f7b6d2',  # Light pink
    '#c49c94',  # Light brown
    '#a2769e',   # Soft purple
    '#dbdb8d',  # Pale yellow
    '#9edae5',  # Pale cyan
    '#c5b0d5',  # Pale purple
    '#c7c7c7',  # Light gray
    '#ff9896',  # Light red
    '#637939',  # Dark olive
    '#aec7e8',  # Light blue
    '#ffbb78',  # Light orange
    '#98df8a',  # Light green
    '#7c231e',  # Dark red
    '#3d6a3d',  # Dark green
    '#f96502',  # Deep orange
    '#6d3f7d',  # Deep purple
    '#6b4423',  # Dark brown
    '#d956a6'   # Hot pink
    ]
    return list

import pandas as pd
import sys
sys.path.append('../')

# Load clinical data
discovery_clinical_data = pd.read_csv('discovery_clinical_data.csv',
                                      low_memory=False, index_col=0)

# Load clinical data
validation_clinical_data = pd.read_csv('validation_clinical_data.csv',
                                        low_memory=False, index_col=0)

# Adjust clinical data
discovery_clinical_data['Train Test'] = 'Discovery (train) Samples'
validation_clinical_data['Train Test'] = 'Validation (test) Samples'

discovery_clinical_data['PaCMAP Output'] = 'Patient Samples'
validation_clinical_data['PaCMAP Output'] = 'Patient Samples'

# Set the theme for the plot
curdoc().theme = 'light_minimal' # or 'dark_minimal'

clinical_trials = ['AAML0531', 'AAML1031', 'AAML03P1', 'CCG2961', 'Japanese AML05']

sample_types = ['Diagnosis', 'Primary Blood Derived Cancer - Bone Marrow', 'Bone Marrow Normal',
                'Primary Blood Derived Cancer - Peripheral Blood', 'Blood Derived Normal']

cols = ['Clinical Trial', 'Sample Type', 'Patient_ID', 'ELN AML 2022 Diagnosis', 'Train Test']

df2 = pd.read_csv('pacmap_2d_model_peds_dx_aml.csv', index_col=0)

# Concatenate discovery and validation clinical data
clinical_data = pd.concat([discovery_clinical_data, validation_clinical_data])

# Select columns to plot
cols = ['PaCMAP Output','Hematopoietic Group','WHO AML 2022 Diagnosis','ELN AML 2022 Diagnosis', 'FAB', 'FLT3 ITD', 'Age (group years)',
        'Complex Karyotype', 'Primary Cytogenetic Code' , 'Sex', 'MRD 1 Status',
        'Leucocyte counts (10⁹/L)', 'Risk Group', 'Race or ethnic group',
        'Clinical Trial', 'Vital Status','First Event','Sample Type', 'Train Test']

# Join clinical data to the embedding
df2 = df2.join(clinical_data[cols], rsuffix='_copy', on='index')

plotter2 = BokehPlotter(df2, cols, get_custom_color_palette(),
                       title='',
                        x_range=(-45, 45), y_range=(-45, 45),
                        datapoint_size=3, tooltip_dx_cols='ELN AML 2022 Diagnosis',
                        width=1000, height=500)

plotter2.save_plot_as_json('plot.json') 

plotter2.plot()

FileNotFoundError: [Errno 2] No such file or directory: 'discovery_clinical_data.csv'