In [None]:
import os
import pandas as pd
import sys
from plotly.colors import DEFAULT_PLOTLY_COLORS

# project_path = "."
project_path = os.path.dirname(os.path.abspath(os.getcwd()))
sys.path.append(project_path)
import BraiAn

In [None]:
EXPERIMENT_DIRECTORY, USE_REMOTE_DATA = "Cariplo_NRe/IEGs Experiment", True
DATA_ROOT, PLOTS_ROOT = BraiAn.remote_dirs(EXPERIMENT_DIRECTORY, False, "")

In [None]:
data_input_path = os.path.join(DATA_ROOT, "BraiAn_output")

In [None]:
# from https://help.brain-map.org/display/api/Downloading+an+Ontology%27s+Structure+Graph
# StructureGraph id=1
path_to_allen_json = os.path.join(project_path, "data", "AllenMouseBrainOntology.json")
BraiAn.cache(path_to_allen_json, "http://api.brain-map.org/api/v2/structure_graph_download/1.json")
brain_onthology = BraiAn.AllenBrainHierarchy(path_to_allen_json, ["fiber tracts", "VS", "grv", "retina", "CB"], version="ccfv3")

In [None]:
REGIONS_TO_PLOT_SELECTION_METHOD = "structural level 8"
# REGIONS_TO_PLOT_SELECTION_METHOD = "major divisions"
match REGIONS_TO_PLOT_SELECTION_METHOD:
    case "summary structures":
        # selects the Summary Strucutures
        path_to_summary_structures = os.path.join(project_path, "data", "AllenSummaryStructures.csv")
        brain_onthology.select_from_csv(path_to_summary_structures)
    case "major divisions":
        brain_onthology.select_regions(BraiAn.MAJOR_DIVISIONS)
    case "smallest":
        brain_onthology.select_leaves()
    case s if s.startswith("depth"):
        n = REGIONS_TO_PLOT_SELECTION_METHOD.split(" ")[-1]
        try:
            depth = int(n)
        except Exception:
            raise Exception("Could not retrieve the <n> parameter of the 'depth' method for 'REGIONS_TO_PLOT_SELECTION_METHOD'")
        brain_onthology.select_at_depth(depth)
    case s if s.startswith("structural level"):
        n = REGIONS_TO_PLOT_SELECTION_METHOD.split(" ")[-1]
        try:
            level = int(n)
        except Exception:
            raise Exception("Could not retrieve the <n> parameter of the 'structural level' method for 'REGIONS_TO_PLOT_SELECTION_METHOD'")
        brain_onthology.select_at_structural_level(level)
    case _:
        raise Exception(f"Invalid value '{REGIONS_TO_PLOT_SELECTION_METHOD}' for REGIONS_TO_PLOT_SELECTION_METHOD")
selected_regions = brain_onthology.get_selected_regions()
print(f"You selected {len(selected_regions)} regions to plot.")

In [None]:
group_1: BraiAn.AnimalGroup = BraiAn.AnimalGroup.from_csv("CTX", data_input_path, f"cell_counts_CTX_density.csv")
group_2: BraiAn.AnimalGroup = BraiAn.AnimalGroup.from_csv("FC", data_input_path, f"cell_counts_FC_density.csv")
if not group_1.is_comparable(group_2):
    raise ImportError("Group 1 and Group 2 are not comparable!\n\
Please check that you're reading two groups that normalized on the same brain regions and on the same marker")

In [None]:
pls = group_1.pls_regions(group_2, selected_regions, n_permutations=5000, n_bootstrap=5000)

In [None]:
group_1.sort_by_onthology(brain_onthology, fill=True, inplace=False).mean["cFos"].data

In [None]:
pls[marker1].sort_by_onthology(brain_onthology, fill=False, inplace=False).data.equals(pls[marker1].data)

In [None]:
group_1.sort_by_onthology(brain_onthology, fill=True, inplace=False).mean["cFos"].data.equals(group_1.mean["cFos"].data)

In [None]:
import matplotlib.colors as mplc
import numpy as np
import plotly.graph_objects as go
import plotly

