In [None]:
import os
import sys
from dotenv import load_dotenv
load_dotenv() 

# Set the target folder name you want to reach
target_folder = "phate-for-text"

# Get the current working directory
current_dir = os.getcwd()

# Loop to move up the directory tree until we reach the target folder
while os.path.basename(current_dir) != target_folder:
    parent_dir = os.path.abspath(os.path.join(current_dir, ".."))
    if parent_dir == current_dir:
        # If we reach the root directory and haven't found the target, exit
        raise FileNotFoundError(f"{target_folder} not found in the directory tree.")
    current_dir = parent_dir

# Change the working directory to the folder where "phate-for-text" is found
os.chdir(current_dir)

# Add the "phate-for-text" directory to sys.path
sys.path.insert(0, current_dir)

In [None]:
# Standard library imports
import warnings

# Third-party imports
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objs as go
import plotly.io as pio
from plotly.subplots import make_subplots
import seaborn as sns
from sklearn.metrics import adjusted_rand_score
import numpy as np
# Suppress warnings
warnings.filterwarnings('ignore')


In [None]:
import gc
gc.collect()

In [None]:
def align_labels_by_ari(labels, reference_labels, palette="tab10"):
    labels = np.array([str(l) if l is not None else "None" for l in labels])
    reference_labels = np.array([str(l) if l is not None else "None" for l in reference_labels])

    # Mask valid entries (exclude "None" for ARI calculation)
    valid_mask = (labels != "None") & (reference_labels != "None")
    valid_labels = labels[valid_mask]
    valid_ref_labels = reference_labels[valid_mask]

    unique_labels = sorted(set(valid_labels))
    unique_ref_labels = sorted(set(valid_ref_labels))

    # Create color palette
    max_colors = max(len(unique_labels), len(unique_ref_labels)) + 5
    palette_colors = sns.color_palette(palette, max_colors)
    color_pool = [
        f"rgba({int(r*255)}, {int(g*255)}, {int(b*255)}, 1.0)"
        for r, g, b in palette_colors
    ]

    # Assign colors to reference clusters
    ref_color_map = {}
    used_colors = set()
    for i, ref_label in enumerate(unique_ref_labels):
        color = color_pool[i]
        ref_color_map[ref_label] = color
        used_colors.add(color)

    # Compute ARI and match predicted clusters to reference clusters
    scores = []
    for label in unique_labels:
        label_mask = valid_labels == label
        for ref_label in unique_ref_labels:
            ref_mask = valid_ref_labels == ref_label
            ari = adjusted_rand_score(label_mask, ref_mask)
            scores.append((ari, label, ref_label))
    scores.sort(reverse=True)

    label_to_color = {}
    matched_labels = set()
    matched_refs = set()

    for ari, label, ref_label in scores:
        if label in matched_labels or ref_label in matched_refs:
            continue
        label_to_color[label] = ref_color_map[ref_label]
        matched_labels.add(label)
        matched_refs.add(ref_label)

    # Assign unmatched labels new colors
    for label in unique_labels:
        if label not in label_to_color:
            unused_colors = [c for c in color_pool if c not in used_colors]
            label_to_color[label] = unused_colors[0] if unused_colors else "rgba(0,0,0,0.8)"
            used_colors.add(label_to_color[label])

    # Always assign 'None' label grey transparent color
    label_to_color["None"] = "rgba(0,0,0,0.2)"

    # Final per-sample color list, with OVERRIDE if reference is 'None'
    final_colors = [
        label_to_color["None"] if ref == "None" else label_to_color.get(pred, "rgba(0,0,0,0.8)")
        for pred, ref in zip(labels, reference_labels)
    ]

    return label_to_color, final_colors



