# GCN Depth train example

In [None]:
import os
import tensorflow as tf
import numpy as np
from algomorphism.models import GCNClassifier
from algomorphism.methods.graphs import a2g, graphs_stats
from algomorphism.datasets.generated_data import SimpleGraphsDataset
import networkx as nx
import matplotlib.pyplot as plt
from sklearn.preprocessing import LabelBinarizer
from itertools import product

In [None]:
# optional for GPU usage

# for gpu in tf.config.list_physical_devices('GPU'):
#     print(gpu)
#     tf.config.experimental.set_memory_growth(gpu, True)

In [None]:
NewData = False
DataId = 0
saveData = False
path = '../data/gcn-depth/GraphsDataset_{}.npy'
if NewData:
    graph_types = [
        'cycle',
        'hypercube',
        'circular_ladder',
        'grid'
    ]
    examples = 400
    minN = 6
    maxN = 24

    sgsd = SimpleGraphsDataset(300, 6, 24)
    sgsd.generate_dataset()
    x, a, atld, y = sgsd.get_train_data()
    x = tf.cast(x, tf.float32)
    a = tf.cast(a, tf.float32)
    atld = tf.cast(atld,tf.float32)
    y = tf.cast(y, tf.float32)
    if saveData:
        di = 0
        while True:
            if not os.path.exists(path.format(di)):
                np.save(path.format(di),
                    {
                        "a": a,
                        "atld": atld,
                        "x": x,
                        "y": y
                })
                break
            di += 1

else:
    a, atld, x, y = np.load(path.format(DataId), allow_pickle=True).tolist().values()

maxD, depth_dist, maxDs, edgesN = graphs_stats(a)
train = tf.data.Dataset.from_tensor_slices((atld, x, y)).batch(128)

class DummyDataset(object):
    def __init__(self, train):
        self.train = train

dd = DummyDataset(train)

In [None]:
save_fig = False

vals = np.array(list(maxDs.values()))
vals = vals/np.sum(vals)

plt.figure(figsize=(16,4))
plt.bar(maxDs.keys(),vals)
plt.xticks(list(maxDs.keys()))
plt.xlabel(r'$Max \; Depths \; \#$', fontsize=20)
plt.ylabel(r'$Verteces \; \%$', fontsize=20)
if save_fig:
    plt.savefig('DataFigures/gcn-depth/{}.eps'.format('max-depth-dist'),format='eps')

In [None]:
c = 0
ymax = y.shape[1]
for i in range(a.shape[0]):
    if c == tf.argmax(y[i]):
        c+=1
        g = a2g(a[i])
        plt.figure(figsize=(8,8))
        nx.draw(g)
    elif c == ymax:
        break

## Grid Training

In [None]:
R = list(range(min(maxDs.keys()),max(maxDs.keys())+1))
E = list(range(100,400,100))

In [None]:
di = 0
load_gcn_depth_results = False
if load_gcn_depth_results:
    NGScost = np.load("../data/gcn-depth/GScost_{}.npy".format(di),allow_pickle=True)[()]
    NGScost = NGScost["GScost"]

else:
    Nexpr = 10


    Nenum = list(product(list(range(Nexpr)), list(range(len(R))),list(range(len(E)))))
    NGS = list(product(list(range(Nexpr)), R, E))

    NGScost = np.zeros((Nexpr, len(R), len(E)))


    for (n, r, e), (i, j,w) in zip(NGS,Nenum):
        df_list = [a.shape[1]] + [8]*r
        mygcn = GCNClassifier(dd, df_list, 4)
        mygcn.train(dd, e, print_types=None)
        cost = mygcn.cost_mtr.metric_dataset(dd.train)
        NGScost[i,j,w] = cost
        print('n exp: {}, r: {}, e: {}, cost: {}'.format(n,r,e,cost))



    while True:
        if not os.path.exists("GScost_{}.npy".format(di)):
            np.save("GScost_{}.npy".format(di),
                {
                    "GScost": NGScost
            })
            break
        di += 1

In [None]:
cost_mean = NGScost.mean(axis=(0,2))
cost_std = NGScost.std(axis=(0,2))

In [None]:
save_fig = False

cost_mean_min = np.min(cost_mean)
cost_mean_argmin = np.argmin(cost_mean)

plt.figure(figsize=(16,8))
plt.subplot(2,1,1)
plt.plot(R, cost_mean, 'o-')
plt.plot(R[cost_mean_argmin],cost_mean_min,'*r', markersize=12,label=r'$min \; cost$')
plt.legend(fontsize=12)
plt.xlabel(r'$GCN \; Depth \; \#$', fontsize=20)
plt.ylabel(r'$Mean \; Cost$', fontsize=20)
plt.fill_between(R, cost_mean - cost_std,
                     cost_mean + cost_std, alpha=0.1,
                     color="b")

if save_fig:
    plt.savefig('DataFigures/gcn-depth/{}.eps'.format('max-depth-grid-train'),format='eps')