# **Understanding DGL: A Deep Dive**
## **Introduction**
In this notebook, I am diving deep into how the **Deep Graph Library (DGL)** works, particularly focusing on the `update_all` function. This function is a key operation in DGL that enables efficient **message passing** and **feature aggregation** in graph neural networks (GNNs).

In addition to `local_scope()`


## `update_all` Function

### Example 1: Simple graph with float data

In [None]:
import dgl
import torch
import dgl.function as fn

# Step 1: Create a graph (with multi-edges)
g = dgl.graph(([0, 0, 2, 3, 3], [1, 1, 1, 2, 2]))  # Multi-edges exist (0→1 and 3→2)

# Step 2: Assign node features
g.ndata['h'] = torch.tensor([1.0, 0.0, 2.0, 3.0])  # Node feature
g.ndata['nid'] = torch.tensor([0, 1, 2, 3])  # Node IDs

# Step 3: Assign edge features
g.edata['weight'] = torch.tensor([0.5, 1.5, 1.0, 1.5, 2.0])  # Edge weights
g.edata['id'] = torch.tensor([0, 1, 2, 3, 4])  # Edge IDs



In [8]:
# Step 4: Define a message function (using edge weight)
def message_func(edges):
    msg = edges.src['h'] * edges.data['weight']  # Compute message
    g.edata['m'] = msg  # Manually store the message
    return {'m': msg}  # Also return for update_all

# Step 5: Apply `update_all` with mean aggregation
g.update_all(message_func, fn.mean('m', 'h_neigh'))

# Print results
print("Messages on edges (m):")
print(g.edata['m'])  # Edge messages before aggregation

print("\nUpdated node features after aggregation (h_neigh):")
print(g.ndata['h_neigh'])  # Aggregated values at nodes


Messages on edges (m):
tensor([0.5000, 1.5000, 2.0000, 4.5000, 6.0000])

Updated node features after aggregation (h_neigh):
tensor([0.0000, 1.3333, 5.2500, 0.0000])


Here is the explanation of the results:

| Edge | Source → Destination | Source h | Weight | Message (m = h * weight) |
|------|----------------------|----------|--------|--------------------------|
| 0    | 0 → 1                | 1.0      | 0.5    | **0.5**                  |
| 1    | 0 → 1                | 1.0      | 1.5    | **1.5**                  |
| 2    | 2 → 1                | 2.0      | 1.0    | **2.0**                  |
| 3    | 3 → 2                | 3.0      | 1.5    | **4.5**                  |
| 4    | 3 → 2                | 3.0      | 2.0    | **6.0**                  |

Aggregation Results:

| Node | Incoming Messages  | Mean Aggregation (h_neigh) |
|------|------------------|--------------------------|
| 0    | None            | **0.0000**                |
| 1    | {0.5, 1.5, 2.0} | **1.3333**                |
| 2    | {4.5, 6.0}      | **5.2500**                |
| 3    | None            | **0.0000**                |


If we use `fn.sum('m', 'h_neigh')`:
| Node | Sum Aggregation (h_neigh) |
|------|--------------------------|
| 0    | **0.0000**                |
| 1    | **4.0000**                |
| 2    | **10.5000**               |
| 3    | **0.0000**                |

If we use `fn.max('m', 'h_neigh')`:
| Node | Max Aggregation (h_neigh) |
|------|--------------------------|
| 0    | **0.0000**                |
| 1    | **2.0000**                |
| 2    | **6.0000**                |
| 3    | **0.0000**                |



### Example 2: Custom reduce function

In [38]:
# Step 1: Create a graph (with multi-edges)
g = dgl.graph(([0, 0, 2, 3, 3], [1, 1, 1, 2, 2]))  # Multi-edges exist (0→1 and 3→2)

# Step 2: Assign node features
g.ndata['h'] = torch.tensor([1.0, 0.0, 2.0, 3.0])  # Node feature
g.ndata['nid'] = torch.tensor([0, 1, 2, 3])  # Node IDs

# Step 3: Assign edge features
g.edata['weight'] = torch.tensor([0.5, 1.5, 1.0, 1.5, 2.0])  # Edge weights
g.edata['id'] = torch.tensor([0, 1, 2, 3, 4])  # Edge IDs

In [49]:
# Step 3: Define a message function
def message_func(edges):
    msg = edges.src['h']  # Compute message
    g.edata['m'] = msg  # Manually store the message
    print(f"==>> msg: {msg}")
    return {'m': msg}  # Also return for update_all

# Step 4: Define a reduce function that computes the product of all messages
def reduce_func(nodes):
    print(f"==>> nodes: {nodes.mailbox['m']}")
    
    # Ensure the tensor shape is correct for concatenation/storage
    mul = torch.prod(nodes.mailbox['m'], dim=1, keepdim=True)  # Ensure (num_nodes, 1)
    
    print(f"==>> mul shape: {mul.shape}")  # Debugging
    return {'m_prod': mul}  # Return correct shape

# Step 5: Apply `update_all`
g.update_all(message_func, reduce_func)

# Print results
print("Messages on edges (m):")
print(g.edata['m'])  # Edge messages before aggregation

print("\nUpdated node features after aggregation (m_prod):")
print(g.ndata['m_prod'])  # Aggregated values at nodes

==>> msg: tensor([1., 1., 2., 3., 3.])
==>> nodes: tensor([[3., 3.]])
==>> mul shape: torch.Size([1, 1])
==>> nodes: tensor([[1., 1., 2.]])
==>> mul shape: torch.Size([1, 1])
Messages on edges (m):
tensor([1., 1., 2., 3., 3.])