In [None]:
def plot_topic_map(embedding, labels, topic_names=None, palette="tab10", plot_3d=True, title="Topic Map", reference_labels=None,pt_size=4):
    labels = [str(label) if label is not None else "None" for label in labels]

    # Color assignment
    if reference_labels is not None:
        label_to_color, colors = align_labels_by_ari(labels, reference_labels, palette=palette)
    else:
        unique_labels = sorted(set(label for label in labels if label != "None"))
        palette_colors = sns.color_palette(palette, len(unique_labels))
        label_to_color = {
            label: f"rgba({int(r*255)}, {int(g*255)}, {int(b*255)}, 1.0)"
            for label, (r, g, b) in zip(unique_labels, palette_colors)
        }
        label_to_color["None"] = "rgba(0,0,0,0.2)"

    colors = [label_to_color.get(label, "rgba(0,0,0,0.2)") for label in labels]

    if topic_names is not None:
        text = [f"Topic: {t}<br>Label: {l}" for t, l in zip(topic_names, labels)]
    else:
        text = [f"Label: {l}" for l in labels]

    if plot_3d:
        trace = go.Scatter3d(
            x=embedding[:, 0], y=embedding[:, 1], z=embedding[:, 2],
            mode='markers',
            marker=dict(size=pt_size, color=colors),
            text=text,
            hoverinfo='text'
        )
        layout = go.Layout(
            title=dict(text=title, x=0.5),
            scene=dict(
                xaxis=dict(showticklabels=False),
                yaxis=dict(showticklabels=False),
                zaxis=dict(showticklabels=False)
            ),
            width=800, height=600, showlegend=False
        )
    else:
        trace = go.Scatter(
            x=embedding[:, 0], y=embedding[:, 1],
            mode='markers',
            marker=dict(size=pt_size, color=colors),
            text=text,
            hoverinfo='text'
        )
        layout = go.Layout(
            title=dict(text=title, x=0.5),
            xaxis=dict(showticklabels=False),
            yaxis=dict(showticklabels=False),
            width=800, height=600,
            showlegend=False
        )

    fig = go.Figure(data=[trace], layout=layout)
    return fig, label_to_color


In [None]:
def create_legend_only_figure(label_to_color, marker_size=10, font_size=12):
    traces = []
    for label, color in label_to_color.items():
        if str(label) == 'nan':
            continue
        trace = go.Scatter(
            x=[None],
            y=[None],
            mode='markers',
            marker=dict(size=marker_size, color=color),
            name=str(label),
            showlegend=True
        )
        traces.append(trace)

    layout = go.Layout(
        showlegend=True,
        legend=dict(
            x=0,
            y=1,
            font=dict(
                family='Times New Roman',
                size=font_size,
                color='black'
            )
        ),
        font=dict(  # global font (e.g., title, annotations)
            family='Times New Roman',
            size=font_size,
            color='black'
        ),
        xaxis=dict(visible=False),
        yaxis=dict(visible=False),
        width=100,
        height=100,
        margin=dict(l=0, r=0, t=0, b=0),
        plot_bgcolor='white',
        paper_bgcolor='white',
    )

    fig = go.Figure(data=traces, layout=layout)
    return fig

In [None]:
embedding_model = "text-embedding-3-large"
embed = np.load(f'{embedding_model}_reduced_embeddings/phate_embedding_Energy, Ecosystems, and Humans_hierarchy_t1.0_maxsub3_depth5_synonyms10_noise0.25_random_decay20_n_components300_tauto.npy')
topic_df= pd.read_csv("data_generation/generated_data/Energy, Ecosystems, and Humans_hierarchy_t1.0_maxsub3_depth5_synonyms10_noise0.25_random.csv")
colors = topic_df['category 0']
labels = topic_df['topic']
shuffle_idx = np.random.RandomState(seed=42).permutation(len(topic_df))
colors = colors.replace('was','None')[shuffle_idx]
colors = colors.replace(np.nan,'None')

In [None]:
fig, _=plot_topic_map(embed, colors, topic_names=topic_df['topic'], palette="tab10", plot_3d=False, title='PHATE')

In [None]:
fig.show()

In [None]:
# fig.write_html('Ecosystem(d).html')

