In [1]:
from collections import defaultdict

from plotly.subplots import make_subplots
import networkx as nx
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import re

```sparql
PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>
PREFIX rdf: <http://www.w3.org/1999/02/22-rdf-syntax-ns#>

SELECT ?sc ?p ?oc (COUNT(?s) AS ?count)
WHERE {
  ?s ?p ?o .
  OPTIONAL { ?s <https://w3id.org/biolink/vocab/category> ?sc . }
  OPTIONAL { ?o <https://w3id.org/biolink/vocab/category> ?oc . }
  FILTER(CONTAINS(STR(?p),"https://w3id.org/biolink/vocab/"))
}
GROUP BY ?sc ?p ?oc
order by ?sc ?p ?oc

# EXCLUDES
#http://purl.org/dc/terms/description
#https://www.example.org/UNKNOWN/meta
#https://www.example.org/UNKNOWN/primary_knowledge_source.1
#https://www.example.org/UNKNOWN/subsets
#rdf:object
#rdf:predicate
#rdf:subject
#rdf:type
#rdfs:label
```

In [2]:
relations_data = "../assets/kg-microbe-types-bioloink-relations.csv"

In [3]:
# Read the CSV file
df = pd.read_csv(relations_data)

In [4]:
print("Data shape:", df.shape)

Data shape: (203, 4)


In [5]:
# Clean the count column (remove RDF literal formatting)
def extract_count(count_str):
    """Extract integer from RDF literal format like '\"123\"^^xsd:integer'"""
    if pd.isna(count_str):
        return 0
    match = re.search(r'"(\d+)"', str(count_str))
    return int(match.group(1)) if match else 0

In [6]:
df['count_clean'] = df['count'].apply(extract_count)

In [7]:
# Filter for complete sc-p-oc patterns (remove rows with missing sc or oc)
network_df = df.dropna(subset=['sc', 'oc']).copy()

In [8]:
print(f"\nRows with complete sc-p-oc patterns: {len(network_df)}")


Rows with complete sc-p-oc patterns: 139


In [9]:
# Function to extract label from IRI (part after last slash)
def get_label(iri):
    """Extract the label from an IRI (part after the last slash)"""
    if pd.isna(iri):
        return ''
    return str(iri).split('/')[-1]

In [10]:
network_df['sc_label'] = network_df['sc'].apply(get_label)
network_df['oc_label'] = network_df['oc'].apply(get_label)
network_df['p_label'] = network_df['p'].apply(get_label)

In [11]:
print("\nCount statistics:")
print(f"Min: {network_df['count_clean'].min()}")
print(f"Max: {network_df['count_clean'].max()}")
print(f"Median: {network_df['count_clean'].median()}")


Count statistics:
Min: 1
Max: 1006639
Median: 435.0


In [12]:
# Remove rows where either sc_label or oc_label is NamedThing 
network_df = network_df[~((network_df['sc_label'] == 'NamedThing') | (network_df['oc_label'] == 'NamedThing'))]


In [13]:
network_df = network_df[network_df['p_label'] != 'subclass_of']


In [14]:
network_df

