In [None]:
from math import log
import numpy as np
from sklearn.metrics.pairwise import euclidean_distances, cosine_distances
from scipy.spatial.distance import pdist, squareform
from scipy.cluster.hierarchy import linkage, fcluster
import math

In [None]:
class HCACLC:

  def __init__(self, dataset, interventions:int, number_pool:int, dist_function:str = 'cosine', linkage:str = "complete" , M = None, set_of_constraints = None):
    self.dataset = dataset
    self.size = self.dataset.shape[0]
    self.dist_function = self.validate_distance(dist_function)
    self.interventions = self.validate_interventions(interventions)
    self.distance_matrix = self.create_distance_matrix() if (M is None) else M
    self.linkage = linkage
    self.clusters = [[i] for i in range(self.size)]
    self.mask = np.ones_like(self.distance_matrix, dtype=bool)
    np.fill_diagonal(self.mask, False)
    self.number_pool = number_pool
    self.threshold_array = self.threshold()
    self.counter = 0
    self.current_id = [i for i in range(self.size)]


  def validate_distance(self,distance_function:str) ->str:
    if distance_function not in ['cosine', 'euclidean']:
      return 'cosine'
    else:
      return distance_function

  def validate_interventions(self, interventions:int) ->int:
    limit_interventions = self.size - math.floor(self.size) - 1
    if (interventions <= limit_interventions ):
      return interventions
    else:
      return limit_interventions

  def create_distance_matrix(self):
    matrix = squareform(pdist(self.dataset, metric=self.dist_function))
    np.fill_diagonal(matrix, np.inf)
    return matrix

  def flatten_list(self,nested_list):
    flat_list = []
    for item in nested_list:
        if isinstance(item, list):
          flat_list.extend(self.flatten_list(item))
        else:
          flat_list.append(item)
    return flat_list
  def flatten_and_count(self,lst):
    count = 0
    stack = [lst]
    while stack:
        current = stack.pop()
        if isinstance(current, list):
          stack.extend(current)
        else:
          count += 1
    return count

  def threshold(self):
    matrix = self.distance_matrix.copy()
    mask = np.ones_like(matrix, dtype=bool)
    confidence_scores = []
    cluster_threshold = self.clusters.copy()
    num_clusters = len(cluster_threshold)
    t = 0
    while True:
      masked_distances = np.where(mask, matrix, np.inf)
      if t == num_clusters or np.all(masked_distances == np.inf):
        break
      minimal_distance_pair = np.unravel_index(np.argmin(masked_distances, axis=None), self.distance_matrix.shape)
      minimal_distance = matrix[minimal_distance_pair]
      a, b = minimal_distance_pair
      cluster_a = [cluster for cluster in cluster_threshold if a in self.flatten_list(cluster)]
      cluster_b = [cluster for cluster in cluster_threshold if b in self.flatten_list(cluster)]
      mask[a,b] = False
      mask[b,a] = False
      if cluster_a and cluster_b and cluster_a != cluster_b:
        if(self.flatten_and_count(cluster_a) > 1 and self.flatten_and_count(cluster_a)>1):
          confidence_scores.append(self.confidence(minimal_distance_pair,minimal_distance))
        for i in range(matrix.shape[0]):
          matrix[i][a] = (matrix[i][a] + matrix[i][b]) / 2.0
          matrix[a][i] = matrix[i][a]
        merged_cluster = [cluster_a[0], cluster_b[0]]
        cluster_threshold = [cluster for cluster in cluster_threshold if not (self.flatten_list(cluster) == self.flatten_list(cluster_a[0]))]
        cluster_threshold = [cluster for cluster in cluster_threshold if not (self.flatten_list(cluster) == self.flatten_list(cluster_b[0]))]
        cluster_threshold.append(merged_cluster)
        t += 1

    confidence_scores.sort()
    return confidence_scores[self.interventions]

  def create_pool(self, cluster_pair):
    pool_clusters = [cluster_pair]
    self.mask[cluster_pair] = False
    self.mask[cluster_pair[1],cluster_pair[0]] = False
    distances_from_a = self.distance_matrix[cluster_pair[0], :].copy()
    distances_from_b = self.distance_matrix[cluster_pair[1], :].copy()

    distances_from_a[cluster_pair[1]] = np.inf
    distances_from_b[cluster_pair[0]] = np.inf

    for _ in range(1, self.number_pool):
        min_index_a = np.argmin(distances_from_a)
        min_index_b = np.argmin(distances_from_b)

        if distances_from_a[min_index_a] < distances_from_b[min_index_b]:
          selected_pair = (cluster_pair[0], min_index_a)
          distances_from_a[min_index_a] = np.inf
        else:
          selected_pair = (cluster_pair[1], min_index_b)
          distances_from_b[min_index_b] = np.inf

        pool_clusters.append(selected_pair)
        self.mask[selected_pair] = False
        self.mask[selected_pair[1],selected_pair[0]] = False
    return pool_clusters

  def confidence(self, pair: tuple,distance) -> float:
    final = []
    row = self.distance_matrix[pair[0]]
    row[pair[1]] = np.inf

    final.append(np.min(row))
    row = self.distance_matrix[pair[1]]
    row[pair[0]] = np.inf

    final.append(np.min(row))

    return min(final) - distance

  def get_clusters(self) -> tuple:
    masked_distances = np.where(self.mask, self.distance_matrix, np.inf)
    minimal = np.unravel_index(np.argmin(masked_distances), self.distance_matrix.shape)
    confidence = self.confidence(minimal,np.argmin(masked_distances))
    a, b = minimal
    cluster_a = [cluster for cluster in self.clusters if a in self.flatten_list(cluster)]
    cluster_b = [cluster for cluster in self.clusters if b in self.flatten_list(cluster)]
    self.mask[a,b] = False
    self.mask[b,a] = False
    if confidence < self.threshold_array and self.counter < self.interventions and self.flatten_and_count(cluster_a) > 1 and self.flatten_and_count(cluster_b) > 1:
        pool = self.create_pool(minimal)
        chosen_index = self.get_user_choice(pool)
        minimal = pool[chosen_index]
        self.mask[chosen_index] = False
        self.mask[chosen_index[1],chosen_index[0]] = False
        self.counter += 1

    return min(minimal), max(minimal)

  def get_user_choice(self, pool: list) -> int:
    print("Available merge options:")
    for i, option in enumerate(pool):
      print(f"{i}: Merge clusters {option[0]} and {option[1]}")

    choice = int(input("Enter the number of the desired merge option: "))
    while choice < 0 or choice >= len(pool):
      print("Invalid choice. Please try again.")
      choice = int(input("Enter the number of the desired merge option: "))

    return choice

  def merge_clusters(self, pair: tuple):
    a, b = pair

    if self.linkage == 'complete':
      for i in range(self.distance_matrix.shape[0]):
        if i != a and i != b:
          self.distance_matrix[a, i] = max(self.distance_matrix[a, i], self.distance_matrix[b, i])
          self.distance_matrix[i, a] = self.distance_matrix[a, i]

    elif self.linkage == 'average':
      for i in range(self.distance_matrix.shape[0]):
        if i != a and i != b:
          self.distance_matrix[a, i] = (self.distance_matrix[a, i] + self.distance_matrix[b, i]) / 2.0
          self.distance_matrix[i, a] = self.distance_matrix[a, i]

    self.distance_matrix[:, b] = np.inf
    self.distance_matrix[b, :] = np.inf


  def clustering(self) -> None:
    count = len(self.clusters)
    while count > 2 :
        index = self.get_clusters()
        if (index[0]==0 and index[1] == 0):
          break
        cluster_a = next((cluster for cluster in self.clusters if index[0] in self.flatten_list(cluster)), None)
        cluster_b = next((cluster for cluster in self.clusters if index[1] in self.flatten_list(cluster)), None)

        if not cluster_a or not cluster_b or cluster_a == cluster_b:
          continue
        merged_cluster = [cluster_a, cluster_b]
        self.clusters = [cluster for cluster in self.clusters if not (self.flatten_list(cluster) == self.flatten_list(cluster_a))]
        self.clusters = [cluster for cluster in self.clusters if not (self.flatten_list(cluster) == self.flatten_list(cluster_b))]
        self.clusters.append(merged_cluster)

        self.merge_clusters(index)

        self.current_id = [c[0] for c in self.clusters]
        count = len(self.clusters)
  def get_output(self):
    return self.clusters, self.current_id