bar_to_heatmap_ratio = np.array([0.7,0.3])
space_between_markers = 0.01
marker_ratio = bar_to_heatmap_ratio*((1-space_between_markers)/2)
column_widths = [*marker_ratio, space_between_markers, *marker_ratio[::-1]] # 5 subplots, with the middle one beight a spacer
fig = plotly.subplots.make_subplots(rows=1, cols=5, horizontal_spacing=0, column_widths=column_widths, shared_yaxes=True)

marker1: str = group_1.markers[0]
marker2: str = group_1.markers[1]
if brain_onthology is not None:
    group_1 = group_1.sort_by_onthology(brain_onthology, fill=True, inplace=False)
    group_2 = group_2.sort_by_onthology(brain_onthology, fill=True, inplace=False)
    pls_marker1 = pls[marker1].sort_by_onthology(brain_onthology, fill=False, inplace=False).data
    pls_marker2 = pls[marker2].sort_by_onthology(brain_onthology, fill=False, inplace=False).data
else:
    pls_marker1 = pls[marker1].data
    pls_marker2 = pls[marker2].data
metric = str(group_1.metric)
group1_marker1 = group_1.to_pandas(marker=marker1).loc[selected_regions]
group2_marker1 = group_2.to_pandas(marker=marker1).loc[selected_regions] #[['367FC', '368FC', '369FC', '426FC']]
group1_marker2 = group_1.to_pandas(marker=marker2).loc[selected_regions]
group2_marker2 = group_2.to_pandas(marker=marker2).loc[selected_regions] #[['367FC', '368FC', '369FC', '426FC']]
assert ((pls[marker1].data.index == pls[marker2].data.index) & \
        (pls[marker2].data.index == group1_marker1.index) & \
        (group1_marker1.index == group2_marker2.index)).all(), "BrainGroups' data and PLS results and must be on the same brain regions!"

def to_rgba(color: str, alpha) -> str:
    r,g,b = plotly.colors.convert_to_RGB_255(mplc.to_rgb(color))
    return f"rgba({r}, {g}, {b}, {alpha})"

color_g1m1 = "MidnightBlue" # plotly.colors.DEFAULT_PLOTLY_COLORS[0]
color_g2m1 = "IndianRed"    # plotly.colors.DEFAULT_PLOTLY_COLORS[1]
color_g1m2 = "LightSkyBlue" # plotly.colors.DEFAULT_PLOTLY_COLORS[2]
color_g2m2 = "LightSalmon"  # plotly.colors.DEFAULT_PLOTLY_COLORS[3]

color_g1m1 = "ForestGreen"
color_g2m1 = "CornFlowerBlue"
color_g1m2 = "DarkGreen"
color_g2m2 = "RoyalBlue"

color_g1m1 = "LightCoral"
color_g2m1 = "SandyBrown"
color_g1m2 = "IndianRed"
color_g2m2 = "Orange"

def bar_ht(marker):
    return "<b>%{meta}</b><br>"+marker+" "+metric+": %{x}<br>region: %{y}<br><extra></extra>"
def heatmap_ht(marker):
    return "animal: %{x}<br>region: %{y}<br>"+marker+" "+metric+": %{z:.2f}<extra></extra>"

def bar(group_df: pd.DataFrame, salience_scores: pd.Series, threshold: float, group_name: str, marker: str, color: str):
    alpha_below_thr = 0.2
    alpha_undefined = 0.1
    fill_color = pd.Series(np.where(salience_scores.abs().ge(threshold, fill_value=0), color, to_rgba(color, alpha_below_thr)), index=salience_scores.index)
    is_undefined = salience_scores.isna()
    fill_color[is_undefined] = to_rgba(color, alpha_undefined)
    line_color = pd.Series(np.where(is_undefined, to_rgba(color, alpha_undefined), color), index=is_undefined.index)
    trace_name = f"{group_name} [{marker}]"
    trace = go.Bar(x=group_df.mean(axis=1), y=group_df.index,
                    error_x=dict(type="data", array=group_df.sem(axis=1), thickness=1),
                    marker=dict(line_color=line_color, line_width=1, color=fill_color), orientation="h",
                    hovertemplate=bar_ht(marker1), showlegend=False,
                    name=trace_name, legendgroup=trace_name, meta=trace_name)
    trace_legend = go.Scatter(x=[None], y=[None], mode="markers", marker=dict(color=color, symbol="square", size=15),
                                name=trace_name, showlegend=True, legendgroup=trace_name)
    return trace, trace_legend