Updated node features after aggregation (m_prod):
tensor([[0.],
        [2.],
        [9.],
        [0.]])


### Example 3: Simple graph with 2D data

In [13]:
# Step 1: Create a simple directed graph (with multi-edges)
g = dgl.graph(([0, 0, 2, 3, 3], [1, 1, 1, 2, 2]))  # Multi-edges exist (0→1 and 3→2)

# Step 2: Assign 2D node features
g.ndata['h'] = torch.tensor([
    [1.0, 0.5],  # Node 0
    [0.0, 0.5],  # Node 1
    [2.0, 0.5],  # Node 2
    [3.0, 0.5]   # Node 3
])  # Shape: (num_nodes, 2)

g.ndata['nid'] = torch.tensor([0, 1, 2, 3])  # Node IDs

# Step 3: Assign 2D edge features
g.edata['weight'] = torch.tensor([
    [0.5, 0.5],  # Edge 0→1
    [1.5, 0.5],  # Edge 0→1 (multi-edge)
    [2.0, 0.0],  # Edge 2→1
    [1.5, 2.0],  # Edge 3→2
    [2.0, 1.0]   # Edge 3→2 (multi-edge)
])  # Shape: (num_edges, 2)

g.edata['id'] = torch.tensor([0, 1, 2, 3, 4])  # Edge IDs

In [17]:
# Step 4: Define a message function with sum reduction
def message_func(edges):
    # Step 1: Concatenate source node features and edge features
    msg_temp = torch.cat([edges.src['h'], edges.data['weight']], dim=1)  # Shape: (num_edges, 4)
    g.edata['msg_temp'] = msg_temp
    # Step 2: Sum across the feature dimension to reduce to (num_edges, 1)
    msg_final = msg_temp.sum(dim=1, keepdim=True)  # Shape: (num_edges, 1)
    g.edata['msg_final'] = msg_final
    
    return {'m': msg_final}  # Store the computed messages

# Step 5: Apply `update_all` with mean aggregation
g.update_all(message_func, fn.mean('m', 'h_neigh'))

# Print results
print("Messages on edges (msg_temp):")
print(g.edata['msg_temp'])

# Print results
print("Messages on edges (msg_final):")
print(g.edata['msg_final'])

print("\nUpdated node features after aggregation (h_neigh):")
print(g.ndata['h_neigh'])  # Aggregated values at nodes

Messages on edges (msg_temp):
tensor([[1.0000, 0.5000, 0.5000, 0.5000],
        [1.0000, 0.5000, 1.5000, 0.5000],
        [2.0000, 0.5000, 2.0000, 0.0000],
        [3.0000, 0.5000, 1.5000, 2.0000],
        [3.0000, 0.5000, 2.0000, 1.0000]])
Messages on edges (msg_final):
tensor([[2.5000],
        [3.5000],
        [4.5000],
        [7.0000],
        [6.5000]])

Updated node features after aggregation (h_neigh):
tensor([[0.0000],
        [3.5000],
        [6.7500],
        [0.0000]])


Message Computation:

Each edge computes a **single value** message by:

1. **Concatenating features** `[h1, h2, w1, w2]`
2. **Summing all values** across `dim=1` to output **one scalar per edge**.

|Edge|Concatenated Values|Summed (`m`)|
|---|---|---|
|0 → 1|`[1.0, 0.5, 0.5, 0.5]`|**2.0**|
|0 → 1|`[1.0, 0.5, 1.5, 0.5]`|**3.5**|
|2 → 1|`[2.0, 0.5, 2.0, 0.0]`|**4.5**|
|3 → 2|`[3.0, 0.5, 1.5, 2.0]`|**7.0**|
|3 → 2|`[3.0, 0.5, 2.0, 1.0]`|**6.5**|


Aggregation (Mean):

| Node | Incoming Messages | Mean Aggregation (`h_neigh`)       |
| ---- | ----------------- | ---------------------------------- |
| 0    | None              | **0.0000** (no change)             |
| 1    | {2.0, 3.5, 4.5}   | **(2.0 + 3.5 + 4.5) / 3 = 3.3333** |
| 2    | {7.0, 6.5}        | **(7.0 + 6.5) / 2 = 6.7500**       |
| 3    | None              | **0.0000** (no change)             |


## `local_scope()` Function:

The function local_scope() in DGL creates a temporary graph scope, meaning any changes made to node features (ndata) or edge features (edata) within the block do not persist outside of it.

In [18]:
# Step 1: Create a simple graph
g = dgl.graph(([0, 1, 2], [1, 2, 3]))  # A small directed graph

# Step 2: Assign node features
g.ndata['h'] = torch.tensor([1.0, 2.0, 3.0, 4.0])  # Node features

print("Before local_scope - Original node features:")
print(g.ndata['h'])  # Print original node features

# Step 3: Use local_scope() to temporarily modify the graph
with g.local_scope():
    # Apply a transformation inside the local scope
    g.ndata['h'] = g.ndata['h'] * 2  # Multiply node features by 2
    print("\nInside local_scope - Modified node features:")
    print(g.ndata['h'])  # ✅ Temporary change only inside this block

# Step 4: Print the node features again (should remain unchanged)
print("\nAfter local_scope - Node features remain unchanged:")
print(g.ndata['h'])  # ✅ Original values restored

Before local_scope - Original node features:
tensor([1., 2., 3., 4.])

Inside local_scope - Modified node features:
tensor([2., 4., 6., 8.])

After local_scope - Node features remain unchanged:
tensor([1., 2., 3., 4.])
