In [1]:
!pip install gudhi
!pip install hdbscan
!pip install genieclust
from scipy.spatial import distance
import plotly.graph_objects as go
import networkx as nx
import matplotlib.pyplot as plt
from pprint import pprint
from numpy.linalg import matrix_rank
import numpy as np
import gudhi as gd
import matplotlib
import scipy.spatial as spatial
import bisect
import matplotlib
from scipy.cluster import hierarchy
from urllib.request import urlopen
from io import BytesIO
from gudhi import AlphaComplex
from scipy.spatial.distance import cdist
from scipy.spatial.distance import squareform
from scipy.cluster.hierarchy import ward, fcluster
from sklearn.cluster import KMeans
from sklearn.neighbors import KernelDensity
from sklearn.metrics.cluster import adjusted_rand_score
from scipy.cluster.hierarchy import dendrogram
import time
from scipy.spatial import KDTree
from sklearn.datasets import make_blobs
from sklearn.metrics import calinski_harabasz_score
import pandas as pd
import random
from sklearn.cluster import DBSCAN
import hdbscan
from gudhi.clustering.tomato import Tomato
import genieclust
from scipy.spatial.distance import pdist
from sklearn import metrics
from sklearn.mixture import GaussianMixture



In [2]:
import numpy as np
from scipy.spatial.distance import cdist
from sklearn.neighbors import KDTree, NearestNeighbors
from sklearn.decomposition import PCA
from sklearn.neighbors import KernelDensity
from sklearn.preprocessing import PolynomialFeatures
from scipy.cluster.hierarchy import dendrogram
import matplotlib.pyplot as plt