def heatmap(group_df: pd.DataFrame, marker: str):
    hmap = go.Heatmap(z=group_df, x=group_df.columns, y=group_df.index, hoverongaps=True, coloraxis="coloraxis", hovertemplate=heatmap_ht(marker))
    nan_hmap = go.Heatmap(z=np.isnan(group_df).astype(int), x=group_df.columns, y=group_df.index, hoverinfo='skip',
                          showscale=False, colorscale=[[0, "rgba(0,0,0,0)"], [1, "silver"]])
    return hmap, nan_hmap

threshold = BraiAn.PLS.norm_threshold(nsigma=2) # use the μ ± 3σ of the normal as threshold
bar_g1m1, bar_g1m1_legend = bar(group1_marker1, pls_marker1, threshold, group_1.name, marker1, color_g1m1)
bar_g2m1, bar_g2m1_legend = bar(group2_marker1, pls_marker1, threshold, group_2.name, marker1, color_g2m1)
bar_g1m2, bar_g1m2_legend = bar(group1_marker2, pls_marker2, threshold, group_1.name, marker2, color_g1m2)
bar_g2m2, bar_g2m2_legend = bar(group2_marker2, pls_marker2, threshold, group_2.name, marker2, color_g2m2)

heatmap_g1m1, heatmap_g1m1_nan = heatmap(group1_marker1, marker1)
heatmap_g2m1, heatmap_g2m1_nan = heatmap(group2_marker1, marker1)
heatmap_g1m2, heatmap_g1m2_nan = heatmap(group1_marker2, marker2)
heatmap_g2m2, heatmap_g2m2_nan = heatmap(group2_marker2, marker2)

all_values = pd.concat((group1_marker1.mean(axis=1)+group1_marker1.sem(axis=1)/2,
                        group2_marker1.mean(axis=1)+group2_marker1.sem(axis=1)/2,
                        group1_marker2.mean(axis=1)+group1_marker2.sem(axis=1)/2,
                        group2_marker2.mean(axis=1)+group2_marker2.sem(axis=1)/2))
bar_range=(all_values.min(), all_values.max())

# MARKER1 - left side
fig.add_trace(bar_g1m1,         row=1, col=1)
fig.add_trace(bar_g1m1_legend,  row=1, col=1)
fig.add_trace(bar_g2m1,         row=1, col=1)
fig.add_trace(bar_g2m1_legend,  row=1, col=1)
fig.update_xaxes(autorange="reversed", range=bar_range, row=1, col=1)
fig.add_trace(heatmap_g1m1,     row=1, col=2)
fig.add_trace(heatmap_g1m1_nan, row=1, col=2)
fig.add_vline(x=group1_marker1.shape[1]-.5, line_color="white", row=1, col=2)
fig.add_trace(heatmap_g2m1,     row=1, col=2)
fig.add_trace(heatmap_g2m1_nan, row=1, col=2)
fig.update_xaxes(tickangle=45,  row=1, col=2)

# MARKER2 - right side
fig.add_trace(heatmap_g1m2,     row=1, col=4)
fig.add_trace(heatmap_g1m2_nan, row=1, col=4)
fig.add_vline(x=group1_marker1.shape[1]-.5, line_color="white", row=1, col=4)
fig.add_trace(heatmap_g2m2,     row=1, col=4)
fig.add_trace(heatmap_g2m2_nan, row=1, col=4)
fig.update_xaxes(tickangle=45,  row=1, col=4)
fig.add_trace(bar_g1m2,         row=1, col=5)
fig.add_trace(bar_g1m2_legend,  row=1, col=5)
fig.add_trace(bar_g2m2,         row=1, col=5)
fig.add_trace(bar_g2m2_legend,  row=1, col=5)
fig.update_xaxes(range=bar_range, row=1, col=5)

#fig.update_xaxes(scaleanchor="y", constrain="domain", row=1, col=2)
fig.update_xaxes(side="top")
fig.update_yaxes(autorange="reversed") #, title="region")
fig.update_layout(height=5000, plot_bgcolor="rgba(0,0,0,0)", legend=dict(tracegroupgap=0), #, width=2000
    coloraxis=dict(colorscale="deep_r", colorbar=dict(lenmode="pixels", len=500, thickness=15, outlinewidth=1, y=0.95, yanchor="top"))
)

fig.show()