In [None]:
dataset_dict = {
    # Web of Science
    "Web of science_PHATE": f"{embedding_model}_reduced_embeddings/PHATE_WOS_embed.npy",
    "Web of science_PCA": f"{embedding_model}_reduced_embeddings/PCA_WOS_embed.npy",
    "Web of science_UMAP": f"{embedding_model}_reduced_embeddings/UMAP_WOS_embed_new.npy",
    "Web of science_T-SNE": f"{embedding_model}_reduced_embeddings/tSNE_WOS_embed.npy",

    # DBpedia
    "DBpedia_PHATE": f"{embedding_model}_reduced_embeddings/dbpedia_phate_embed.npy",
    "DBpedia_PCA": f"{embedding_model}_reduced_embeddings/PCA_dbpedia_embed.npy",
    "DBpedia_UMAP": f"{embedding_model}_reduced_embeddings/UMAP_dbpedia_embed_new.npy",
    "DBpedia_T-SNE": f"{embedding_model}_reduced_embeddings/tSNE_dbpedia_embed.npy",

    # Amazon
    "Amazon_PHATE": f"{embedding_model}_reduced_embeddings/PHATE_amz_embed.npy",
    "Amazon_PCA": f"{embedding_model}_reduced_embeddings/PCA_amz_embed.npy",
    "Amazon_UMAP": f"{embedding_model}_reduced_embeddings/UMAP_amz_embed_new.npy",
    "Amazon_T-SNE": f"{embedding_model}_reduced_embeddings/tSNE_amz_embed.npy",

    # Fisheries (deep)
    "Fisheries (d)_PHATE": f"{embedding_model}_reduced_embeddings/phate_embedding_Offshore energy impacts on fisheries_hierarchy_t1.0_maxsub3_depth5_synonyms10_noise0.25_random_decay20_n_components300_tauto.npy",
    "Fisheries (d)_PCA": f"{embedding_model}_reduced_embeddings/PCA_embedding_Offshore energy impacts on fisheries_hierarchy_t1.0_maxsub3_depth5_synonyms10_noise0.25_random_n_components300.npy",
    "Fisheries (d)_UMAP": f"{embedding_model}_reduced_embeddings/UMAP_embedding_Offshore energy impacts on fisheries_hierarchy_t1.0_maxsub3_depth5_synonyms10_noise0.25_random_n_components300.npy",
    "Fisheries (d)_T-SNE": f"{embedding_model}_reduced_embeddings/tSNE_embedding_Offshore energy impacts on fisheries_hierarchy_t1.0_maxsub3_depth5_synonyms10_noise0.25_random_n_components300.npy",

    # Fisheries (shallow)
    "Fisheries (s)_PHATE": f"{embedding_model}_reduced_embeddings/phate_embedding_Offshore energy impacts on fisheries_hierarchy_t1.0_maxsub5_depth3_synonyms10_noise0.25_random_decay20_n_components300_tauto.npy",
    "Fisheries (s)_PCA": f"{embedding_model}_reduced_embeddings/PCA_embedding_Offshore energy impacts on fisheries_hierarchy_t1.0_maxsub5_depth3_synonyms10_noise0.25_random_n_components300.npy",
    "Fisheries (s)_UMAP": f"{embedding_model}_reduced_embeddings/UMAP_embedding_Offshore energy impacts on fisheries_hierarchy_t1.0_maxsub5_depth3_synonyms10_noise0.25_random_n_components300.npy",
    "Fisheries (s)_T-SNE": f"{embedding_model}_reduced_embeddings/tSNE_embedding_Offshore energy impacts on fisheries_hierarchy_t1.0_maxsub5_depth3_synonyms10_noise0.25_random_n_components300.npy",

    # Ecosystems (deep)
    "Ecosystems (d)_PHATE": f"{embedding_model}_reduced_embeddings/phate_embedding_Energy, Ecosystems, and Humans_hierarchy_t1.0_maxsub3_depth5_synonyms10_noise0.25_random_decay20_n_components300_tauto.npy",
    "Ecosystems (d)_PCA": f"{embedding_model}_reduced_embeddings/PCA_embedding_Energy, Ecosystems, and Humans_hierarchy_t1.0_maxsub3_depth5_synonyms10_noise0.25_random_n_components300.npy",
    "Ecosystems (d)_UMAP": f"{embedding_model}_reduced_embeddings/UMAP_embedding_Energy, Ecosystems, and Humans_hierarchy_t1.0_maxsub3_depth5_synonyms10_noise0.25_random_n_components300.npy",
    "Ecosystems (d)_T-SNE": f"{embedding_model}_reduced_embeddings/tSNE_embedding_Energy, Ecosystems, and Humans_hierarchy_t1.0_maxsub3_depth5_synonyms10_noise0.25_random_n_components300.npy",

    # Ecosystems (shallow)
    "Ecosystems (s)_PHATE": f"{embedding_model}_reduced_embeddings/phate_embedding_Energy, Ecosystems, and Humans_hierarchy_t1.0_maxsub5_depth3_synonyms10_noise0.25_random_decay20_n_components300_tauto.npy",
    "Ecosystems (s)_PCA": f"{embedding_model}_reduced_embeddings/PCA_embedding_Energy, Ecosystems, and Humans_hierarchy_t1.0_maxsub5_depth3_synonyms10_noise0.25_random_n_components300.npy",
    "Ecosystems (s)_UMAP": f"{embedding_model}_reduced_embeddings/UMAP_embedding_Energy, Ecosystems, and Humans_hierarchy_t1.0_maxsub5_depth3_synonyms10_noise0.25_random_n_components300.npy",
    "Ecosystems (s)_T-SNE": f"{embedding_model}_reduced_embeddings/tSNE_embedding_Energy, Ecosystems, and Humans_hierarchy_t1.0_maxsub5_depth3_synonyms10_noise0.25_random_n_components300.npy",
}