class FiltrationBuilder:
    def __init__(self, data):
        self.data = data

    def get_filtration(self, dist):
        filtration = []
        for i in range(len(self.data)):
            for j in range(len(self.data)):
                if i > j:
                    continue
                elif i == j:
                    filtration.append(([i], dist[i][j]))
                else:
                    filtration.append(([i, j], dist[i][j]))
        filtration = sorted(filtration, key=lambda x: (x[1], x[0]))
        return self.transform_filtration(filtration)

    def transform_filtration(self, filtration):
        new_filtration = []
        new_filtration.append(filtration[0])
        if filtration[1][1] == filtration[0][1]:
            new_filtration.append((filtration[1][0], filtration[1][1] + 1e-15))
        else:
            new_filtration.append(filtration[1])
        for i in range(2, len(filtration)):
            if filtration[i][1] == filtration[i-1][1] and filtration[i][1] == filtration[i-2][1]:
                new_filtration.append((filtration[i][0], new_filtration[-1][1] + 1e-15))
            elif filtration[i][1] == filtration[i-1][1]:
                new_filtration.append((filtration[i][0], filtration[i-1][1] + 1e-15))
            else:
                new_filtration.append(filtration[i])
        return new_filtration

    def get_filtration_from_scipy_dist(self, dist_type):
        dist = cdist(self.data, self.data, dist_type)
        return self.get_filtration(dist)

    def get_filtration_from_density(self, kernel_type, bandwidth):
        kde = KernelDensity(kernel=kernel_type, bandwidth=bandwidth).fit(self.data)
        point_density = np.exp(kde.score_samples(self.data)).tolist()
        maximum = np.max(np.exp(kde.score_samples(self.data)))
        second_maximum = np.sort(np.exp(kde.score_samples(self.data)))[-2]
        dist = [[0 for _ in range(len(self.data))] for _ in range(len(self.data))]
        for i in range(len(self.data)):
            for j in range(len(self.data)):
                if i < j:
                    dist[i][j] = abs(point_density[i] - point_density[j])
                    dist[j][i] = dist[i][j]
        return self.get_filtration(dist)

    def get_not_normalized_filtration_from_density(self, r):
        tree = KDTree(self.data)
        densities = tree.query_ball_point(self.data, r)
        point_density = np.array([len(neighbors) for neighbors in densities])
        maximum = max(point_density)
        second_maximum = sorted(point_density)[-2]
        dist = [[0 for _ in range(len(self.data))] for _ in range(len(self.data))]
        for i in range(len(self.data)):
            for j in range(len(self.data)):
                if i < j:
                    dist[i][j] = second_maximum + maximum - (point_density[i] + point_density[j])
                    dist[j][i] = dist[i][j]
        return self.get_filtration(dist)

    def get_filtration_from_curvature(self):
        x = np.array(self.data)[:, 0]
        y = np.array(self.data)[:, 1]
        dx = np.gradient(x)
        dy = np.gradient(y)
        d2x = np.gradient(dx)
        d2y = np.gradient(dy)
        curvatures = np.abs(d2x * dy - dx * d2y) / (dx**2 + dy**2)**1.5
        dist = [[0 for _ in range(len(self.data))] for _ in range(len(self.data))]
        for i in range(len(self.data)):
            for j in range(len(self.data)):
                if i < j:
                    dist[i][j] = curvatures[i] + curvatures[j]
                    dist[j][i] = dist[i][j]
        return self.get_filtration(dist)

    def get_filtration_from_knn(self, k):
        nbrs = NearestNeighbors(n_neighbors=k).fit(self.data)
        distances, _ = nbrs.kneighbors(self.data)
        knn_distances = distances[:, -1]
        maximum = np.max(knn_distances)
        second_maximum = np.sort(knn_distances)[-2] if len(knn_distances) >= 2 else maximum
        dist = [[0.0 for _ in range(len(self.data))] for _ in range(len(self.data))]
        for i in range(len(self.data)):
            for j in range(len(self.data)):
                if i < j:
                    val = second_maximum + maximum - (knn_distances[i] + knn_distances[j])
                    dist[i][j] = val
                    dist[j][i] = val
        return self.get_filtration(dist)

    def get_filtration_from_christoffel(self, degree):
        N = len(self.data)
        poly = PolynomialFeatures(degree=degree, include_bias=True)
        V = poly.fit_transform(self.data)
        s = V.shape[1]
        res = V[0].reshape(-1, 1) @ V[0].reshape(1, -1)
        for i in range(1, len(V)):
            Vi = V[i].reshape(-1, 1)
            res += Vi @ Vi.T
        M_d = (1.0 / N) * res
        M_d_inv = np.linalg.pinv(M_d)
        Q = np.sum(V * (np.dot(V, M_d_inv)), axis=1)
        Lambda = 1.0 / Q
        christoffel_density = s * Lambda
        max_density = np.max(christoffel_density)
        second_max = np.sort(christoffel_density)[-2] if N >= 2 else max_density
        dist = [[0 for _ in range(len(self.data))] for _ in range(len(self.data))]
        for i in range(len(self.data)):
            for j in range(i+1, len(self.data)):
                val = second_max + max_density - (christoffel_density[i] + christoffel_density[j])
                dist[i][j] = val
                dist[j][i] = val
        return self.get_filtration(dist)

    def get_filtration_from_intrinsic_dim(self, k_neighbors=10, var_threshold=0.95):
        nbrs = NearestNeighbors(n_neighbors=k_neighbors + 1).fit(self.data)
        _, indices = nbrs.kneighbors(self.data)
        intrinsic_dims = []
        for idx_list in indices:
            X_local = self.data[idx_list[1:]]
            pca = PCA().fit(X_local)
            cum_var = np.cumsum(pca.explained_variance_ratio_)
            intrinsic_dims.append(np.searchsorted(cum_var, var_threshold) + 1)
        intrinsic_dims = np.array(intrinsic_dims)
        max_dim = np.max(intrinsic_dims)
        second_max = np.sort(intrinsic_dims)[-2] if len(intrinsic_dims) >= 2 else max_dim
        n = len(self.data)
        dist = [[0.0 for _ in range(n)] for _ in range(n)]
        for i in range(n):
            for j in range(i + 1, n):
                val = abs(intrinsic_dims[i] - intrinsic_dims[j])
                dist[i][j] = val
                dist[j][i] = val
        for i in range(n):
            dist[i][i] = intrinsic_dims[i]
        return self.get_filtration(dist)

