In [None]:
import networkit as nk
import numpy as np
import matplotlib.pyplot as plt
import time
import phate

from tqdm.notebook import tqdm
from sklearn.preprocessing import StandardScaler

from External.WS_Tree_dist import EMD_dist_tree_discretized, EMD_dist_tree_edge, plot_flow_discretized
from External.ICT.calculate_ICT import calculate_ICT, compute_widths
from External.clustering import k_means_pp
from External.generation import create_graph

plt.style.use('standard.mplstyle')

# Hyperparameters
mode = "Full+Exp-Triangle"   # mode for the graph construction
gamma = 5                    # Gamma for the triangle break

metric = "euclidean"         # Metric for the k-means clustering
ε, δ = 0.03, 0.1             # Values for the estimation of the k for k-means

title_of_ICT = 
name_of_image = "triangle_break/"

In [None]:
# Load the data

position, labels = phate.tree.gen_dla(n_dim=2, n_branch=3,
                                            branch_length=2500, rand_multiplier=2, seed=37, sigma=1)
labels_unique = np.unique(labels)


position = StandardScaler().fit_transform(position)
number_of_nodes = len(position)

In [None]:
embedding = position

In [None]:
fig, ax = plt.subplots(1,1, figsize=(24,24))
for label in labels_unique:
    plt.plot(*embedding[np.argwhere(labels == label).T[0]].T, label=label)
ax.axis("equal")
plt.legend()
plt.show()

In [None]:
# Create the graph
start = time.time()
G, position = create_graph(number_of_nodes, mode, position=position, gamma=gamma)
G.indexEdges()
print(time.time()-start)

In [None]:
start = time.time()
# calculate the clusters
r = 1 / (ε**2) * (int(np.log2(number_of_nodes - 1)) + 1 + np.log(1/δ))
k = int(np.sqrt(r))
k = np.min((k, G.numberOfNodes()))
cluster_centers = k_means_pp(k, position, metric=metric, G=G)
print(time.time()-start)

In [None]:
# calculate the cluster ICT with all aim nodes
start = time.time()
ICT = calculate_ICT(G, algorithm_type="cluster_all", cluster_centers=cluster_centers,
                                zeros_stay_zeros=True, update_G=1.1)
ICT.indexEdges()

# widths of the edges
widths = compute_widths(ICT)
print(time.time()-start)

In [None]:
# plotting

names = [title_of_ICT, "Ground truth"]
number_of_plots = len(names)



# Create the figure
fig, ax = plt.subplots(1, 2, figsize=(24*2,24))
    
# Plot the ICT and the cluster centers
nk.viztasks.drawGraph(ICT, pos=embedding, ax=ax[0], width=widths, node_size=10)
ax[0].plot(*embedding[cluster_centers].T, marker="o", color = "Red")

# Plot the ground truth
for label in labels_unique:
    ax[1].plot(*embedding[np.argwhere(labels == label).T[0]].T, label=label)
    
# General stuff
for i in range(number_of_plots):
    ax[i].tick_params(left=True, bottom=True, labelleft=True, labelbottom=True)
    ax[i].set_axis_on()
    ax[i].set_title(names[i])
    ax[i].axis("equal")
    
plt.legend
plt.tight_layout()
plt.savefig(f"./Output/Images/{name_of_image}.png")
plt.show()