In [None]:
%load_ext autoreload

In [None]:
%autoreload 2

In [None]:
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

In [None]:
import networkit as nk
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import time
import random

from sklearn.preprocessing import StandardScaler
from sklearn.metrics import pairwise_distances
from IPython.display import display, Math, Latex, Markdown
from tqdm.notebook import tqdm

from cluster_filter import cfilter, cassign

from External.ICT.calculate_ICT import calculate_ICT, calculate_sub_ICTs
from External.clustering import centers, k_means_pp
from External.generation import create_graph
from External.plotting import plot_points, plot_graph, no_intersections
from External.create_k_nearest import patch_together

plt.style.use('standard.mplstyle')

In [None]:
# Hyperparameters
mode = "Full+Exp-Triangle"
metric = "euclidean"
dataset = "non_convex"
ICT_algorithm = "cluster_all"


min_cluster_size = 12
small_behavior = "reassign"

# image loading
n = number_of_nodes = 1000
Random = True
image_name = "image"

# for the filename
parameters = "2-0_5-1000"

In [None]:
def load_image(filename):
    a=plt.imread(filename+'.png')
    rgb_weights = [0.2989, 0.5870, 0.1140]
    
    grayscale_image = np.dot(a[...,:3], rgb_weights)>0
    
    return grayscale_image
    
    
def sample_points_from_image(n,img,Random=True):
    if not Random:
        random.seed(42)
    non_zero = np.where(img!=0)
    # non_zero=np.vstack((non_zero[0],non_zero[1])).T
    
    
    idx = random.sample(range(len(non_zero[0])),n)
    
    x_coord = non_zero[0][idx]
    y_coord = non_zero[1][idx]
    return x_coord,y_coord

In [None]:
# Compute the position array
img = load_image(image_name)
position = np.array(sample_points_from_image(n,img,Random)).T
position = StandardScaler().fit_transform(position)

In [None]:
for k in range(1, 61):
    cluster_centers, cluster_labels = k_means_pp(k, position, return_labels=True)

    if small_behavior == "remove":
        cluster_centers, cluster_labels, (position, ) = cfilter(cluster_centers, cluster_labels, t=min_cluster_size, position_likes=[position])
        number_of_nodes = len(position)
    if small_behavior == "reassign":
        cluster_centers, cluster_labels = cassign(cluster_centers, cluster_labels, position, t=min_cluster_size)

    sub_ICTs, components = calculate_sub_ICTs(position, cluster_centers, cluster_labels, t=min_cluster_size, mode=mode)


    # plot the ICT forest
    ICT_forest = nk.graph.Graph(n=len(position), weighted=True)
    distances = pairwise_distances(position, position)
    for component, sub_ICT in zip(components, sub_ICTs):
        for u, v, w in sub_ICT.iterEdgesWeights():
            nodeA = component[u]
            nodeB = component[v]
            ICT_forest.addEdge(nodeA, nodeB, distances[nodeA, nodeB])

    ICT_forest.indexEdges()
    
    
    
    good_edges = []
    for u, v in ICT_forest.iterEdges():
        good_edges.append([u,v])
        
        
        

    G = patch_together(ICT_forest, position, bridges=4)

    ICT = calculate_ICT(G, algorithm_type=ICT_algorithm, cluster_centers=cluster_centers,
                                    zeros_stay_zeros=True, update_G=1.1, good_edges=good_edges)
    ICT.indexEdges()
    
    # Plot the ICT
    fig, axs = plt.subplots(1, 2, figsize=(12,6))
    plot_points(position, f"ICT with nodes ({k} clusters)", axs[0], labels=np.array(cluster_labels), node_size=5)
    axs[0].get_legend().remove()
    plot_graph(ICT, position, f"ICT with nodes ({k} clusters)", axs[0], node_size=0, edge_scale=0.5)
    plot_graph(ICT, position, f"ICT without nodes ({k} clusters)", axs[1], node_size=0, edge_scale=0.5)
    name = str(k)
    plt.tight_layout()
    plt.savefig(f"Output/triangle2/"+ name.zfill(5) + ".png")
    plt.close()