# Imports


In [1]:
import json
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from matplotlib.colors import to_hex

In [51]:
# 2. Constants and configuration
INPUT_GRAPH_PATH = "../data/07-clustered-graphs/alpha0.3_k10_res0.002.graphml"
CLUSTER_INFO_LABEL_TREE = "../output/cluster-qualifications/ClusterInfoLabelTree.xlsx"
CLUSTER_LABEL_DICT_PATH = "../data/99-testdata/cluster_label_dict.json"
CLUSTER_TREE_PATH = "../output/cluster-qualifications/ClusterHierarchy_noComments.json"
OUTPUT_DIR = "../data/99-testdata/"
THREEJS_OUTPUT_DIR = (
    "/Users/jlq293/Projects/Random Projects/LW-ThreeJS/2d_ssrinetworkviz/src/data/"
)
CLUSTER_HIERARCHY_FOR_LEGEND_PATH = (
    "../output/cluster-qualifications/ClusterHierarchy_noComments.json"
)

In [53]:
class ClusterColorAssigner:
    """
    TO DO:
    1. clusters not mutually exclusive - need to assign to multiple categories
    2. colors too similar - need to assign more distinct colors

    A class for assigning colors to clusters based on their characteristics.

    This class provides methods to categorize clusters into color palettes,
    assign specific colors within those palettes, and create a mapping
    between clusters and their assigned colors.

    Attributes:
        colormaps (dict): A dictionary mapping color names to matplotlib colormaps.
        condition_names (dict): A dictionary mapping color names to condition names.

    Methods:
        assign_color_categories(clust_hierarchy): Assigns color categories to clusters.
        print_color_mapping(): Prints the mapping of conditions to color palettes.
        assign_colors(df, colormap): Assigns specific colors to clusters within a palette.
        create_color_dataframes(clust_hierarchy): Creates separate DataFrames for each color category.
        process_cluster_hierarchy(clust_hierarchy): Processes the entire cluster hierarchy.

    Usage:
        color_assigner = ClusterColorAssigner()
        processed_hierarchy, color_dict = color_assigner.process_cluster_hierarchy(cluster_hierarchy_df)
    """

    def __init__(self):
        self.colormaps = {
            "blue": plt.get_cmap("Blues"),
            "red": plt.get_cmap("Reds"),
            "green": plt.get_cmap("Greens"),
            "purple": plt.get_cmap("Purples"),
        }
        self.condition_names = {
            "blue": "pharmacology",
            "red": "indications",
            "green": "safety",
            "purple": "other",
        }

    def assign_color_categories(self, clust_hierarchy):
        conditions = [
            clust_hierarchy["pharmacology"] == 1,
            clust_hierarchy["indications"] == 1,
            clust_hierarchy["safety"] == 1,
            clust_hierarchy["other"] == 1,
        ]
        choices = ["blue", "red", "green", "purple"]
        clust_hierarchy["color_pal"] = np.select(conditions, choices, default="")
        return clust_hierarchy

    def print_color_mapping(self):
        print("Mapping of conditions to color palettes:")
        for color, condition in self.condition_names.items():
            print(f"{condition.capitalize()}: {color}")

    @staticmethod
    def assign_colors(df, colormap):
        num_colors = df.shape[0]
        colors = [to_hex(colormap(x)) for x in np.linspace(0.1, 0.9, num_colors)]
        df["color"] = colors
        return df

    def create_color_dataframes(self, clust_hierarchy):
        color_dfs = {}
        for color_name, colormap in self.colormaps.items():
            df_color = clust_hierarchy[
                clust_hierarchy["color_pal"] == color_name
            ].copy()
            if not df_color.empty:
                color_dfs[color_name] = self.assign_colors(df_color, colormap)
        return color_dfs

    def process_cluster_hierarchy(self, clust_hierarchy):
        # Assign color categories
        clust_hierarchy = self.assign_color_categories(clust_hierarchy)
        self.print_color_mapping()

        # Create color dataframes
        color_dfs = self.create_color_dataframes(clust_hierarchy)

        # Concatenate the color dataframes
        colored_hierarchy = pd.concat(color_dfs.values(), ignore_index=True)

        # Find and reinsert any missing clusters (without a color)
        missing_clusters = clust_hierarchy.loc[
            ~clust_hierarchy["cluster"].isin(colored_hierarchy["cluster"])
        ]
        if not missing_clusters.empty:
            missing_clusters["color"] = "gray"  # Assign a default color (e.g., gray)
            colored_hierarchy = pd.concat(
                [colored_hierarchy, missing_clusters], ignore_index=True
            )

        # Create cluster-color dictionary
        cluster_color_dict = dict(
            zip(colored_hierarchy["cluster"], colored_hierarchy["color"])
        )

        return colored_hierarchy, cluster_color_dict

    def save_dict_to_json(self, dict, path):
        with open(path, "w") as f:
            json.dump(dict, f)
        print(f"Cluster color dictionary saved to {path}")

