In [None]:
import numpy as np
import pandas as pd
import plotly.graph_objects as go
from scipy.spatial.distance import cdist
from sklearn.covariance import LedoitWolf
from sklearn.cluster import AgglomerativeClustering
import concurrent.futures
from itertools import product, combinations

# ---------------------------------------------------------
# 1. Existing Helper Functions (Dependencies)
# ---------------------------------------------------------


def getMahalanobisDistances(vectors_a, vectors_b):
	# Mahalanobis helper kept in snake_case internally
	norms = np.linalg.norm(vectors_a, axis=1, keepdims=True)
	cleaned_vectors = vectors_a / (norms + 1e-10)

	lw = LedoitWolf()
	lw.fit(cleaned_vectors)	# assumption vectors_a=vectors_b
	precision_matrix = lw.precision_

	dist_matrix = cdist(
		cleaned_vectors, cleaned_vectors, metric="mahalanobis", VI=precision_matrix
	)
	return dist_matrix, precision_matrix


Distance_Processors = {
	"cosine": lambda emb_a, emb_b: 1.0
	- (emb_a @ emb_b.T)
	/ (
		np.linalg.norm(emb_a, axis=1, keepdims=True)
		@ np.linalg.norm(emb_b, axis=1, keepdims=True).T
		+ 1e-10
	),
	"l1": lambda emb_a, emb_b: np.sum(np.abs(emb_a[..., np.newaxis] - emb_b.T), axis=1),
	"l2": lambda emb_a, emb_b: np.linalg.norm(emb_a[..., np.newaxis] - emb_b.T, axis=1),
	"dot": lambda emb_a, emb_b: emb_a @ emb_b.T,
	"mahalanobis": lambda emb_a, emb_b: getMahalanobisDistances(emb_a, emb_b),
}


def _prepareModelArtifact(
	raw_vectors,
	semantic_data,
	truncation_dim=256,
	distance_metric="mahalanobis",
	debug=True,
):
	# 1. Truncation
	data_matrix = np.array(raw_vectors)
	input_dim = data_matrix.shape[1]

	if input_dim < truncation_dim and debug:
		print(
			f"Warning: Vector dimension ({input_dim}) is smaller than truncation limit ({truncation_dim}). Proceeding without truncation."
		)

	data_truncated = data_matrix[:, :truncation_dim]

	# 2. Distance Calculation
	dist_output = Distance_Processors[distance_metric](data_truncated, data_truncated)

	precision_matrix = None
	if distance_metric == "mahalanobis":
		dist_matrix, precision_matrix = dist_output
	else:
		dist_matrix = dist_output

	# 3. NN Indices (In-place modification to avoid copy overhead)
	np.fill_diagonal(dist_matrix, float("inf"))
	nn_indices = np.argmin(dist_matrix, axis=1)
	np.fill_diagonal(dist_matrix, 0.0)

	return {
		"dist_matrix": dist_matrix,
		"vectors": data_truncated,
		"precision": precision_matrix,
		"semantic_data": semantic_data,
		"metric": distance_metric,
		"nn_indices": nn_indices,
	}


def prepareModelArtifacts(
	data_set, vector_keys, truncation_dim=256, distance_metric="mahalanobis", debug=True
):
	semantic_data = list(data_set.keys())
	model_artifacts = {}
	raw_vectors = {}
	for key in vector_keys:
		raw_vectors[key] = [data_set[s][key] for s in semantic_data]
	executor = concurrent.futures.ThreadPoolExecutor(max_workers=len(vector_keys))

	futures = dict()
	for key in vector_keys:
		if debug:
			print(f"Processing {key}...")

		futures[key] = executor.submit(
			_prepareModelArtifact,
			raw_vectors[key],
			semantic_data,
			truncation_dim,
			distance_metric,
			debug,
		)
	executor.shutdown(wait=True)
	for key in vector_keys:
		model_artifacts[key] = futures[key].result()
	return model_artifacts


def getGroupsFromLabels(labels):
	groups = {}
	for idx, label in enumerate(labels):
		groups.setdefault(label, []).append(idx)
	return [g for g in groups.values() if len(g) > 1]


def getNNPairsFromGroups(groups, nn_indices):
	pairs = set()
	for group in groups:
		if len(group) < 2:
			continue

		group_set = set(group)
		for idx in group:
			nn_idx = nn_indices[idx]
			if nn_idx in group_set:
				pairs.add(tuple(sorted((idx, nn_idx))))
	return pairs


def clusterAndGetArtifacts(dist_matrix, threshold):
	model = AgglomerativeClustering(
		n_clusters=None,
		distance_threshold=threshold,
		metric="precomputed",
		linkage="complete",
	)
	labels = model.fit_predict(dist_matrix)
	return getGroupsFromLabels(labels), labels


def calculateNTrue(labels_array, target_pairs):
	if not target_pairs or labels_array is None:
		return 0

	involved_indices = {idx for pair in target_pairs for idx in pair}

	if not involved_indices:
		return 0

	return len({labels_array[idx] for idx in involved_indices})


