In [1]:
import numpy as np
import networkx as nx
from sklearn.semi_supervised import LabelPropagation
from sklearn.metrics import pairwise_distances
# from sklearn.neighbors import NearestNeighbors
import embed_utils
from sklearn.model_selection import train_test_split
from copy import deepcopy
from collections import Counter
from matplotlib import pyplot as plt

In [2]:
def show_graph(G):
    # G = nx.Graph()
    # G.add_edge(1,2)
    # G.add_edge(2,3)
    # for v in G.nodes():
    #     G.node[v]['state']='X'
    # G.node[1]['state']='Y'
    # G.node[2]['state']='Y'

    # for n in G.edges_iter():
    #     G.edge[n[0]][n[1]]['state']='X'
    # G.edge[2][3]['state']='Y'

    node_labels = nx.get_node_attributes(G,embed_utils.CLASS_NAME)
    node_attr = nx.get_node_attributes(G,embed_utils.SENSATTR)

    color_map = []
    for node in G:
        if node_attr[node] == 0:
            color_map.append('blue')
        else: 
            color_map.append('green')   

    # pos = nx.spring_layout(G)

    # nx.draw(G, pos)
    nx.draw(G, node_color=color_map, labels = node_labels)
    # nx.draw_networkx_labels(G, pos, labels = node_labels)
    # edge_labels = nx.get_edge_attributes(G,'state')
    # nx.draw_networkx_edge_labels(G, pos, labels = edge_labels)
    # plt.savefig('this.png')
    plt.show()
    return

In [12]:
def get_metrics(classifier, test_nodes, embeddings, label_dict):
    test_node_embeddings = embeddings[test_nodes]
    test_node_labels = [label_dict[node] for node in test_nodes]
    pred = classifier.predict(test_node_embeddings)
    accuracy = sum(test_node_labels == pred) / len(test_node_labels)
    
    return accuracy


def check_classification_params(nodes, labels, embeddings):
    assert nodes == list(range(len(nodes)))
    assert len(labels) == len(nodes) == len(embeddings), f"{len(labels)}, {len(nodes)}, {len(embeddings)}"


def make_classification_model(nodes, labels, embeddings):
    """
    Makes the classification model
    :param nodes: should be the nodes as list of consecutive integers
    :param labels: should be a list of labels where a value of -1
        indicates a missing label
    :param embeddings: is the embeddings of all nodes obtained using a
        (modified) random walk

    :returns: a model with a predict() function that predicts the label
        from embeddings
    """
    check_classification_params(nodes, labels, embeddings)

    g = np.mean(pairwise_distances(embeddings))
    # TODO is this the right kernel?
    clf = LabelPropagation(kernel="knn", gamma = g).fit(embeddings, labels)

    return clf


def run_classification(graph, embeddings, name_ds=""):
    # show_graph(graph)

    # print("Graph info")
    # print(f"Number of nodes: {len(graph.nodes())}")
    # print(f"Number of edges: {len(graph.edges())}")
    # print(f"Length label dict: {len(nx.get_node_attributes(graph, embed_utils.CLASS_NAME))}")
    # print(f"Length attri dict: {len(nx.get_node_attributes(graph, embed_utils.SENSATTR))}")
    # print("Label dict:", nx.get_node_attributes(graph, embed_utils.CLASS_NAME))
    # print("Attri dict:", nx.get_node_attributes(graph, embed_utils.SENSATTR))

    # Get labels and attributes of test nodes to other classes
    label_dict = deepcopy(nx.get_node_attributes(graph, embed_utils.CLASS_NAME))
    attr_dict = deepcopy(nx.get_node_attributes(graph, embed_utils.SENSATTR))
    
    # Split in to equal sized train and test nodes
    nodes = list(graph.nodes())
    train_nodes, test_nodes = train_test_split(nodes, test_size=0.5, shuffle=True)
    train_nodes = set(train_nodes)

    # Get semi-supervised labels
    n_unique_labels = len(set(label_dict.values()))
    unique_labels = (x for x in np.arange(n_unique_labels+1, len(train_nodes) + n_unique_labels + 2))
    semi_supervised_y = [label_dict[node] if node in train_nodes else next(unique_labels) for node in nodes]

    # Train the classifier
    clf = make_classification_model(nodes, semi_supervised_y, embeddings)

    # test_node_embeddings = [emb for i, emb in enumerate(embeddings) if i in test_nodes]
    # test_node_labels = [label_dict[node] for node in test_nodes]
    # pred = classifier.predict(test_node_embeddings)
    # accuracy = sum(test_node_labels == pred) / len(test_node_labels)

    # Get test node embeddings, labels and find accuracy on test nodes
    c0_nodes = [node for node in test_nodes if attr_dict[node] == 0]
    acc_c0 = get_metrics(clf, c0_nodes, embeddings, label_dict) * 100

    c1_nodes = [node for node in test_nodes if attr_dict[node] == 1]
    acc_c1 = get_metrics(clf, c1_nodes, embeddings, label_dict) * 100

    accuracy = get_metrics(clf, test_nodes, embeddings, label_dict) * 100
    disparity = np.var([acc_c0, acc_c1])
    # print(f"Accuracy c0: {acc_c0}")
    # print(f"Accuracy c1: {acc_c1}")
    # print(f"Disparity: {disparity}") 
    # print()

    # print(f"Total accuracy: {accuracy}")
    # print()
    # print(f"Counter training lables: {Counter([label_dict[node] for node in train_nodes]).most_common(3)}")
    # print(f"Counter real lables: {Counter([label_dict[node] for node in test_nodes]).most_common(3)}")
    # print(
    #     f"Counter prediction: {Counter(clf.predict([emb for i, emb in enumerate(embeddings) if i in test_nodes])).most_common(3)}"
    #     )
    # print()

    return accuracy, disparity


n_runs = 200
for dataset, reweight_method, embed_method in [["rice", "crosswalk", "deepwalk"]]:
    # get graph from data
    graph = embed_utils.data2graph(dataset)

    # get embedding from graph
    embed = embed_utils.graph2embed(
                                        graph, 
                                        reweight_method, 
                                        embed_method
                                        )

    accs = []
    disparities = []
    for i in range(n_runs):
        acc, disparity = run_classification(graph, embed)
        accs.append(acc)
        disparities.append(disparity)

In [None]:
np.mean(accs), np.mean(disparities)

(59.28699551569507, 95.03498626227366)

In [None]:
np.var(accs), np.var(disparities)

(14.073458143135795, 4760.090769273912)