# <font color='darkblue'> Part 1 : Message Passing </font>

In this notebook we will go step by step through message passing, the foundation of GNN learning. 

We will use the Toy example of the Karate Club dataset first introduced in Section 1. 

We will see how we can propagate information and how nodes can learn from their neighbours in the most simple of instances

In [None]:
import networkx as nx
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches

'''
#################################################
               YOUR CODE HERE
#################################################

Load Karate Club graph from networkx package
'''

G = nx.

In [None]:
def gen_graph_legend(node_colours, G, attr):
    """
    Generate a legend for a graph based on node colors and attributes.

    Parameters:
    - node_colours (pd.Series): A series of node colors.
    - G (networkx.Graph): The graph object.
    - attr (str): The attribute to use for labeling.

    Returns:
    - patches (list): A list of matplotlib patches representing the legend.

    """
    
    patches = []
    for col, lab in zip(node_colours.drop_duplicates(), pd.Series(nx.get_node_attributes(G, attr)).drop_duplicates()):
        patches.append(mpatches.Patch(color=col, label=lab))

    return patches

In [None]:
'''
#################################################
               YOUR CODE HERE
#################################################

Generate a node colour for each club in karate club

Draw the network using nx.draw and the correct parameters
'''
node_colours = 

# Draw Network
np.random.seed(42) # so we produce the same network each time as the layout is a stochastic algorithm
nx.draw() # inout our graph, and node colours into network using with_labels = True
plt.title('Karate Club Graph')
legend_handles = gen_graph_legend(node_colours , G , 'club') # Generate legend labels from function above
plt.legend(handles = legend_handles)
plt.show()

## <font color='darkblue'>Message Passing Features </font>
We need to specify node features which each node will pass on. In this simple use case we are going to embed the answer in the node feature. We will one-hot encode each club. e.g. Mr Hi's club will be [0,1] and Officer club will be [1,0]. We will mask the participation of some members club by encoding them as [0.5 , 0.5], making it equally likely they belong to each club.

Can we re-identify these members clubs by performing simple message passing over the network?

In [None]:
'''
#################################################
               YOUR CODE HERE
#################################################

Add node features to our network

1. Embed the features [0,1] for Mr. Hi club, [1,0] for Mr. Officer and
[0.5, 0.5] for the masked nodes

2. Print each node and its corresponding feature
'''

unknown_nodes = np.random.choice(list(G.nodes) , 10) # randomly select ten nodes for our test set
for node in G.nodes : 
    if node in unknown_nodes : 
        G.nodes[node]['feature'] = np.array() # mask node features as 50/50 chance of being Mr. Hi or Mr. Officer
        G.nodes[node]['club'] = 'Masked'
    else : 
        if G.nodes[node]['club'] == 'Mr. Hi' :
            G.nodes[node]['feature'] = np.array() # Embed Mr. Hi features as [0,1]
        elif G.nodes[node]['club'] == 'Officer' : 
            G.nodes[node]['feature'] = np.array() # Embed Mr. Officer features as [1,0]
        else : 
            print('No club or group found for node')
            
# Access updated node features after one iteration
for node in G.nodes:
    print("Node:", , "Feature:", )

In [None]:
# Draw our updated network
node_colours = pd.DataFrame(nx.get_node_attributes(G , 'feature')).T[0].astype('str').replace('0.0' , 'r').replace('0.5' , 'grey').replace('1.0' , 'skyblue')

np.random.seed(42)
nx.draw(G , with_labels = True , node_color = node_colours) # draw our network
plt.title('Karate Club Graph with hidden memberships') 
legend_handles = gen_graph_legend(node_colours , G , 'club') # Generate legend labels from function above
plt.legend(handles = legend_handles)
plt.show()

## <font color='darkblue'>Message Passing Function </font>
As defined, message passing consists of three main steps. Message propagation, aggregation and update. 

- We need each node to pass or propogate their features to their neighbours. 
- Each node then aggregates these messages using median aggregation
- Each node updates their node features with the aggregate message

Can you define two functions, the first message passing propogates and aggregates node messages and a second function which then updates the node features

In [None]:
# Define message passing function
def message_passing(node, G):
    """
    Perform message passing for a given node in a graph.

    Parameters:
    node (int): The node for which message passing is performed.
    G (networkx.Graph): The graph containing the node and its neighbors.

    Returns:
    numpy.ndarray: The aggregated message from the neighboring nodes.

    Notes:
    - This function gathers the messages for a single node and will be used in the message_passing_iteration.
    - The function performs propagation and aggregation.
    - Propagation: Gather the node features of all neighboring nodes.
    - Aggregation: Aggregate the gathered messages using median aggregation.
    """
    neighbors = list(G.neighbors(node))
    if not neighbors:
        return G.nodes[node]['feature']
    else:
        '''
        #################################################
                       YOUR CODE HERE
        #################################################
        
        Note. This function gathers the messages for a single node and will be useed
        to iterate over in the message_passing_iteration
        
        In this function you will need to : 
            1. Create a nested list of the node features of neighbouring nodes hint : 
            use list comprehension & G.nodes[neighbor]['feature'] to extract the feature
            2. Aggregate these messages for each node using np.median() or likewise
        '''

        return aggregated_message