Unnamed: 0,sc,p,oc,count,count_clean,sc_label,oc_label,p_label
6,https://w3id.org/biolink/vocab/AnatomicalEntity,https://w3id.org/biolink/vocab/location_of,https://w3id.org/biolink/vocab/OrganismTaxon,"""1763""^^xsd:integer",1763,AnatomicalEntity,OrganismTaxon,location_of
8,https://w3id.org/biolink/vocab/AnatomicalEntity,https://w3id.org/biolink/vocab/related_to,https://w3id.org/biolink/vocab/AnatomicalEntity,"""631""^^xsd:integer",631,AnatomicalEntity,AnatomicalEntity,related_to
9,https://w3id.org/biolink/vocab/AnatomicalEntity,https://w3id.org/biolink/vocab/related_to,https://w3id.org/biolink/vocab/EnvironmentalFe...,"""8""^^xsd:integer",8,AnatomicalEntity,EnvironmentalFeature,related_to
16,https://w3id.org/biolink/vocab/BiologicalProcess,https://w3id.org/biolink/vocab/capable_of,https://w3id.org/biolink/vocab/BiologicalProcess,"""2""^^xsd:integer",2,BiologicalProcess,BiologicalProcess,capable_of
19,https://w3id.org/biolink/vocab/BiologicalProcess,https://w3id.org/biolink/vocab/enabled_by,https://w3id.org/biolink/vocab/BiologicalProcess,"""435""^^xsd:integer",435,BiologicalProcess,BiologicalProcess,enabled_by
...,...,...,...,...,...,...,...,...
194,https://w3id.org/biolink/vocab/PhenotypicQuality,https://w3id.org/biolink/vocab/consumes,https://w3id.org/biolink/vocab/ChemicalEntity,"""44""^^xsd:integer",44,PhenotypicQuality,ChemicalEntity,consumes
195,https://w3id.org/biolink/vocab/PhenotypicQuality,https://w3id.org/biolink/vocab/consumes,https://w3id.org/biolink/vocab/ChemicalSubstance,"""14""^^xsd:integer",14,PhenotypicQuality,ChemicalSubstance,consumes
197,https://w3id.org/biolink/vocab/PhenotypicQuality,https://w3id.org/biolink/vocab/is_assessed_by,https://w3id.org/biolink/vocab/PhenotypicQuality,"""98""^^xsd:integer",98,PhenotypicQuality,PhenotypicQuality,is_assessed_by
199,https://w3id.org/biolink/vocab/PhenotypicQuality,https://w3id.org/biolink/vocab/related_to,https://w3id.org/biolink/vocab/BiologicalProcess,"""65""^^xsd:integer",65,PhenotypicQuality,BiologicalProcess,related_to


In [16]:
network_df.to_csv("network_df.tsv", sep='\t', index=False)

In [None]:
# Create nodes and edges for the network
# Get all unique classes (both subject and object)
all_classes = pd.concat([
    network_df[['sc', 'sc_label']].rename(columns={'sc': 'iri', 'sc_label': 'label'}),
    network_df[['oc', 'oc_label']].rename(columns={'oc': 'iri', 'oc_label': 'label'})
]).drop_duplicates()

In [None]:
# Create node mapping
node_list = list(all_classes['iri'].unique())
node_indices = {node: i for i, node in enumerate(node_list)}
node_labels = {row['iri']: row['label'] for _, row in all_classes.iterrows()}

In [None]:
print(f"\nNetwork size:")
print(f"Nodes: {len(node_list)}")
print(f"Edges: {len(network_df)}")

In [None]:
# Create NetworkX graph for better layout
G = nx.Graph()

In [None]:
# Add nodes
for node in node_list:
    G.add_node(node, label=node_labels[node])


In [None]:
# Add edges (for layout purposes, we'll use undirected but visualize as directed)
for _, row in network_df.iterrows():
    G.add_edge(row['sc'], row['oc'],
               predicate=row['p_label'],
               weight=row['count_clean'])

In [None]:
# Get node positions using spring layout
pos = nx.spring_layout(G, k=3, iterations=50, seed=42)

In [None]:
# Prepare data for Plotly
node_x = [pos[node][0] for node in node_list]
node_y = [pos[node][1] for node in node_list]
node_text = [node_labels[node] for node in node_list]

In [None]:
# Calculate edge traces
edge_x = []
edge_y = []
edge_info = []

In [None]:

# For edge width scaling - use log scale due to wide range
min_count = network_df['count_clean'].min()
max_count = network_df['count_clean'].max()

In [None]:
def scale_edge_width(count, min_width=0.5, max_width=10):
    """Scale edge width using logarithmic scaling"""
    if count <= 0:
        return min_width
    log_count = np.log10(count)
    log_min = np.log10(min_count)
    log_max = np.log10(max_count)

    # Normalize to 0-1 range
    normalized = (log_count - log_min) / (log_max - log_min)
    return min_width + normalized * (max_width - min_width)

