[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/gbatt55/ai-ontology-map/blob/notebooks/visualize_graph.ipynb)

# AI Supply Chain Ontology — Graph Visualizer

This notebook:
- Downloads `ontology.json` from the GitHub repo
- Builds a directed graph of the AI supply chain ontology
- Colors nodes by **layer**
- Uses node **shapes** for different node types
- Colors edges by **relation type** (dependency, leverage, competition, coupling)
- Scales edge thickness by **weight**
- Provides a legend for layers, node types, and edge types

You can edit `selected_layers` to control which layers are displayed.

In [None]:
!pip install networkx

import json, urllib.request, os, numpy as np
import networkx as nx
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import matplotlib.lines as mlines

cmap = plt.cm.get_cmap('tab10')
print("Imports ready.")

In [None]:
# Download ontology.json from GitHub RAW
RAW_URL = "https://github.com/gbatt55/ai-ontology-map/main/ontology.json"
local_file = "ontology.json"

urllib.request.urlretrieve(RAW_URL, local_file)

with open(local_file, "r") as f:
    data = json.load(f)

nodes = data.get("nodes", [])
edges = data.get("edges", [])

print(f"Loaded {len(nodes)} nodes and {len(edges)} edges.")

In [None]:
# Choose which layers to visualize
# Example: [2] → only Layer 2 (Compute Fabric)
# Example: []  → ALL layers

selected_layers = [2]

print("Visualizing layers:", selected_layers if selected_layers else "ALL")

In [None]:
# Build NetworkX graph from ontology
G = nx.DiGraph()

for n in nodes:
    G.add_node(n["id"], **n)

for e in edges:
    G.add_edge(e["source"], e["target"], **e)

print("Graph built:", G.number_of_nodes(), "nodes,", G.number_of_edges(), "edges.")

In [None]:
# ==============================================
# UPGRADED GRAPH VISUALIZER
# ==============================================

NODE_SHAPES = {
    "company": "o",
    "technology": "s",
    "platform": "D",
    "institution": "^",
    "concept": "v",
    "infrastructure": "P",
    "dataset": "h",
    "supply-chain": "8"
}

EDGE_COLORS = {
    "dependency": "blue",
    "leverage": "purple",
    "competition": "red",
    "coupling": "orange"
}

def edge_width(weight):
    return 1 + 4 * weight

# === FILTER BY LAYER ===
if selected_layers:
    H = nx.DiGraph()
    for n, d in G.nodes(data=True):
        if d.get("layer") in selected_layers:
            H.add_node(n, **d)
    for u, v, d in G.edges(data=True):
        if u in H.nodes and v in H.nodes:
            H.add_edge(u, v, **d)
else:
    H = G.copy()

if H.number_of_nodes() == 0:
    print("No nodes present for selected layers.")

# === LAYER-AWARE LAYOUT ===
layer_positions = {}
y_scaling = 1.5

for n, d in H.nodes(data=True):
    layer = d.get("layer", 4)
    layer_positions[n] = (
        layer + 0.2 * np.random.randn(),
        -layer * y_scaling + 0.1 * np.random.randn()
    )

pos = nx.spring_layout(H, pos=layer_positions, seed=42, k=1.0)

# === DRAW ===
plt.figure(figsize=(15, 11))

# NODES (by type)
for node_type, shape in NODE_SHAPES.items():
    nodes_of_type = [n for n, d in H.nodes(data=True) if d.get("type") == node_type]
    if nodes_of_type:
        nx.draw_networkx_nodes(
            H, pos,
            nodelist=nodes_of_type,
            node_color=[cmap((H.nodes[n]["layer"] - 1) % 10) for n in nodes_of_type],
            node_shape=shape,
            node_size=1100,
            alpha=0.92
        )

# EDGES (by type + weight)
for (u, v, d) in H.edges(data=True):
    e_type = d.get("type", "dependency")
    weight = d.get("weight", 0.4)

    nx.draw_networkx_edges(
        H, pos,
        edgelist=[(u, v)],
        arrowstyle="-|>",
        arrowsize=16,
        width=edge_width(weight),
        edge_color=EDGE_COLORS.get(e_type, "gray"),
        alpha=0.85
    )

# LABELS
nx.draw_networkx_labels(
    H, pos,
    labels={n: H.nodes[n]["name"] for n in H.nodes},
    font_size=9
)

# TITLE
plt.title(
    "AI Supply Chain Ontology — Layers " + (
        str(selected_layers) if selected_layers else "ALL"
    ),
    fontsize=14
)
plt.axis("off")

# === LEGEND ===
layer_patches = [mpatches.Patch(color=cmap(i), label=f"Layer {i+1}") for i in range(7)]

node_handles = [
    mlines.Line2D([], [], color="black", marker=shape, linestyle="None", markersize=10, label=typ)
    for typ, shape in NODE_SHAPES.items()
]

edge_handles = [
    mlines.Line2D([], [], color=color, linewidth=3, label=etype)
    for etype, color in EDGE_COLORS.items()
]

plt.legend(
    handles=layer_patches + node_handles + edge_handles,
    loc="upper left",
    bbox_to_anchor=(1.05, 1.0),
    fontsize=9
)

plt.tight_layout()
plt.show()

print("Layers present:", sorted({d.get('layer') for _, d in H.nodes(data=True)}))