In [54]:
clust_hierarchy = pd.read_excel(CLUSTER_INFO_LABEL_TREE)
clust_hierarchy  # Q["cluster"].head(10)

cluster_label_dict = dict(
    zip(clust_hierarchy["cluster"], clust_hierarchy["clusterlabel"])
)

In [55]:
cluster_label_dict

{0: 'Serotonin Receptor Studies',
 1: 'Aquatic Ecotoxicology',
 2: 'Risks of Prenatal Exposure',
 3: 'Quantification of SSRIs in Biological Samples',
 4: 'SSRIs for Obsessive-Compulsive Disorder (OCD)',
 5: 'SSRIs and the Cytochrome P450 System',
 6: 'SSRI Neuroscience',
 7: 'Pediatric Depression',
 8: 'The Chronic Unpredictable Mild Stress Model of Depression',
 9: 'Fluvoxamine for Depression',
 10: 'Paroxetine Bindeing',
 11: 'SSRIs Effect on Neural Processing of Emotional Cues',
 12: 'Risks of Prenatal Exposure (Rodents)',
 13: 'SSRIs for PTSD',
 14: 'SSRIs in Forced Swimming Test',
 15: 'Serotonin Syndrome',
 16: 'Sexual Dysfunction',
 17: 'Sequenced Depression Treatment',
 18: 'Post-Stroke SSRI Use',
 19: 'SSRIs Effect on Fear',
 20: 'Bleeding Risk',
 21: 'Serotonin Transporter Gene and Antidepressant Response',
 22: 'SSRIs and Inflammation',
 23: 'Escitalopram for Depression',
 24: 'SSRIs for Pain',
 25: 'SSRI Utilization Patterns',
 26: 'SSRIs in Dementias',
 27: 'SSRIs for Weig

In [61]:
clust_hierarchy = pd.read_excel(CLUSTER_INFO_LABEL_TREE)
clust_hierarchy.sort_values(by="cluster", inplace=True)

color_assigner = ClusterColorAssigner()
clust_hierarchy, cluster_color_dict = color_assigner.process_cluster_hierarchy(
    clust_hierarchy
)

color_assigner.save_dict_to_json(
    cluster_color_dict, OUTPUT_DIR + "cluster_color_dict.json"
)

cluster_label_dict = dict(
    zip(clust_hierarchy["cluster"], clust_hierarchy["clusterlabel"])
)

color_assigner.save_dict_to_json(cluster_label_dict, CLUSTER_LABEL_DICT_PATH)

print("\nCluster color dictionary (first 5 items):")
print(dict(list(cluster_color_dict.items())[:5]))
print("\nCluster label dictionary (first 5 items):")
print(dict(list(cluster_label_dict.items())[:5]))

Mapping of conditions to color palettes:
Pharmacology: blue
Indications: red
Safety: green
Other: purple
Cluster color dictionary saved to ../data/99-testdata/cluster_color_dict.json
Cluster color dictionary saved to ../data/99-testdata/cluster_label_dict.json

Cluster color dictionary (first 5 items):
{0: '#e3eef9', 2: '#dfebf7', 3: '#dbe9f6', 5: '#d6e6f4', 6: '#d3e3f3'}

Cluster label dictionary (first 5 items):
{0: 'Serotonin Receptor Studies', 2: 'Risks of Prenatal Exposure', 3: 'Quantification of SSRIs in Biological Samples', 5: 'SSRIs and the Cytochrome P450 System', 6: 'SSRI Neuroscience'}


In [69]:
cluster_label_dict[31]

'SSRIs for Panic Disorder'

In [62]:
ks = list([int(k) for k in cluster_label_dict.keys()])
ks.sort()
ks

[0,
 1,
 2,
 3,
 4,
 5,
 6,
 7,
 8,
 9,
 10,
 11,
 12,
 13,
 14,
 15,
 16,
 17,
 18,
 19,
 20,
 21,
 22,
 23,
 24,
 25,
 26,
 27,
 28,
 29,
 30,
 31,
 32,
 33,
 34,
 35,
 36,
 37,
 38,
 39,
 40,
 41,
 42,
 43,
 44,
 45,
 46,
 47,
 48,
 49,
 50,
 51,
 52,
 53,
 54,
 55,
 56,
 57,
 58,
 59,
 60,
 61,
 62,
 63,
 64,
 65,
 66,
 67,
 68,
 69,
 70,
 71,
 72,
 73,
 74,
 75,
 76,
 77,
 78,
 79,
 80,
 81,
 82,
 83,
 84,
 85,
 86,
 87,
 88,
 89,
 90,
 91,
 92,
 93,
 94,
 95,
 96,
 97,
 98,
 99,
 100,
 101,
 102,
 103,
 104,
 105,
 106,
 107,
 108,
 109,
 110,
 111,
 112,
 113,
 114,
 115,
 116,
 117,
 118,
 119,
 120,
 121,
 122,
 123,
 124,
 125,
 126,
 127,
 128,
 129,
 130,
 131,
 132,
 133,
 134,
 135,
 136,
 137,
 138,
 139,
 140,
 141,
 142,
 143,
 144,
 145,
 146,
 147,
 148]

# legend json creation


In [58]:
import json


def transform_dict_to_legend(cluster_hierarchy_dict, cluster_label_dict):
    """
    Transforms the cluster hierarchy dictionary by adding cluster labels to create a legend.
    """
    # Ensure keys in cluster_label_dict are integers
    cluster_label_dict = {int(k): v for k, v in cluster_label_dict.items()}

    def transform(item):
        if isinstance(item, dict):
            return {k: transform(v) for k, v in item.items()}
        elif isinstance(item, list):
            return [transform(i) for i in item]
        elif isinstance(item, int) and item in cluster_label_dict:
            return {item: cluster_label_dict[item]}
        else:
            return item

    return transform(cluster_hierarchy_dict)


with open(CLUSTER_HIERARCHY_FOR_LEGEND_PATH, "r") as f:
    cluster_hierarchy_dict = json.load(f)

with open(CLUSTER_LABEL_DICT_PATH, "r") as f:
    cluster_label_dict = json.load(f)

# Transform the dictionary
legend = transform_dict_to_legend(cluster_hierarchy_dict, cluster_label_dict)

# Output the result
print(json.dumps(legend, indent=2))

# Optionally, save the result to a file
# with open("path/to/output_legend.json", "w") as f:
#     json.dump(legend, f, indent=2)

{
  "Pharmacology": {
    "Pharmacodynamics": {
      "Mechanism of action": [
        {
          "0": "Serotonin Receptor Studies"
        },
        {
          "10": "Paroxetine Bindeing"
        },
        {
          "19": "SSRIs Effect on Fear"
        },
        {
          "11": "SSRIs Effect on Neural Processing of Emotional Cues"
        },
        {
          "30": "Serotonin Binding and Receptor Studies"
        },
        {
          "81": "Astrocyte Receptors"
        },
        {
          "88": "Repeated SSRIs Exposures Effects on Dopamine Receptors"
        },
        {
          "90": "Neurochemical and Electrophysiological Correlates of SSRIs"
        },
        {
          "94": "SSRIs for OCD-like behaviors"
        },
        {
          "112": "SSRIs in Model Organisms ( C. Elegans and Drosophilia)"
        },
        {
          "120": "Tryptophan Depletion in Depression"
        }
      ],
      "Animal Models of Disorders": [
        {
          "8": "The Chr

In [66]:
# Save as JSON
with open(OUTPUT_DIR + "legend_full_label_tree_clusternr.json", "w") as json_file:
    json.dump(legend, json_file, indent=4)

    # Save as JSON
with open(
    THREEJS_OUTPUT_DIR + "legend_tree.json",
    "w",
) as json_file:
    json.dump(legend, json_file, indent=4)

In [68]:
cluster_label_dict

{0: 'Serotonin Receptor Studies',
 2: 'Risks of Prenatal Exposure',
 3: 'Quantification of SSRIs in Biological Samples',
 5: 'SSRIs and the Cytochrome P450 System',
 6: 'SSRI Neuroscience',
 8: 'The Chronic Unpredictable Mild Stress Model of Depression',
 10: 'Paroxetine Bindeing',
 11: 'SSRIs Effect on Neural Processing of Emotional Cues',
 14: 'SSRIs in Forced Swimming Test',
 19: 'SSRIs Effect on Fear',
 21: 'Serotonin Transporter Gene and Antidepressant Response',
 22: 'SSRIs and Inflammation',
 30: 'Serotonin Binding and Receptor Studies',
 54: '(Sertraline) Drug Delivery',
 60: 'SSRI Synthesis',
 64: 'SSRIs Effects on Ion Channels',
 65: 'Methylenedioxymethamphetamine (MDMA) Effects on Serotonin',
 66: 'Genotype Mediated Response to SSRIs',
 67: 'Pharmacological Perspectives on Antidepressants',
 70: 'SSRIs Effects on Neuroendocrine System',
 80: 'Pulmonary Hypertension',
 81: 'Astrocyte Receptors',
 88: 'Repeated SSRIs Exposures Effects on Dopamine Receptors',
 90: 'Neurochemica