### What:

Trying to get node classification with deepwalk+Label Propagation to work

### TODO:

- improve documentation/comments
- add train and test split to actually test performance for classiying
    - follow the authors code for the split. See main.py in the 'classifier' folder
        - same splits and hyperparameters
    - add code for LabelPropagation() to perform the task
- compare Perozzis embeddings with karateclub
    - make their shapes compatible

In [1]:
import numpy as np
import networkx as nx
from karateclub import LabelPropagation
from sklearn.neighbors import NearestNeighbors
import embed_utils

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def k_nearest_graph(embed, original_graph, k=5):
    """"
    Construct a graph from an embedding using the k-nearest graph algorithm
    embed: embedding
    original_graph: needed for preserving node attributes
    k: nearest k neighbors to a point in the embedding space are used fix edges in the returned graph
    """
    # create new graph based on original nodes and attributes
    k_nearest_graph = nx.Graph()
    k_nearest_graph.add_nodes_from(original_graph)
    nx.set_node_attributes(k_nearest_graph, nx.get_node_attributes(original_graph, "class"),  "class")
    # find the k nearest neighbors for each node and add between the node and these neighbors
    # edges in the new graph 
    nbrs = NearestNeighbors(n_neighbors=k, algorithm='ball_tree').fit(embed)
    _, indices = nbrs.kneighbors(embed)
    k_nearest_links = [[i[0], j] for i in indices for j in i[1:]]
    k_nearest_graph.add_edges_from(k_nearest_links)
    return k_nearest_graph

In [3]:
# get graph from data
graph_synth2 = embed_utils.data2graph("./data/synthetic_n500_Pred0.7_Phom0.025_Phet0.001", "class")
print(f"#nodes: {len(graph_synth2.nodes())}, #edges (bidirectional): {len(graph_synth2.edges())}")
# get embedding from graph using DeepWalk
embed_synth2_karateclub = embed_utils.graph2embed(graph_synth2)
print(f"shape embedding: {embed_synth2_karateclub.shape}")
labels_list = list(nx.get_node_attributes(graph_synth2, "class").values())
n_class_1 = sum(labels_list)
n_class_0 = len(labels_list) - n_class_1
print(f"#nodes with class 0: {n_class_0}, #nodes with class 1: {n_class_1}")
# construct new graph from the embedding using k-nearest graph algorithm
k_nearest_graph_synth2 = k_nearest_graph(embed_synth2_karateclub, graph_synth2)
# # use label propagation on the new graph for classification
# lp = LabelPropagation()
# lp.fit(k_nearest_graph_synth2)

#nodes: 500, #edges (bidirectional): 1834
shape embedding: (500, 128)
#nodes with class 0: 150, #nodes with class 1: 350
