In [197]:
import matplotlib.pyplot as plt
import networkx as nx
import pandas as pd
import ipywidgets as widgets
from IPython.display import display
from IPython.display import clear_output

def create_label_dictionary(nodes_df):
    return {row["Node ID"]: row["Name"] for _, row in nodes_df.iterrows()}
    
def create_position_dictionary(nodes_df):
    pos = {}
    layer_counts = {}

    for _, row in nodes_df.iterrows():
        layer = row["Layer"]
        node_type = row["Type"]

        # Initialize layer count for each layer and type
        if (layer, node_type) not in layer_counts:
            layer_counts[(layer, node_type)] = 0
        
        # Position based on the count of nodes in the same layer and type
        if node_type == "Head":
            pos[row["Node ID"]] = (layer * 2 , -layer_counts[(layer, node_type)] - 0.5)
        else:
            pos[row["Node ID"]] = (layer * 2 + 1, -layer_counts[(layer, node_type)])

        # Increment the count for the layer and type
        layer_counts[(layer, node_type)] += 1
    return pos

def build_graph_from_dataframes(nodes_df, edges_df):
    G = nx.DiGraph()
    color_map = plt.cm.get_cmap("rainbow")
    for _, row in nodes_df.iterrows():
        G.add_node(row["Node ID"], color=color_map(row["Correlation"]))
    for _, row in edges_df.iterrows():
        G.add_edge(row["Source"], row["Target"])
    return G

def draw_graph(G, pos, labels, nodes_df):
    plt.figure(figsize=(15, 10))
    ax = plt.gca()

    # Keep track of colors used for nodes
    node_colors = []

    for node_type in ["Head", "Feature"]:
        nodelist = [
            node
            for node, data in G.nodes(data=True)
            if not nodes_df[nodes_df["Node ID"] == node].empty
            and nodes_df[nodes_df["Node ID"] == node]["Type"].iloc[0] == node_type
        ]

        node_shapes = "o" if node_type == "Head" else "s"
        colors = [
            data["color"] for node, data in G.nodes(data=True) if node in nodelist
        ]
        node_colors.extend(colors)

        nx.draw_networkx_nodes(
            G,
            pos,
            nodelist=nodelist,
            node_shape=node_shapes,
            node_size=2000,
            node_color=colors,
            ax=ax
        )

    nx.draw_networkx_edges(G, pos, arrowstyle="->", arrowsize=20, ax=ax)
    nx.draw_networkx_labels(G, pos, labels=labels, font_size=8, ax=ax)
    
    # Create a colorbar
    sm = plt.cm.ScalarMappable(cmap=plt.cm.rainbow, norm=plt.Normalize(vmin=0, vmax=1))
    sm.set_array([])
    plt.colorbar(sm, ax=ax, orientation='vertical')

    plt.title("Flow Chart with Custom Node Shapes, Colors, and Labels")
    plt.show()
import ast  # Add this import

def update_correlation(nodes_df, tag):
    # Convert 'Indicators' from string to dictionary if it's a string
    nodes_df['Indicators'] = nodes_df['Indicators'].apply(
        lambda x: ast.literal_eval(x) if isinstance(x, str) else x
    )
    
    # Calculate the max_count of all heads with this tag and update 'Correlation'
    max_count = max(nodes_df['Indicators'].apply(lambda x: x.get(tag, 0)))
    nodes_df['Correlation'] = nodes_df['Indicators'].apply(
        lambda x: x.get(tag, 0) / max_count if max_count > 0 else 0
    )
    

# Function to create the interactive graph based on the selected tag
def interactive_graph(tag, min_correlation, nodes_df, edges_df):
    # clear_output(wait=True)  # Clear the existing plot
    update_correlation(nodes_df, tag)
    
    # Filter nodes based on the minimum correlation value
    nodes_to_keep = nodes_df[nodes_df['Correlation'] >= min_correlation]['Node ID']
    nodes_df_filtered = nodes_df[nodes_df['Node ID'].isin(nodes_to_keep)]
    edges_df_filtered = edges_df[edges_df['Source'].isin(nodes_to_keep) & edges_df['Target'].isin(nodes_to_keep)]
    
    labels = create_label_dictionary(nodes_df_filtered)
    pos = create_position_dictionary(nodes_df_filtered)
    print("pos")
    print(pos)
    G = build_graph_from_dataframes(nodes_df_filtered, edges_df_filtered)
    draw_graph(G, pos, labels, nodes_df_filtered)
    print("nodes_df_[Correlation]")
    print(nodes_df_filtered["Correlation"])
    