data_files = {
    # Hierarchical: Ecosystems
    "Ecosystems (d)": "data_generation/generated_data/Energy, Ecosystems, and Humans_hierarchy_t1.0_maxsub3_depth5_synonyms10_noise0.25_random.csv",
    "Ecosystems (s)": "data_generation/generated_data/Energy, Ecosystems, and Humans_hierarchy_t1.0_maxsub5_depth3_synonyms10_noise0.25_random.csv",

    # Hierarchical: Fisheries
    "Fisheries (d)": "data_generation/generated_data/Offshore energy impacts on fisheries_hierarchy_t1.0_maxsub3_depth5_synonyms10_noise0.25_random.csv",
    "Fisheries (s)": "data_generation/generated_data/Offshore energy impacts on fisheries_hierarchy_t1.0_maxsub5_depth3_synonyms10_noise0.25_random.csv",

    # Flat data sources
    "Amazon": "data/amazon/amz_data.csv",
    "DBpedia": "data/dbpedia/DBP_wiki_data.csv",
    "Web of science": "data/WebOfScience/Meta-data/Data.xlsx"
}


In [None]:

# Extract unique row and column labels
row_labels = []
col_labels = []

# Go through dataset_dict and extract the row and column labels while maintaining their original order
for k in dataset_dict:
    row_label, col_label = k.split("_")[0], k.split("_")[1]
    if row_label not in row_labels:
        row_labels.append(row_label)
    if col_label not in col_labels:
        col_labels.append(col_label)

content_cols = len(col_labels)

# Equal widths for content columns
equal_width = 1.0
column_widths = [equal_width] * content_cols

# Scale all widths so they sum to 1 (required by Plotly)
total_width = sum(column_widths)
normalized_widths = [w / total_width for w in column_widths]

# Initialize subplot grid
fig = make_subplots(
    rows=len(row_labels),
    cols=content_cols,
    column_widths=normalized_widths,
    horizontal_spacing=0.01,
    vertical_spacing=0.03,
)

fig.update_layout(
    width=800 * len(col_labels),  # ~300 px per subplot
    height=600 * len(row_labels), # ~300 px per row
    showlegend=False,
    margin=dict(l=20, r=20, t=20, b=20),
)

# Map row and column positions
row_map = {label: i + 1 for i, label in enumerate(row_labels)}
col_map = {label: j + 1 for j, label in enumerate(col_labels)}

latest_colors_by_row = {}
for key, title in dataset_dict.items():
    print(key)
    row_key, col_key = key.split("_")
    i = row_map[row_key]
    j = col_map[col_key]

    topic_file = data_files[row_key]
    embedding_file = title
    if topic_file.endswith('.csv'):
        df = pd.read_csv(topic_file)
    else:
        df = pd.read_excel(topic_file)

    if row_key == 'Amazon':
        colors = df['category_0']
    elif row_key == 'DBpedia':
        colors = df['l1']
    elif row_key == 'Web of science':
        colors = df['Domain']
    else:
        colors = df['category 0'].replace('was', None).replace(np.nan, None)

    embed = np.load(embedding_file)
    shuffle_idx = np.random.RandomState(seed=42).permutation(len(df))
    embed = embed
    colors = colors.iloc[shuffle_idx]

    if len(df) > 10000:
        pt_size = 3
    else:
        pt_size = 6

    fig_topic, label_to_color = plot_topic_map(embed, colors, plot_3d=False, title="", palette="tab10", pt_size=pt_size)

    latest_colors_by_row[row_key] = label_to_color

    if 'Generated' not in topic_file:
        label_to_color.pop('None', None)

    # Add the trace(s) from the topic map to the subplot
    for trace in fig_topic.data:
        fig.add_trace(trace, row=i, col=j)