class BifiltrationProcessor:
    def __init__(self, filtration1, filtration2):
        self.filtration1 = filtration1
        self.filtration2 = filtration2

    def critical_points(self):
        filtration1_dict, filtration2_dict = dict(), dict()
        filtration1_dict_rev, filtration2_dict_rev = dict(), dict()
        for i in range(len(self.filtration1)):
            filtration1_dict[tuple(self.filtration1[i][0])] = (self.filtration1[i][1], i)
            filtration1_dict_rev[i] = self.filtration1[i][0]
        for i in range(len(self.filtration2)):
            filtration2_dict[tuple(self.filtration2[i][0])] = (self.filtration2[i][1], i)
            filtration2_dict_rev[i] = self.filtration2[i][0]

        # create bifiltration grid and mark cells where simplex appears in both filtrations
        grid = [[-1 for i in range(len(self.filtration2))] for j in range(len(self.filtration1))]
        for item in filtration1_dict.items():
            grid[item[1][1]][filtration2_dict[item[0]][1]] = 0

        # traverse grid to find remaining critical points
        for i in range(len(grid)):
            for j in range(len(grid[0])):
                if (i != 0) and (j != 0) and grid[i - 1][j] == 0 and grid[i][j - 1] == 0:
                    grid[i][j] = 1
        return grid, filtration1_dict_rev, filtration2_dict_rev, filtration1_dict, filtration2_dict

    def get_slice(self, f, f_invers):
        grid, filtration1_dict_rev, filtration2_dict_rev, filtration1_dict, filtration2_dict = self.critical_points()

        # project critical cells(marked as 0/1) onto the line and assign filtration value
        prepared_filtration = []
        for x_grid in range(len(self.filtration1)):
            for y_grid in range(len(self.filtration2)):
                x_grid_val = self.filtration1[x_grid][1]
                y_grid_val = self.filtration2[y_grid][1]
                # check cell is to the left or to the right from the line
                y_line = f(x_grid_val)
                x_line = f_invers(y_grid_val)
                if y_line <= y_grid_val: # to the left
                    if grid[x_grid][y_grid] == 0:
                        # (simplex, filtration_value, dist_to_line_in_x)
                        prepared_filtration.append((filtration1_dict_rev[x_grid], y_grid_val + x_line, x_line - x_grid_val))
                    elif grid[x_grid][y_grid] == 1:
                        prepared_filtration.append((filtration1_dict_rev[x_grid - 1], y_grid_val + x_line, x_line - x_grid_val))
                else: # to the right
                    if grid[x_grid][y_grid] == 0:
                        # (simplex, filtration_value, dist_to_line_in_x)
                        prepared_filtration.append((filtration2_dict_rev[y_grid], y_line + x_grid_val, y_line - y_grid_val))
                    elif grid[x_grid][y_grid] == 1:
                        prepared_filtration.append((filtration2_dict_rev[y_grid - 1], y_line + x_grid_val, y_line - y_grid_val))

        # Leave only appearances closest to the line
        prepared_filtration = sorted(prepared_filtration, key= lambda x: (x[1], x[2]))
        filtration = []
        filtration.append((prepared_filtration[0][0], prepared_filtration[0][1]))
        for i in range(1, len(prepared_filtration)):
            if prepared_filtration[i][1] != prepared_filtration[i - 1][1]:
                filtration.append((prepared_filtration[i][0], prepared_filtration[i][1]))
        return filtration

    def get_slice_optimized(self, f, f_inverse):
        # Создание словарей
        f1_dict = {tuple(simplex): (value, idx) for idx, (simplex, value) in enumerate(self.filtration1)}
        f2_dict = {tuple(simplex): (value, idx) for idx, (simplex, value) in enumerate(self.filtration2)}

        prepared_filtration = []

        # Обработка симплексов из filtration1
        for simplex, (value1, idx1) in f1_dict.items():
            if simplex in f2_dict:
                value2, idx2 = f2_dict[simplex]
                x_line = f_inverse(value2)
                y_line = f(value1)
                if y_line <= value2:
                    projected_value = value2 + x_line
                    dist = x_line - value1
                else:
                    projected_value = y_line + value1
                    dist = y_line - value2
                prepared_filtration.append((simplex, projected_value, dist))

        # Сортировка и устранение дубликатов
        prepared_filtration.sort(key=lambda x: (x[1], x[2]))
        filtration = []
        for i in range(len(prepared_filtration)):
            if i == 0 or prepared_filtration[i][1] != prepared_filtration[i-1][1]:
                filtration.append((prepared_filtration[i][0], prepared_filtration[i][1]))

        return filtration


    def get_path(self):
        set_list1 = []
        set_list2 = []
        curr_set = set()
        for simpl, val in self.filtration1:
            curr_set.add(tuple(simpl))
            set_list1.append(curr_set.copy())

        curr_set = set()
        for simpl, val in self.filtration2:
            curr_set.add(tuple(simpl))
            set_list2.append(curr_set.copy())

        # traverse grid
        prepared_filtration = []
        up = True
        x_grid = 0
        y_grid = 0
        curr_set = set()
        while x_grid < len(self.filtration1):
            if (x_grid == len(self.filtration1)) or (y_grid == len(self.filtration2)):
                break
            while y_grid < len(self.filtration2):
                if (x_grid == len(self.filtration1)) or (y_grid == len(self.filtration2)):
                    break
                inter = set_list1[x_grid].intersection(set_list2[y_grid])
                simplex = inter.difference(curr_set)
                if len(simplex) != 0:
                    prepared_filtration.append((list(list(simplex)[0]), self.filtration1[x_grid][1] + self.filtration2[y_grid][1]))
                curr_set = inter
                if up == True:
                    x_grid += 1
                    up = False
                else:
                    y_grid += 1
                    up = True
        if up == False:
            inter = set_list1[-1].intersection(set_list2[-1])
            simplex = inter.difference(curr_set)
            if len(simplex) != 0:
                prepared_filtration.append((list(list(simplex)[0]), self.filtration1[-1][1] + self.filtration2[-1][1]))
        return prepared_filtration

    def get_raw_critical_points(self):
        f1_dict = {tuple(simplex): val for simplex, val in self.filtration1}
        f2_dict = {tuple(simplex): val for simplex, val in self.filtration2}

        critical_points = []
        for simplex in f1_dict:
            if simplex in f2_dict:
                critical_points.append((
                    f1_dict[simplex],
                    f2_dict[simplex],
                    list(simplex)
                ))

        return critical_points

    def plot_critical_points_2d(self, title="Critical Points",
                               color='blue', marker_size=50,
                               grid=True, show_labels=False,
                               plot_function=None,
                               func_x_range=None,
                               func_style={'color': 'red', 'linestyle': '--', 'label': 'Function'}):
        # Получаем критические точки
        points = self.get_raw_critical_points()
        x_points = [p[0] for p in points]
        y_points = [p[1] for p in points]
        labels = [str(p[2]) for p in points]

        # Создаем график
        plt.figure(figsize=(10, 6))

        # Отрисовываем критические точки
        plt.scatter(x_points, y_points,
                   s=marker_size,
                   c=color,
                   edgecolors='black',
                   linewidths=0.5,
                   label='Critical Points')

        # Отрисовываем функцию, если задана
        if plot_function is not None:
            # Определяем диапазон для функции
            if func_x_range is None:
                x_min, x_max = min(x_points), max(x_points)
                x_vals = np.linspace(x_min, x_max, 100)
            else:
                x_vals = np.linspace(*func_x_range)

            y_vals = plot_function(x_vals)
            plt.plot(x_vals, y_vals, **func_style)

        # Подписи точек
        if show_labels:
            for xi, yi, label in zip(x_points, y_points, labels):
                plt.text(xi, yi, label,
                        fontsize=8,
                        ha='center',
                        va='bottom',
                        bbox=dict(boxstyle="round,pad=0.2",
                                facecolor='white',
                                edgecolor='none',
                                alpha=0.8))

        # Настройки оформления
        plt.title(title, fontsize=14)
        plt.xlabel("Filtration 1 Value", fontsize=12)
        plt.ylabel("Filtration 2 Value", fontsize=12)
        plt.legend()

        if grid:
            plt.grid(True, linestyle=':', color='grey', alpha=0.7)

        plt.tight_layout()
        plt.show()

    def plot_bifiltration_grid(self, filtration1, filtration2):
          grid, f1_rev, f2_rev, f1_dict, f2_dict = self.critical_points(filtration1, filtration2)

          plt.figure(figsize=(10, 10))

          for i in range(len(grid)):
              for j in range(len(grid[0])):
                  if grid[i][j] == 0:
                      color = 'red'
                      marker = 'o'
                      size = 50
                  elif grid[i][j] == 1:
                      color = 'blue'
                      marker = 's'
                      size = 30
                  else:
                      continue

                  plt.scatter(i, j, c=color, marker=marker, s=size, edgecolors='black')

          plt.xticks(range(len(filtration1)),
                    labels=[f"{filtration1[i][0]}" for i in range(len(filtration1))],
                    rotation=45)
          plt.yticks(range(len(filtration2)),
                    labels=[f"{filtration2[i][0]}" for i in range(len(filtration2))])

          plt.xlabel("Filtration 1 Index")
          plt.ylabel("Filtration 2 Index")
          plt.title("Bifiltration Critical Points Grid")

          plt.grid(True, linestyle='--', alpha=0.7)
          plt.gca().set_axisbelow(True)

          plt.scatter([], [], c='red', marker='o', s=50, label='Primary Critical Points')
          plt.scatter([], [], c='blue', marker='s', s=30, label='Secondary Critical Points')
          plt.legend()

          plt.tight_layout()
          plt.show()