# Perform message passing iteration
def message_passing_iteration(G):
    """
    Perform message passing iteration on a graph.

    Parameters:
    - G (networkx.Graph): The input graph.

    Returns:
    - None

    Description:
    This function performs message passing iteration on a graph.
      It iterates through all nodes in the graph and updates their
      features based on the aggregated message from neighboring nodes.

    Note:
    - The function `message_passing` is defined above.

    """
    
    updated_features = {}
    
    '''
    #################################################
                   YOUR CODE HERE
    #################################################
    
    We will now update our node features with the updated features dictionary
    
    In this function you will need to : 
        1. Iterate through all nodes
        2. create a dictionary where each node is a key and the aggregated message is the value
        3. Iterate through the dictionary created above. hint use updated_features.items()
        4. Update each nodes features
    '''

In [None]:
'''
#################################################
               YOUR CODE HERE
#################################################

Perform a single iteration of message passing and
print the updated node features
'''
# Example of one iteration of message passing

# Access updated node features after one iteration


#### Here we are plotting the karate club network after one iteration of message passing and comparing it to the original network

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(10, 5))

node_colours_message = (pd.DataFrame(nx.get_node_attributes(G , 'feature')).T[0] > 0.5).replace(True , 'skyblue').replace(False , 'r') # Assign node colours based on node features
np.random.seed(42)
nx.draw(G , ax=axes[0] , with_labels = True , node_color = node_colours_message) # draw our network
axes[0].set_title('Karate Club Graph after 1 iteration of Message Passing')

G_true = nx.karate_club_graph() # Get back original karate club graph
node_colours_true = pd.Series(nx.get_node_attributes(G_true , 'club')).replace('Mr. Hi' , 'r').replace('Officer' , 'skyblue') # Assign node colours based on true club membership
# Draw Network
np.random.seed(42)
nx.draw(G_true , ax=axes[1] , with_labels = True , node_color = node_colours_true) # draw our network
axes[1].set_title('True Karate Club Graph')
legend_handles = gen_graph_legend(node_colours_true , G_true , 'club') # Generate legend labels 
axes[0].legend(handles = legend_handles)
axes[1].legend(handles = legend_handles)

# Adjust layout
plt.tight_layout()
plt.show()

#### We are now going to run 100 more iterations of message passing to see if the node features converge to give back the original club memberships

In [None]:
# Example of 100 iterations of message passing
for _ in range(100) : 
    
# Access updated node features after one iteration

fig, axes = plt.subplots(1, 2, figsize=(10, 5))

# Plotting the updated node features and the original karate club network
node_colours_message = (pd.DataFrame(nx.get_node_attributes(G , 'feature')).T[0] > 0.5).replace(True , 'skyblue').replace(False , 'r') # Assign node colours based on node features
np.random.seed(42)
nx.draw(G , ax=axes[0] , with_labels = True , node_color = node_colours_message) # draw our network
axes[0].set_title('Karate Club Graph after 100 iterations of Message Passing')

G_true = nx.karate_club_graph() # Get back original karate club graph
node_colours_true = pd.Series(nx.get_node_attributes(G_true , 'club')).replace('Mr. Hi' , 'r').replace('Officer' , 'skyblue') # Assign node colours based on true club membership
# Draw Network
np.random.seed(42)
nx.draw(G_true , ax=axes[1] , with_labels = True , node_color = node_colours_true)  # draw our network
axes[1].set_title('True Karate Club Graph')
legend_handles = gen_graph_legend(node_colours_true , G_true , 'club') # Generate legend labels
axes[0].legend(handles = legend_handles)
axes[1].legend(handles = legend_handles)

# Adjust layout
plt.tight_layout()
plt.show()

Above we can see that we are able to reclaim the karate club assignments by sharing each nodes belief about their own club. 

Although we mislabel node 8, in general, the message passing algorithm is a very efficient method to update beliefs about unknown node labels. 

In practice, in a supervised learning setting, we will not be giving the node label as the node feature. Instead, we hope that our node feature will be an accurate representation of that nodes label. For example, our node features could be an altered gene expression in a patient with cancer. In this example, for an informative cancer network we can see envision how we will be better able to identify new cancer patients if the messages they receive are from other similar cancer patients, but more on this later.

The main difference between Graph Neural Network architectures is how they perform message passing and how they aggregate messages. Previously, we looked at a median aggregation, next we want to repeat the above using mean aggregation

## <font color='darkblue'>Mean Aggregation </font>

What happens if we change the aggregation function from median to sum or mean. For sum we can easily envision how our current features would explode greater that the 0.5 threshold for all nodes but what about mean? 

First, we will alter our node features which encodes the Officer club as [-1,0]. Introducing the -1 in hope to control the behaviour of the message passing and reaching an equillibrium with mean aggregation.

In [None]:
'''
#################################################
               YOUR CODE HERE
#################################################

Add node features to our network

1. Embed the features which you think might be best for convergence using mean aggregation

2. Print each node and its corresponding feature
'''