def getPairsFromLablesCombinations(labels):
	"""
	Converts cluster labels into a Set of unique pairs (indices).
	Returns: set of tuples {(min_id, max_id), ...}
	"""
	groups = {}
	for idx, label in enumerate(labels):
		groups.setdefault(label, []).append(idx)

	pairs = set()
	# Group by label
	for label, indices in groups.items():
		if len(indices) > 1:
			# Generate all unique pairs in this cluster
			for p in combinations(sorted(indices), 2):
				pairs.add(p)
	return pairs


# ---------------------------------------------------------
# 2. New Functionality: Noise Generation & Stability
# ---------------------------------------------------------


def generateDensityBasedNoise(vectors, nn_indices, k=1):
	"""
	Generates noise vectors where the standard deviation for each point
	is exactly the distance to its k-th nearest neighbor.

	sigma_i = distance(i, NN_k)
	"""
	# 1. Get the indices of the k-th neighbor for every point
	# nn_indices shape is (N, >k). We grab the column corresponding to k.
	# (assuming column 0 is the point itself if included, but argmin usually returns OTHER)
	# If nn_indices comes from argmin on dist_matrix with diag=inf, it is rank 1 (N,)
	# If it is a matrix of neighbors, we select column k.

	if nn_indices.ndim == 1:
		target_neighbor_indices = nn_indices	# Fallback to 1st NN if only 1 stored
	else:
		target_neighbor_indices = nn_indices[:, k]

	# 2. Extract vectors
	neighbor_vectors = vectors[target_neighbor_indices]

	# 3. Calculate distance (Local Sigma)
	# We calculate raw Euclidean magnitude based on NN proximity.
	local_sigmas = np.linalg.norm(vectors - neighbor_vectors, axis=1, keepdims=True)

	# Safety: Ensure no zero sigma (duplicates)
	local_sigmas[local_sigmas == 0] = 1e-9

	# 4. Generate Noise
	# N(0, sigma_i^2)
	noise = np.random.normal(0, 1, size=vectors.shape) * local_sigmas

	return noise


def runStabilityTest(
	artifacts,
	optimal_taus,
	n_iterations=200,
	k_neighbor=1,
	active_tests=["NN", "Combinations"],
	consensus_pairs=None,
):
	"""
	Runs noise injection iterations and tracks cluster stability and ACGC metrics.

	Args:
	    active_tests: List of strings. Options: "NN", "Combinations".
	    consensus_pairs: (Set of tuples, optional) The "P_True" set from findConsensusStructure.
	                     If provided, the ACGC component (N_Target) is calculated against this.
	                     If None, it is calculated against the model's own clean pairs.
	Returns:
	    plot_payload: Dict of data arrays.
	"""

	# Data structure for the plotter:
	plot_payload = {}

	print(f"--- Starting Stability Test ({n_iterations} iterations, k={k_neighbor}) ---")
	print(f"Modes: {active_tests}")

	for key, art in artifacts.items():
		print(f"Processing Model: {key}")

		# 1. Establish Clean Baseline
		tau = optimal_taus[key]
		clean_groups, clean_labels = clusterAndGetArtifacts(art["dist_matrix"], tau)

		# Pre-compute baselines for Stability (Retention)
		clean_nn_pairs = None
		clean_comb_pairs = None

		# Determine the target for the ACGC Metric (N_Target)
		# If global consensus is provided, use that. Otherwise use local clean pairs.
		target_acgc_pairs = (
			consensus_pairs
			if consensus_pairs is not None
			else getNNPairsFromGroups(clean_groups, art["nn_indices"])
		)

		if "NN" in active_tests:
			clean_nn_pairs = getNNPairsFromGroups(clean_groups, art["nn_indices"])

		if "Combinations" in active_tests:
			clean_comb_pairs = getPairsFromLablesCombinations(clean_labels)

		# Metric Logs
		nn_jaccard_scores = []
		comb_jaccard_scores = []
		acgc_scores = []	# This tracks N_True (Number of groups the target pairs fall into)

		# 2. Iteration Loop
		for i in range(n_iterations):
			# A. Generate Noise based on Density
			noise = generateDensityBasedNoise(art["vectors"], art["nn_indices"], k=k_neighbor)
			noisy_vectors = art["vectors"] + noise

			# B. Calculate Distances
			# (Reuse precision matrix if Mahalanobis to maintain sensor geometry)
			if art["metric"] == "mahalanobis":
				norms = np.linalg.norm(noisy_vectors, axis=1, keepdims=True)
				cleaned_v = noisy_vectors / (norms + 1e-10)
				d_matrix = cdist(cleaned_v, cleaned_v, metric="mahalanobis", VI=art["precision"])
			else:
				d_matrix = Distance_Processors[art["metric"]](noisy_vectors, noisy_vectors)

			np.fill_diagonal(d_matrix, 0.0)

			# C. Cluster
			groups, labels = clusterAndGetArtifacts(d_matrix, tau)

			# D. Metric 1: ACGC Component (N_Target)
			# "How many groups do the True Pairs span?"
			n_target_val = calculateNTrue(labels, target_acgc_pairs)
			acgc_scores.append(n_target_val)

			# E. Metric 2: Stability (Jaccard vs Clean) - NN
			if "NN" in active_tests:
				current_nn = getNNPairsFromGroups(groups, art["nn_indices"])
				if clean_nn_pairs:
					score = len(current_nn.intersection(clean_nn_pairs)) / len(
						current_nn.union(clean_nn_pairs)
					)
				else:
					score = 0 if current_nn else 1.0
				nn_jaccard_scores.append(score)

			# F. Metric 3: Stability (Jaccard vs Clean) - Combinations
			if "Combinations" in active_tests:
				current_comb = getPairsFromLablesCombinations(labels)
				if clean_comb_pairs:
					score = len(current_comb.intersection(clean_comb_pairs)) / len(
						current_comb.union(clean_comb_pairs)
					)
				else:
					score = 0 if current_comb else 1.0
				comb_jaccard_scores.append(score)

		# Store in payload
		plot_payload[f"{key} (ACGC_N_Target)"] = acgc_scores

		if "NN" in active_tests:
			plot_payload[f"{key} (NN_Stability)"] = nn_jaccard_scores
		if "Combinations" in active_tests:
			plot_payload[f"{key} (Comb_Stability)"] = comb_jaccard_scores

	return plot_payload