class ClusteringUtils:
    @staticmethod
    def format_filtration(filtration):
        filtration_upd = []
        for i in filtration:
            if len(i[0]) == 2:
                filtration_upd.append((i[0][0], i[0][1], i[1]))
        return filtration_upd

    @staticmethod
    def get_linkage_matrix(filtration, K):
        filtration = ClusteringUtils.format_filtration(filtration)
        clusters = {i: [i] for i in range(K)}
        ind = K
        linkage_matrix = []

        for el in filtration:
            clust1 = next(clust for clust, items in clusters.items() if el[0] in items)
            clust2 = next(clust for clust, items in clusters.items() if el[1] in items)

            if clust1 != clust2:
                cluster_upd = clusters[clust1] + clusters[clust2]
                linkage_matrix.append([clust1, clust2, el[2], len(cluster_upd)])
                clusters[ind] = cluster_upd
                del clusters[clust1]
                del clusters[clust2]
                ind += 1
        return np.array(linkage_matrix)

    @staticmethod
    def plot_merge_tree(linkage_matrix, threshold):
        plt.figure(figsize=(10, 5))
        dendrogram(linkage_matrix)
        plt.title('Dendrogram')
        plt.xlabel('Data points')
        plt.ylabel('Height')
        plt.axhline(y=threshold, color='r', linestyle='-')
        plt.show()
        return

    @staticmethod
    def merge_clusters(data, start_clusters, linkage_matrix, min_size, distance_matrix):
        sizes = {cluster: np.sum((start_clusters == cluster)) for cluster in np.unique(start_clusters)}
        new_clusters = {cluster for cluster, size in sizes.items() if size < min_size}
        while new_clusters:
            c = new_clusters.pop()
            points = np.where(start_clusters == c)[0]
            min_distance = np.inf
            closest_cluster = None
            for cluster in sizes:
                if cluster != c and sizes[cluster] >= min_size:
                    for point in points:
                        distances = distance_matrix[point, start_clusters == cluster]
                        if np.min(distances) < min_distance:
                            min_distance = np.min(distances)
                            closest_cluster = cluster
            if closest_cluster is not None:
                start_clusters[points] = closest_cluster
                sizes[closest_cluster] += len(points)
            else:
                continue
            sizes[c] = 0
        return start_clusters