# Adjust margins to provide more space for titles
fig.update_layout(
    margin=dict(l=450, r=20, t=100, b=20),  # More space around the plot for titles
)

# Add column titles (top of each column)
for j, col_label in enumerate(col_labels):
    xref = f"x{j+1} domain" if j + 1 > 1 else "paper"
    x= .5 if j + 1 > 1 else .095
    fig.add_annotation(
        dict(
            text=col_label,
            xref=xref,
            yref="paper",
            x=x,
            y=1.015,
            showarrow=False,
            font=dict(size=48),
            align="center"
        )
    )

# Add row titles (left of each row)
for i, row_label in enumerate(row_labels):
    axis_index = i * content_cols + 1  # Index of the first subplot in the row
    yref = f"y{axis_index} domain" if axis_index > 1 else "paper"
    y= .5 if i + 1 > 1 else .95
    fig.add_annotation(
        dict(
            text=row_label,
            xref="paper",
            yref=yref,
            x=-0.15,  # Slightly to the left of the plots
            y=y,  # Centered within that domain
            showarrow=False,
            font=dict(size=48),
            align="right"
        )
    )

for i in range(1, len(row_labels) + 1):
    for j in range(1, len(col_labels) + 1):
        fig.update_xaxes(showticklabels=False, ticks="", row=i, col=j)
        fig.update_yaxes(showticklabels=False, ticks="", row=i, col=j)


In [None]:
fig.show()

In [None]:
fig.write_image("all_embeddings_subplot.png", scale=4
                )

### EPA visual

In [None]:
data = np.load('data/epa/epa_embed.npy')

labels = np.load('data/epa/epa_labels.npy',allow_pickle=True)

In [None]:
embedding_model = "MiniLM-L6-v2"

embedding_methods = {}
embedding_methods["PHATE"]  =np.load(f"{embedding_model}_reduced_embeddings/PHATE_epa_embed_top20.npy")
embedding_methods["UMAP"]=np.load(f"{embedding_model}_reduced_embeddings/UMAP_epa_embed_top20.npy")
embedding_methods['PCA']=np.load(f"{embedding_model}_reduced_embeddings/PCA_epa_embed_top20.npy")
embedding_methods['T-SNE']=np.load(f"{embedding_model}_reduced_embeddings/tSNE_epa_embedtop20.npy")

In [None]:
from plotly.subplots import make_subplots
import plotly.graph_objects as go

# Create a 2x2 subplot figure
fig_combined = make_subplots(
    rows=2, cols=2,
    subplot_titles=list(embedding_methods.keys()),
    specs=[[{"type": "scatter"}]*2]*2,
    vertical_spacing=0.05,
    horizontal_spacing=.05,
    
)

row_col_map = [(1, 1), (1, 2), (2, 1), (2, 2)]

for (i, (method_name, data)) in zip(row_col_map, embedding_methods.items()):
    row, col = i
    # Generate the topic map plot

    fig_topic, label_to_color = plot_topic_map(
        data, labels,plot_3d=False, title=None, palette="tab20"
    )
    # fig_topic.write_image(f'{method_name}_epa.png',scale=5)

    # Add each trace from fig_topic into the corresponding subplot
    for trace in fig_topic.data:
        trace
        fig_combined.add_trace(trace, row=row, col=col)

axis_style = dict(showticklabels=False)

fig_combined.update_layout(
    height = 900,
    width = 1200,
    # margin=dict(t=5),
    title_font_size=64,
    xaxis=axis_style,
    xaxis2=axis_style,
    xaxis3=axis_style,
    xaxis4=axis_style,
    yaxis=axis_style,
    yaxis2=axis_style,
    yaxis3=axis_style,
    yaxis4=axis_style,
    showlegend=False,
    
)
for annotation in fig_combined['layout']['annotations']:
    annotation['font'] = dict(size=24)

In [None]:
fig_combined.show()