In [None]:
import tensorflow as tf
from algomorphism.models import GAE
from datasets import GaeDataset
from algomorphism.figures.nn import multiple_models_history_figure
from algomorphism.methods.graphs import a2g
import networkx as nx
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA

In [None]:
# optional for GPU usage

# for gpu in tf.config.list_physical_devices('GPU'):
#     print(gpu)
#     tf.config.experimental.set_memory_growth(gpu, True)

In [None]:
gaed = GaeDataset()

In [None]:
for x, atld, a in gaed.train:
    A = a
    X = x
    Atld = atld

In [None]:
w_p = float(A.shape[0]*A.shape[1] * A.shape[1] - tf.reduce_sum(A)) / tf.reduce_sum(A)
norm = A.shape[1] * A.shape[2] / float((A.shape[1] * A.shape[2] - tf.reduce_sum(A)) * 2)

df_list = [A.shape[1], 32, 64]

gae = GAE(gaed, df_list ,w_p, norm=norm, optimizer="Adam", learning_rate=1e-4, ip_weights=True)

In [None]:
gae.set_lr_rate(1e-2)
gae.set_clip_norm(0.0)
gae.train(gaed, 200)

In [None]:
ahat = tf.nn.sigmoid(gae([X,Atld])[0]).numpy()[0]
plt.imshow(ahat, vmin=0, vmax=1)
ahat

In [None]:
g = a2g((ahat>0.5).astype(int))
nx.draw(g, with_labels=True,font_size=20,node_size=600)

In [None]:
if any(gae.history.values()):
    multiple_models_history_figure([gae])

In [None]:
idx_save = "64"
idx_load = ""

In [None]:
Labels = ['0', '1', '2', '3', '4', '5']
emb = gae.encoder(X,Atld)
emb = emb[0]
emb_dict = dict(zip(Labels,emb.numpy()))
emb.shape

In [None]:
import numpy as np
np.save("../data/gae/gae-node-embs{}".format(idx_save), emb_dict)

In [None]:
pca = PCA(n_components=2)
pca_emb = pca.fit_transform(emb)

In [None]:
plt.figure(figsize=(10,5))
plt.plot(pca_emb[:,0], pca_emb[:,1],'o', label='node embeddings', markersize=15)

for (v,l) in zip(pca_emb, Labels):
    plt.text(v[0],v[1],l,fontsize=20)
plt.legend()
# plt.savefig('DataFigures/gae/node-embs.eps', format='eps')

In [None]:
################################################################
# from datasets import ZeroShotDataset
#
# embs_id = '64'
# embsD = 64
#
# dz = ZeroShotDataset(embs_id)
# labels = ['0','1','2']

In [None]:
# seen_labels = ['1', '2']
# unseen_labels = ['0']
#
# r_embs_dict_pr = dz.r_emb_disct_preprosesing(seen_labels, unseen_labels)
# r_emb_dict = {}
# for k in seen_labels:
#     r_emb_dict[k] = emb_dict[k]
#
# for k in unseen_labels:
#     r_emb_dict[k] = emb_dict[k]
#
# r_embs_pr = np.array([v for v in r_embs_dict_pr.values()])
# r_embs = np.array([v for v in r_emb_dict.values()])
#
# d0 = 4
# d1 = 3
# plt.plot(r_embs[:, d0], r_embs[:, d1], 'o')
# plt.plot(r_embs_pr[:, d0], r_embs_pr[:, d1], '+')
#
#
# tsne = PCA(n_components=2)
# tsne_emb = tsne.fit_transform(r_embs_pr)
#
# plt.figure(figsize=(10,5))
# plt.plot(tsne_emb[:,0], tsne_emb[:,1],'o', label='node embeddings', markersize=15)
#
# for (v,l) in zip(tsne_emb, r_emb_dict.keys()):
#     plt.text(v[0],v[1],l,fontsize=20)
# plt.legend()