In [4]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.cluster.hierarchy import dendrogram, fcluster
from scipy.spatial.distance import pdist
from ipywidgets import interact, FloatText, Text, Button, Output, Dropdown, VBox
from IPython.display import display, Math
import sympy as sp
from sympy.utilities.lambdify import lambdify

class DendrogramApp:
    def __init__(self):
        self.filtration1 = None
        self.filtration2 = None
        self.data = None
        self.labels = None
        self.max_dist = 10  # Начальное значение, будет обновлено

        # Доступные функции фильтрации
        self.filtration_options = {
            'Евклидово расстояние': 'euclidean',
            'Манхэттенское расстояние': 'cityblock',
            'Косинусное расстояние': 'cosine',
            'Ядерная плотность': 'density',
            'Кривизна': 'curvature',
            'KNN': 'knn',
            'Христоффель': 'christoffel',
            'Инт. размерность': 'intrinsic_dim'
        }

        # Элементы интерфейса
        self.output = Output()
        self.filt1_dropdown = Dropdown(
            options=list(self.filtration_options.keys()),
            value='Евклидово расстояние',
            description='Фильтрация 1:'
        )
        self.filt2_dropdown = Dropdown(
            options=list(self.filtration_options.keys()),
            value='Ядерная плотность',
            description='Фильтрация 2:'
        )
        self.f_input = Text(value="x", description="f(x):")
        self.threshold_input = FloatText(
            value=5.0,  # Начальное значение
            min=0.1,   # Минимальное значение
            step=0.1,
            description="Порог:",
            continuous_update=False
        )
        self.plot_button = Button(description="Построить дендрограмму и кластеризацию",
                                button_style='success')
        self.plot_button.on_click(self.plot_dendrogram)

    def set_data(self, data):
        """Установка данных для анализа"""
        self.data = data
        self.fb = FiltrationBuilder(data)
        # Автоматически обновляем рекомендуемый порог
        self.update_threshold_value()

    def update_threshold_value(self):
        """Обновление рекомендуемого значения порога на основе данных"""
        if self.data is not None:
            self.max_dist = np.max(pdist(self.data))
            self.threshold_input.value = self.max_dist / 2
            self.threshold_input.min = 0.1
            self.threshold_input.max = self.max_dist * 2  # Для валидации

    def show_ui(self):
        """Отображение интерфейса"""
        display(VBox([
            self.filt1_dropdown,
            self.filt2_dropdown,
            self.f_input,
            self.threshold_input,
            self.plot_button,
            self.output
        ]))

    def get_filtration(self, filt_type):
        """Создание фильтрации по выбранному типу"""
        if filt_type == 'euclidean':
            return self.fb.get_filtration_from_scipy_dist('euclidean')
        elif filt_type == 'cityblock':
            return self.fb.get_filtration_from_scipy_dist('cityblock')
        elif filt_type == 'cosine':
            return self.fb.get_filtration_from_scipy_dist('cosine')
        elif filt_type == 'density':
            return self.fb.get_filtration_from_density('gaussian', 0.3)
        elif filt_type == 'curvature':
            return self.fb.get_filtration_from_curvature()
        elif filt_type == 'knn':
            return self.fb.get_filtration_from_knn(5)
        elif filt_type == 'christoffel':
            return self.fb.get_filtration_from_christoffel(degree=2)
        elif filt_type == 'intrinsic_dim':
            return self.fb.get_filtration_from_intrinsic_dim()
        else:
            raise ValueError(f"Unknown filtration type: {filt_type}")

    def compute_inverse_function(self, f_expr):
        """Вычисление обратной функции с помощью sympy"""
        try:
            x = sp.symbols('x')
            y = sp.symbols('y')

            # Парсим выражение функции
            f = sp.sympify(f_expr)

            # Решаем уравнение y = f(x) относительно x
            inverse_expr = sp.solve(sp.Eq(y, f), x)

            if not inverse_expr:
                raise ValueError("Не удалось найти обратную функцию")

            # Берем первое решение (может быть несколько)
            inverse_expr = inverse_expr[0]
            display(Math(f"f^{{-1}}(y) = {sp.latex(inverse_expr)}"))

            # Преобразуем обратную функцию в лямбда-функцию
            g = lambdify(y, inverse_expr, modules=['numpy'])

            return g
        except Exception as e:
            raise ValueError(f"Ошибка при вычислении обратной функции: {str(e)}")

    def plot_dendrogram(self, b):
        """Построение дендрограммы и кластеризации"""
        with self.output:
            self.output.clear_output()

            try:
                # Получаем выбранные фильтрации
                filt1_type = self.filtration_options[self.filt1_dropdown.value]
                filt2_type = self.filtration_options[self.filt2_dropdown.value]

                self.filtration1 = self.get_filtration(filt1_type)
                self.filtration2 = self.get_filtration(filt2_type)

                # Нормализация фильтраций
                max1 = max(item[1] for item in self.filtration1)
                max2 = max(item[1] for item in self.filtration2)
                self.filtration2 = [
                    tuple([item[0], (item[1]/max2)*max1])
                    for item in self.filtration2
                ]

                # Получаем функцию f и вычисляем обратную функцию g
                f_expr = self.f_input.value
                f = lambda x: eval(f_expr, {'x': x, 'np': np})
                g = self.compute_inverse_function(f_expr)

                # Отрисовка дендрограмм для каждой фильтрации отдельно
                self.plot_individual_dendrograms(filt1_type, filt2_type)

                # Создаем бифильтрацию
                bf_processor = BifiltrationProcessor(self.filtration1, self.filtration2)
                filtration = bf_processor.get_slice_optimized(f, g)

                # Кластеризация
                K = len(set().union(*[set(item[0]) for item in filtration]))
                linkage_matrix = ClusteringUtils.get_linkage_matrix(filtration, K)
                self.labels = fcluster(linkage_matrix, t=self.threshold_input.value, criterion='distance')

                # Визуализация
                self.visualize_results(linkage_matrix)
                bf_processor.plot_critical_points_2d(plot_function=f)

            except Exception as e:
                print(f"Ошибка: {str(e)}")

    def plot_individual_dendrograms(self, filt1_type, filt2_type):
        """Отрисовка отдельных дендрограмм для каждой фильтрации"""
        plt.figure(figsize=(15, 5))

        # Дендрограмма для первой фильтрации
        plt.subplot(1, 2, 1)
        K1 = len(set().union(*[set(item[0]) for item in self.filtration1]))
        linkage_matrix1 = ClusteringUtils.get_linkage_matrix(self.filtration1, K1)
        dendrogram(linkage_matrix1)
        plt.title(f"Дендрограмма для {filt1_type}")
        plt.xlabel("Объекты")
        plt.ylabel("Расстояние")

        # Дендрограмма для второй фильтрации
        plt.subplot(1, 2, 2)
        K2 = len(set().union(*[set(item[0]) for item in self.filtration2]))
        linkage_matrix2 = ClusteringUtils.get_linkage_matrix(self.filtration2, K2)
        dendrogram(linkage_matrix2)
        plt.title(f"Дендрограмма для {filt2_type}")
        plt.xlabel("Объекты")
        plt.ylabel("Расстояние")

        plt.tight_layout()
        plt.show()

    def visualize_results(self, linkage_matrix):
        """Визуализация дендрограммы и кластеризации"""
        plt.figure(figsize=(15, 6 if self.data.shape[1] == 2 else 5))

        # Дендрограмма
        plt.subplot(1, 2 if self.data.shape[1] == 2 else 1, 1)
        dendrogram(linkage_matrix)
        plt.axhline(y=self.threshold_input.value, color='r', linestyle='--')
        plt.title(f"Дендрограмма кластеризации (Порог = {self.threshold_input.value:.1f})")
        plt.xlabel("Объекты")
        plt.ylabel("Расстояние")

        # Кластеризация (только для 2D данных)
        if self.data.shape[1] == 2:
            plt.subplot(1, 2, 2)
            plt.scatter(self.data[:, 0], self.data[:, 1], c=self.labels, cmap='tab20', s=50)
            plt.title(f"2D Визуализация кластеров (Порог = {self.threshold_input.value:.1f})")
            plt.xlabel("Первая компонента")
            plt.ylabel("Вторая компонента")

        plt.tight_layout()
        plt.show()

if __name__ == "__main__":
    import random
    data = np.load('X_hdbscan.npy')
    np.random.seed(0)
    random_sample = random.sample(data.tolist()[400:], 500)
    data = np.array(random_sample)

    app = DendrogramApp()
    app.set_data(data)
    app.show_ui()

VBox(children=(Dropdown(description='Фильтрация 1:', options=('Евклидово расстояние', 'Манхэттенское расстояни…