def plotNoiseStability(
	data_dict, title="Cluster Stability & ACGC under Density-Based Noise"
):
	"""
	Plots the stability traces with interactive filtering.

	Args:
	    data_dict: { "Method_Name": [val1, val2, ... val200] }
	"""
	fig = go.Figure()

	keys = list(data_dict.keys())
	if not keys:
		print("No data to plot.")
		return

	# Add all traces
	for key in keys:
		y_data = data_dict[key]
		x_data = list(range(len(y_data)))

		# Determine stats for the legend
		avg = np.mean(y_data)
		std = np.std(y_data)

		fig.add_trace(
			go.Scatter(
				x=x_data, y=y_data, mode="lines", name=f"{key} (μ={avg:.2f}, σ={std:.3f})", opacity=0.8
			)
		)

	# --- Create Interactive Buttons (Filters) ---
	buttons = []

	# 1. Button: Show All
	buttons.append(
		dict(label="All", method="update", args=[{"visible": [True] * len(keys)}])
	)

	# 2. Button: ACGC Only
	is_acgc = ["(ACGC_N_Target)" in k for k in keys]
	if any(is_acgc):
		buttons.append(
			dict(label="ACGC (N_Target)", method="update", args=[{"visible": is_acgc}])
		)

	# 3. Button: NN Stability
	is_nn = ["(NN_Stability)" in k for k in keys]
	if any(is_nn):
		buttons.append(dict(label="NN Stability", method="update", args=[{"visible": is_nn}]))

	# 4. Button: Combinations Stability
	is_comb = ["(Comb_Stability)" in k for k in keys]
	if any(is_comb):
		buttons.append(
			dict(label="Comb Stability", method="update", args=[{"visible": is_comb}])
		)

	fig.update_layout(
		title=title,
		xaxis_title="Iteration",
		yaxis_title="Metric Value",
		updatemenus=[
			dict(
				type="buttons", direction="left", x=1.0, y=1.1, showactive=True, buttons=buttons
			)
		],
		hovermode="x unified",
	)

	fig.show()


# ---------------------------------------------------------
# 3. Usage Example
# ---------------------------------------------------------

# Note: This block assumes 'artifacts' and 'consensus_ACGC' are present in the scope.
# In a real script, run 'prepareModelArtifacts' and 'findConsensusStructure' first.

if __name__ == "__main__":
	# Example Configuration
	# We use the global 'P_true' from ACGC as the target to see if noise maintains the consensus.

	# Assuming 'consensus_ACGC' was calculated previously as per your snippet:
	# global_consensus_pairs = consensus_ACGC["P_true"]
	# optimal_taus = consensus_ACGC["optimal_taus"]

	# Mock vars for standalone execution safety (Replace with actuals)
	global_consensus_pairs = None
	optimal_taus_mock = (
		{k: 0.5 for k in artifacts.keys()} if "artifacts" in locals() else {}
	)

	if "artifacts" in locals() and "consensus_ACGC" in locals():

		# Run Phase 1: Stability with NN only + ACGC component
		stability_data = runStabilityTest(
			artifacts,
			consensus_ACGC["optimal_taus"],
			n_iterations=200,
			k_neighbor=1,
			active_tests=["NN"],	# Start with NN as requested
			consensus_pairs=consensus_ACGC["P_true"],
		)

		# Plot
		plotNoiseStability(stability_data)