unknown_nodes = np.random.choice(list(G.nodes ) , 10) # randomly select ten nodes for our test set
for node in G.nodes : 
    if node in unknown_nodes : 
        G.nodes[node]['feature'] = np.array() # mask node features as 50/50 chance of being Mr. Hi or Mr. Officer
        G.nodes[node]['club'] = 'Masked'
    else : 
        if G.nodes[node]['club'] == 'Mr. Hi' :
            G.nodes[node]['feature'] = np.array() # Embed Mr. Hi as ...
        elif G.nodes[node]['club'] == 'Officer' : 
            G.nodes[node]['feature'] = np.array() # Embed Mr. Officer as ...
        else : 
            print('No club or group found for node')
            
# Access updated node features after one iteration


In [None]:
# Draw our updated network
node_colours = pd.DataFrame(nx.get_node_attributes(G , 'feature')).mean().astype('str').replace('-0.5' , 'r').replace('0.0' , 'grey').replace('0.5' , 'skyblue') # Assign node colours based on node features

np.random.seed(42)
nx.draw(G , with_labels = True , node_color = node_colours)     # draw our network
plt.title('Karate Club Graph with hidden memberships')
legend_handles = gen_graph_legend(node_colours , G , 'club')    # Generate legend labels
plt.legend(handles = legend_handles)
plt.show()

In [None]:
# Define message passing function
def message_passing(node, G):
    """
    Perform message passing on a node in a graph.

    Parameters:
    node (any): The node to perform message passing on.
    G (networkx.Graph): The graph containing the node and its neighbors.

    Returns:
    any: The aggregated message from the node's neighbors.
    """

    neighbors = list(G.neighbors(node))
    if not neighbors:
        return G.nodes[node]['feature']
    else:
        '''
        #################################################
                       YOUR CODE HERE
        #################################################
        
        Re-implement the above code but using np.mean() or likewise instead of median
        '''

        return aggregated_message

In [None]:
'''
#################################################
               YOUR CODE HERE
#################################################

Perform a single iteration of message passing and
print the updated node features
'''
# Example of one iteration of message passing

# Access updated node features after one iteration


#### Here we are plotting the karate club network after one iteration of message passing and comparing it to the original network

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(10, 5))

node_colours_message = (pd.DataFrame(nx.get_node_attributes(G , 'feature')).mean() > 0).replace(True , 'skyblue').replace(False , 'r') # Assign node colours based on node features
np.random.seed(42)
nx.draw(G , ax=axes[0] , with_labels = True , node_color = node_colours_message) # draw our network
axes[0].set_title('Karate Club Graph after 1 iteration of Message Passing')

G_true = nx.karate_club_graph() # Get back original karate club graph
node_colours_true = pd.Series(nx.get_node_attributes(G_true , 'club')).replace('Mr. Hi' , 'r').replace('Officer' , 'skyblue') # Assign node colours based on true club membership
# Draw Network
np.random.seed(42)
nx.draw(G_true , ax=axes[1] , with_labels = True , node_color = node_colours_true) # draw our network
axes[1].set_title('True Karate Club Graph')
legend_handles = gen_graph_legend(node_colours_true , G_true , 'club') # Generate legend labels
axes[0].legend(handles = legend_handles)
axes[1].legend(handles = legend_handles)

# Adjust layout
plt.tight_layout()
plt.show()

#### We are now going to run 100 more iterations of message passing to see if the node features converge to give back the original club memberships

In [None]:
# Example of 100 iterations of message passing
for _ in range(100) : 
    message_passing_iteration(G)
    
# Access updated node features after one iteration
for node in G.nodes:
    print("Node:", node, "Feature:", G.nodes[node]['feature'])

fig, axes = plt.subplots(1, 2, figsize=(10, 5))

node_colours_message = (pd.DataFrame(nx.get_node_attributes(G , 'feature')).mean() > 0).replace(True , 'skyblue').replace(False , 'r') # Assign node colours based on node features
np.random.seed(42)
nx.draw(G , ax=axes[0] , with_labels = True , node_color = node_colours_message) # draw our network
axes[0].set_title('Karate Club Graph after 100 iterations of Message Passing')

G_true = nx.karate_club_graph() # Get back original karate club graph
node_colours_true = pd.Series(nx.get_node_attributes(G_true , 'club')).replace('Mr. Hi' , 'r').replace('Officer' , 'skyblue') # Assign node colours based on true club membership
# Draw Network
np.random.seed(42)
nx.draw(G_true , ax=axes[1] , with_labels = True , node_color = node_colours_true) # draw our network
axes[1].set_title('True Karate Club Graph')
legend_handles = gen_graph_legend(node_colours_true , G_true , 'club') # Generate legend labels
axes[0].legend(handles = legend_handles)
axes[1].legend(handles = legend_handles)

# Adjust layout
plt.tight_layout()
plt.show()

## <font color='darkblue'>Message Passing Key Takeaways </font>

1. Message passing is fast. 
2. Message passing is not perfect
        
Can we introduce a non-linearity that allows us to mean and sum aggregate our nodes to prevent feature saturation and learn node labels from features?? -> GNN's!