In [None]:

edge_traces = []
edge_annotations = []
edge_labels = []

for _, row in network_df.iterrows():
    x0, y0 = pos[row['sc']]
    x1, y1 = pos[row['oc']]

    width = scale_edge_width(row['count_clean'])

    # Create edge line
    edge_trace = go.Scatter(
        x=[x0, x1, None],
        y=[y0, y1, None],
        mode='lines',
        line=dict(width=width, color='rgba(100, 150, 200, 0.6)'),
        hoverinfo='text',
        hovertext=f"<b>{row['sc_label']}</b> --{row['p_label']}--> <b>{row['oc_label']}</b><br>Count: <b>{row['count_clean']:,}</b>",
        showlegend=False
    )
    edge_traces.append(edge_trace)

    # Add arrow annotation
    # Calculate arrow position (slightly before the target node)
    dx = x1 - x0
    dy = y1 - y0
    length = (dx**2 + dy**2)**0.5
    if length > 0:
        # Position arrow 90% along the edge
        arrow_x = x0 + 0.9 * dx
        arrow_y = y0 + 0.9 * dy

        edge_annotations.append(
            dict(
                x=arrow_x, y=arrow_y,
                ax=x0 + 0.8 * dx, ay=y0 + 0.8 * dy,
                xref='x', yref='y',
                axref='x', ayref='y',
                arrowhead=2,
                arrowsize=1.5,
                arrowwidth=width/2,
                arrowcolor='rgba(80, 120, 180, 0.8)',
                showarrow=True,
                hovertext=f"{row['sc_label']} → {row['oc_label']}<br>{row['p_label']}: {row['count_clean']:,}"
            )
        )

    # Add edge labels for ALL relationships
    mid_x = (x0 + x1) / 2
    mid_y = (y0 + y1) / 2

    edge_label = go.Scatter(
        x=[mid_x],
        y=[mid_y],
        mode='text',
        text=[row['p_label']],
        textfont=dict(size=9, color='darkblue', family='Arial Black'),
        textposition="middle center",
        hoverinfo='text',
        hovertext=f"<b>{row['sc_label']}</b> --{row['p_label']}--> <b>{row['oc_label']}</b><br>Count: <b>{row['count_clean']:,}</b>",
        showlegend=False
    )
    edge_labels.append(edge_label)

# Create node trace with draggable nodes
node_trace = go.Scatter(
    x=node_x,
    y=node_y,
    mode='markers+text',
    marker=dict(
        size=15,
        color='lightcoral',
        line=dict(width=2, color='darkred'),
        symbol='circle'
    ),
    text=node_text,
    textposition="middle center",
    textfont=dict(size=11, color='black', family='Arial Black'),
    # hoverinfo='text',
    # hovertext=[f"<b>{label}</b><br>Full IRI: {iri}<br><i>Drag to reposition!</i>"
    #            for iri, label in zip(node_list, node_text)],
    name='Classes'
)

# Create the figure with improved styling and draggable nodes
fig = go.Figure(data=[*edge_traces, *edge_labels, node_trace])

fig.update_layout(
    title=dict(
        text="<b>RDF Class Relationship Network</b><br><sub>Edge width ∝ relationship frequency (log scale) • Arrows show direction </sub>",
        x=0.5,
        font=dict(size=18, color='darkslategray')
    ),
    showlegend=False,
    hovermode='closest',
    # dragmode='pan',  # Enable dragging
    margin=dict(b=40,l=40,r=40,t=100),
    height=800,
    width=1200,
    annotations=edge_annotations + [
        dict(
            text="💡 <b>Drag nodes to rearrange the layout</b> • Hover for details • Thicker edges = more frequent relationships",
            showarrow=False,
            xref="paper", yref="paper",
            x=0.5, y=-0.05,
            xanchor='center', yanchor='bottom',
            font=dict(color='gray', size=12)
        )
    ],
    xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
    yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
    plot_bgcolor='white',
    paper_bgcolor='#fafafa'
)

# Show the plot
fig.show()