# Sample tags for demonstration
sample_tags =  ["tag1", "tag2", "tag3", "tag4", "tag5","tag6", "tag7", "tag8", "tag9", "tag10"]

In [198]:
# Function to run the main interactive application

def main_interactive(nodes_df, edges_df):
    tag_dropdown = widgets.Dropdown(
        options=sample_tags,
        value=sample_tags[0],
        description='Tag:',
        disabled=False,
    )

    min_correlation_input = widgets.FloatSlider(
        value=0.0,
        min=0.0,
        max=100.0,
        step=0.1,
        description='Min Correlation:',
        disabled=False,
        continuous_update=False
    )

    widgets.interactive(
        lambda tag, min_correlation: interactive_graph(tag, min_correlation, nodes_df, edges_df),
        tag=tag_dropdown,
        min_correlation=min_correlation_input
    )


In [199]:
nodes_df = pd.read_csv('data/nodes_df.csv', index_col=False)
edges_df = pd.read_csv('data/edges_df.csv', index_col=False)

nodes_df

edges_df

Unnamed: 0,Source,Target
0,L1H1,L1F1
1,L1H1,L1F3
2,L1H1,L1F6
3,L1H1,L1F8
4,L1H1,L1F10
...,...,...
338,L4H3,L4F22
339,L4H3,L4F23
340,L4H3,L4F24
341,L4H3,L4F25


In [200]:
tag_dropdown = widgets.Dropdown(
    options=sample_tags,
    value=sample_tags[0],
    description='Tag:',
    disabled=False,
)

min_correlation_input = widgets.FloatSlider(
    value=0.1,
    min=0.0,
    max=1.0,
    step=0.1,
    description='Min Correlation:',
    disabled=False,
    continuous_update=False
)

widgets.interactive(
    lambda tag, min_correlation: interactive_graph(tag, min_correlation, nodes_df, edges_df),
    tag=tag_dropdown,
    min_correlation=min_correlation_input
)

interactive(children=(Dropdown(description='Tag:', options=('tag1', 'tag2', 'tag3', 'tag4', 'tag5', 'tag6', 't…

In [201]:
# Correcting the lengths of the arrays in the dataframe
# Counting the number of elements for each type (Head and Feature) in Layer 1
num_heads_layer_1 = 3  # L1H1, L1H2, L1H3
num_features_layer_1 = 23 * 3  # 23 features for each head

# Constructing the Node ID list
node_ids = ["L1H1"] + ["L1F" + str(i) for i in range(1, 24)] + ["L1H2"] + ["L1F" + str(i) for i in range(1, 24)] + ["L1H3"] + ["L1F" + str(i) for i in range(1, 24)]

# Constructing the Type list
types = ["Head"] + ["Feature"] * 23 + ["Head"] + ["Feature"] * 23 + ["Head"] + ["Feature"] * 23

# Constructing the Layer list
layers = [1] * (num_heads_layer_1 + num_features_layer_1)

# Creating the dataframe
data_corrected = {
    "Node ID": node_ids,
    "Type": types,
    "Layer": layers
}

nodes_df_corrected = pd.DataFrame(data_corrected)

# Re-running the function with the corrected data
positions_corrected = create_position_dictionary(nodes_df_corrected)
positions_corrected


{'L1H1': (2, -0.5),
 'L1F1': (3, -46),
 'L1F2': (3, -47),
 'L1F3': (3, -48),
 'L1F4': (3, -49),
 'L1F5': (3, -50),
 'L1F6': (3, -51),
 'L1F7': (3, -52),
 'L1F8': (3, -53),
 'L1F9': (3, -54),
 'L1F10': (3, -55),
 'L1F11': (3, -56),
 'L1F12': (3, -57),
 'L1F13': (3, -58),
 'L1F14': (3, -59),
 'L1F15': (3, -60),
 'L1F16': (3, -61),
 'L1F17': (3, -62),
 'L1F18': (3, -63),
 'L1F19': (3, -64),
 'L1F20': (3, -65),
 'L1F21': (3, -66),
 'L1F22': (3, -67),
 'L1F23': (3, -68),
 'L1H2': (2, -1.5),
 'L1H3': (2, -2.5)}

# 