# Loading Real-world Graphs with Node Attributes

### Libs

In [None]:
## libs 
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

## Keras
from keras.layers import Lambda, Input, Dense, Conv2D, Conv2DTranspose, Flatten, Reshape
from keras.models import Model
from keras.callbacks import EarlyStopping, ModelCheckpoint
from keras.losses import mse, binary_crossentropy
from keras.utils import plot_model
from keras import backend as K

## Basic
from tqdm import tqdm_notebook as tqdm
import argparse
import os
import random
import itertools
import time

# Computation
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import mutual_info_score

import scipy
from scipy.stats.stats import pearsonr 

## Visualization
import matplotlib.pyplot as plt
import seaborn as sns

## Network Processing
import networkx as nx
from networkx.generators import random_graphs

## node colour
color_map = ["steelblue"]

### Supporting Functions

In [None]:
## supporting functions
from support.preprocessing import sort_adj, reshape_A, calculate_A_shape, reconstruct_adjacency, pad_matrix, unpad_matrix, prepare_in_out
from support.metrics import compute_mig, compute_mi
from support.graph_generating import generate_single_features, generate_manifold_features, generate_topol_manifold, generate_topol_manifold
from support.latent_space import vis2D, visDistr
from support.comparing import compare_manifold_adjacency, compare_topol_manifold

## graph sampling
from sampling import ForestFire, Metropolis_Hastings, Random_Walk, Snowball, Ties, Base_Samplers

## Load Graph Data

In [None]:
g_complete = nx.read_edgelist("/Users/niklasstoehr/Programming/thesis/4_real_attr/data/mag/edges_py.csv",  nodetype=int, delimiter = ",")
#a_complete = nx.adjacency_matrix(g_complete)

df_nodes  = pd.read_csv("/Users/niklasstoehr/Programming/thesis/4_real_attr/data/mag/nodes_py.csv", header = None)
nodes = df_nodes.values
nodes = [l.tolist() for l in list(nodes)]
nodes_num = [i[0] for i in nodes]

g.add_nodes_from(nodes_num)

node_attr_dict = dict()

for num, name, c in nodes:
    node_attr_dict[num] = c
    
print(node_attr_dict[0])
    
nx.set_node_attributes(g_complete, node_attr_dict, "citations")

## Sample Subgraph

**exact_n:** bfs, biased_random_walk, forestfire, random_walk_induced_graph_sampling, random_walk_sampling_with_fly_back, adjacency, select

**approx_n:**  snowball, standard_bfs, walk, jump

In [None]:
sampleArgs = {"sample": "biased_random_walk", "jump_bias": "random_walk_induced_graph_sampling", "n": 1000, "p": 20.0, "q": 100.0, "source_starts": 2, "source_returns": 4, "depth": 2}

##exact_n: forestfire, random_walk_induced_graph_sampling, random_walk_sampling_with_fly_back, adjacency, select
##approx_n: snowball, bfs, walk, jump

def get_graph(sampleArgs,g_complete,a_complete):
    
    if sampleArgs["sample"] == "biased_random_walk":
        sampler = Base_Samplers.Base_Samplers(g_complete,a_complete)
        #sampler = Base_Samplers(g_complete,a_complete)
        g = sampler.biased_random_walk(sampleArgs["n"], sampleArgs["p"], sampleArgs["q"])

    if sampleArgs["sample"] == "forestfire":
        sampler = ForestFire.ForestFire(g_complete,a_complete)
        g = sampler.forestfire(sampleArgs["n"])

    if sampleArgs["sample"] == "snowball":
        sampler = Snowball.Snowball(g_complete,a_complete)
        g = sampler.snowball(sampleArgs["source_starts"], sampleArgs["source_returns"])

    if sampleArgs["sample"] == "random_walk_induced_graph_sampling":
        sampler = Random_Walk.Random_Walk(g_complete,a_complete)
        g = sampler.random_walk_induced_graph_sampling(sampleArgs["n"])

    if sampleArgs["sample"] == "random_walk_sampling_with_fly_back":
        sampler = Random_Walk.Random_Walk(g_complete,a_complete)
        g = sampler.random_walk_sampling_with_fly_back(sampleArgs["n"], sampleArgs["p"])
        
    if sampleArgs["sample"] == "standard_bfs":
        sampler = Base_Samplers.Base_Samplers(g_complete,a_complete)
        g = sampler.standard_bfs(sampleArgs["source_starts"], sampleArgs["depth"]) 
        
    if sampleArgs["sample"] == "bfs":
        sampler = Base_Samplers.Base_Samplers(g_complete,a_complete)
        g = sampler.bfs(sampleArgs["n"]) 
        
    if sampleArgs["sample"] == "walk":
        sampler = Base_Samplers.Base_Samplers(g_complete,a_complete)
        g = sampler.walk(sampleArgs["source_starts"], sampleArgs["source_returns"], sampleArgs["p"])        
        
    if sampleArgs["sample"] == "jump":
        sampler = Base_Samplers.Base_Samplers(g_complete,a_complete)
        g = sampler.jump(sampleArgs["source_starts"], sampleArgs["p"], sampleArgs["jump_bias"])
        
    if sampleArgs["sample"] == "adjacency":
        sampler = Base_Samplers.Base_Samplers(g_complete,a_complete)
        g = sampler.adjacency(sampleArgs["n"]) 
        
    if sampleArgs["sample"] == "select":
        sampler = Base_Samplers.Base_Samplers(g_complete,a_complete)
        g = sampler.adjacency(sampleArgs["n"]) 
    
    return g 

start_time = time.time()
g = get_graph(sampleArgs, g_complete, a_complete)

print("-- n_max should be >=", len(g), "--")
print("-- function get_graph takes %s secs --" % round((time.time() - start_time),  5))

if len(g) <= 200:
    nx.draw(g, node_color = color_map, with_labels = False)