In [None]:
!pip install torch==2.1.2 --index-url https://download.pytorch.org/whl/cpu
!pip install dgl -f https://data.dgl.ai/wheels/torch-2.1/repo.html
!pip install git+https://github.com/masaponto/dglex.git

In [None]:
import dgl
import torch
import numpy as np
from dglex.visualisation import plot_graph, plot_subgraph_with_neighbors

## For homogeneous Graph

Here we define a homogeneous graph using dgl like below:

In [None]:
src = np.array([0, 1, 1, 3, 2, 5, 3, 4, 5, 1])
dst = np.array([1, 2, 3, 4, 4, 1, 2, 5, 0, 0])
homo_graph = dgl.graph((src, dst))
edge_weight = torch.Tensor([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0])
homo_graph.edata["weight"] = edge_weight

A simple way to plot a homogeneous garph is to use the 'plot_garph' function.

In [None]:
plot_graph(homo_graph, title="homogeneous graph")

If you want to plot a graph with edge weights, you can use the following code. 

In [None]:
plot_graph(
    homo_graph, title="homogeneous_graph with edge-weights", edge_weight_name="weight"
)

In addition, you can change your custom-defined node labels. The node labels must be a dictionary with the node ID as the key and the node label as the value.

In [None]:
node_labels = {
    0: "toh-chan",
    1: "ichiro",
    2: "jiro",
    3: "saburo",
    4: "shiro",
    5: "goro",
}
plot_graph(
    homo_graph,
    title="homogeneous graph with edge-weights and custom defined node labels",
    edge_weight_name="weight",
    node_labels=node_labels,
)

# For Hetegeneous Graph
If a graph has multiple node types and edge types, it is called a heterogeneous graph.
The plot_graph function supports plotting heterogeneous graphs, and its usage is the same as for homogeneous graphs.

In [None]:
follow_src = np.array([6, 3, 7, 4, 6, 9, 2, 6, 7, 4])
follow_dst = np.array([3, 7, 7, 2, 5, 4, 1, 7, 5, 1])
click_src = np.array([4, 0, 9, 5, 8, 0, 9, 2, 6, 3])
click_dst = np.array([8, 2, 4, 2, 6, 4, 8, 6, 1, 3])
dislike_src = np.array([8, 1, 9, 8, 9, 4, 1, 3, 6, 7])
dislike_dst = np.array([2, 0, 3, 1, 7, 3, 1, 5, 5, 9])

hetero_graph = dgl.heterograph(
    {
        ("user", "follow", "user"): (follow_src, follow_dst),
        ("user", "followed-by", "user"): (follow_dst, follow_src),
        ("user", "click", "item"): (click_src, click_dst),
        ("item", "clicked-by", "user"): (click_dst, click_src),
        ("user", "dislike", "item"): (dislike_src, dislike_dst),
        ("item", "disliked-by", "user"): (dislike_dst, dislike_src),
    }
)

reverse_etypes = {
    "click": "clicked-by",
    "dislike": "disliked-by",
    "follow": "followed-by",
    "followed-by": "follow",
    "clicked-by": "click",
    "disliked-by": "dislike",
}

edge_weight = {
    ("user", "follow", "user"): torch.Tensor(
        [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
    ),
    ("user", "followed-by", "user"): torch.Tensor(
        [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
    ),
    ("user", "click", "item"): torch.Tensor(
        [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
    ),
    ("item", "clicked-by", "user"): torch.Tensor(
        [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
    ),
    ("user", "dislike", "item"): torch.Tensor(
        [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
    ),
    ("item", "disliked-by", "user"): torch.Tensor(
        [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
    ),
}

hetero_graph.edata["weight"] = edge_weight

Here is an example of how to plot a heterogeneous graph. Different node types and edge types are represented by different colors.

In [None]:
plot_graph(hetero_graph, title="heterogeneous graph", figsize=(8, 6))

If there are edges with reverse edge relationships, you can set reverse_etypes as follows to plot them in the same color.

In [None]:
plot_graph(
    hetero_graph,
    title="heterogeneous graph",
    reverse_etypes=reverse_etypes,
    figsize=(8, 6),
)

An example of Edge weights and node labels for heterogeneous graph.

In [None]:
node_labels = {
    "item": {i: f"movie{i}" for i in range(10)},
    "user": {i: f"user{i}" for i in range(10)},
}

plot_graph(
    hetero_graph,
    title="heterogeneous graph",
    reverse_etypes=reverse_etypes,
    figsize=(8, 6),
    edge_weight_name="weight",
    node_labels=node_labels,
)

# For Subgraph 
If your graph is too large, you can plot a subgraph by specifying the node IDs you want to plot.  

When the target node IDs are specified, the subgraph will only contain the target nodes and their neighbors based on the number of hops. Note that now only support the subgraph contains in-comming edge of the target nodes.

In [None]:
plot_subgraph_with_neighbors(
    homo_graph,
    target_nodes=[0],
    n_hop=2,
    edge_weight_name="weight",
)

A fanouts parameter is a list of integers, which specifies the number of edge for each hop. The length of fanouts should be equal to the number of hops.

In [None]:
plot_subgraph_with_neighbors(
    homo_graph,
    target_nodes=[2],
    n_hop=2,
    edge_weight_name="weight",
    fanouts=[2, 2],
)

For heterogeneous graphs, the target_nodes should be a dictionary with the node type as the key and the list of node IDs as the value. 

In [None]:
plot_subgraph_with_neighbors(
    hetero_graph, target_nodes={"user": [0]}, n_hop=2, fanouts=[2, 2], reverse_etypes=reverse_etypes, edge_weight_name="weight"
)