# Advanced Tutorial: Interpreting Archetypal Analysis Models

This notebook demonstrates sophisticated techniques for interpreting and evaluating archetypal analysis models using the `ArchetypalAnalysisInterpreter` class from the `archetypax` library. Archetypal analysis identifies extreme patterns (archetypes) in data, and the interpreter provides quantitative measures to assess model quality and interpretability.


## 1. Importing Essential Libraries


In [None]:
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from sklearn.datasets import make_blobs
from sklearn.preprocessing import StandardScaler

# Import archetypax components
from archetypax.models.archetypes import ImprovedArchetypalAnalysis
from archetypax.tools.interpret import ArchetypalAnalysisInterpreter
from archetypax.tools.visualization import ArchetypalAnalysisVisualizer

# Configure visualization settings
plt.style.use("seaborn-v0_8-whitegrid")
sns.set_context("notebook", font_scale=1.2)
plt.rcParams["figure.figsize"] = [12, 8]

## 2. Generating Synthetic Data

We'll create a synthetic dataset with clear cluster structure to demonstrate archetypal analysis. This controlled environment allows us to evaluate how well the model identifies meaningful archetypes.


In [None]:
# Set random seed for reproducibility
np.random.seed(42)

# Define dataset parameters
n_samples = 200
n_features = 10
n_clusters = 4

# Generate data with clear cluster structure
X_raw, y_true = make_blobs(
    n_samples=n_samples, n_features=n_features, centers=n_clusters, random_state=42, cluster_std=1.2
)

# Add some noise features to make the problem more challenging
noise_features = np.random.normal(0, 0.5, size=(n_samples, 5))
X_raw = np.hstack([X_raw, noise_features])

# Standardize the data
scaler = StandardScaler()
X = scaler.fit_transform(X_raw)

print(f"Dataset dimensions: {X.shape}")
print(f"Number of true clusters: {n_clusters}")

## 3. Visualizing the Data Structure

Before applying archetypal analysis, let's visualize the inherent structure in our synthetic dataset to establish a baseline for comparison.


In [None]:
# Create a scatter plot of the first two dimensions colored by true cluster
plt.figure(figsize=(10, 8))
plt.scatter(X[:, 0], X[:, 1], c=y_true, cmap="viridis", s=80, alpha=0.8, edgecolors="k")
plt.title("Data Visualization (First Two Dimensions)")
plt.xlabel("Feature 1")
plt.ylabel("Feature 2")
plt.colorbar(label="True Cluster")
plt.grid(True, alpha=0.3)
plt.show()

# Create a correlation heatmap to visualize feature relationships
plt.figure(figsize=(12, 10))
correlation_matrix = np.corrcoef(X.T)
sns.heatmap(correlation_matrix, annot=False, cmap="coolwarm", center=0)
plt.title("Feature Correlation Matrix")
plt.show()

## 4. Training Multiple Archetypal Analysis Models

We'll train a series of archetypal analysis models with varying numbers of archetypes to identify the optimal configuration for our dataset.


In [None]:
# Define the range of archetype numbers to explore
archetype_range = range(2, 9)  # From 2 to 8 archetypes

# Dictionary to store trained models
models_dict = {}

# Train models for each number of archetypes
for n_archetypes in archetype_range:
    print(f"Training model with {n_archetypes} archetypes...")

    # Initialize the model with carefully selected hyperparameters
    model = ImprovedArchetypalAnalysis(
        n_archetypes=n_archetypes,
        max_iter=200,  # Sufficient iterations for convergence
        random_seed=42,
        learning_rate=0.01,
    )

    # Fit the model to our data
    model.fit(X)

    # Store the trained model
    models_dict[n_archetypes] = model

print("Model training complete for all archetype configurations.")

## 5. Initializing and Evaluating with ArchetypalAnalysisInterpreter

Now we'll leverage the `ArchetypalAnalysisInterpreter` to systematically evaluate our trained models using multiple interpretability metrics.


In [None]:
# Initialize the interpreter
interpreter = ArchetypalAnalysisInterpreter()

# Add all trained models to the interpreter
for n_archetypes, model in models_dict.items():
    interpreter.add_model(n_archetypes, model)

# Evaluate all models using comprehensive metrics
results = interpreter.evaluate_all_models(X)

print("Comprehensive evaluation completed for all models.")

## 6. Analyzing Interpretability Metrics

Let's examine how interpretability metrics vary across different numbers of archetypes to identify the optimal configuration.


In [None]:
# Extract metrics for visualization
n_archetypes_list = sorted(results.keys())
distinctiveness_scores = [results[k]["avg_distinctiveness"] for k in n_archetypes_list]
sparsity_scores = [results[k]["avg_sparsity"] for k in n_archetypes_list]
purity_scores = [results[k]["avg_purity"] for k in n_archetypes_list]
interpretability_scores = [results[k]["interpretability_score"] for k in n_archetypes_list]

# Create a multi-metric plot
plt.figure(figsize=(14, 8))

plt.plot(n_archetypes_list, distinctiveness_scores, "o-", label="Distinctiveness", linewidth=2)
plt.plot(n_archetypes_list, sparsity_scores, "s-", label="Sparsity", linewidth=2)
plt.plot(n_archetypes_list, purity_scores, "^-", label="Purity", linewidth=2)
plt.plot(n_archetypes_list, interpretability_scores, "D-", label="Overall Interpretability", linewidth=3)

plt.xlabel("Number of Archetypes")
plt.ylabel("Score")
plt.title("Interpretability Metrics by Number of Archetypes")
plt.grid(True, alpha=0.3)
plt.legend(loc="best")
plt.xticks(n_archetypes_list)
plt.show()

# Identify the optimal number of archetypes based on interpretability
optimal_n_archetypes = n_archetypes_list[np.argmax(interpretability_scores)]
print(f"Optimal number of archetypes based on interpretability: {optimal_n_archetypes}")

## 7. Reconstruction Error Analysis

Let's also examine how reconstruction error varies with the number of archetypes to balance interpretability with model fit.


In [None]:
# Calculate reconstruction error for each model
reconstruction_errors = []

for n_archetypes in n_archetypes_list:
    model = models_dict[n_archetypes]
    X_reconstructed = model.reconstruct(X)
    error = np.mean(np.sum((X - X_reconstructed) ** 2, axis=1))
    reconstruction_errors.append(error)

# Plot reconstruction error
plt.figure(figsize=(12, 6))
plt.plot(n_archetypes_list, reconstruction_errors, "o-", color="crimson", linewidth=2, markersize=10)
plt.xlabel("Number of Archetypes")
plt.ylabel("Mean Squared Reconstruction Error")
plt.title("Reconstruction Error by Number of Archetypes")
plt.grid(True, alpha=0.3)
plt.xticks(n_archetypes_list)
plt.show()

# Calculate elbow point (where adding more archetypes yields diminishing returns)
from scipy.interpolate import interp1d
from scipy.optimize import minimize


# Function to find the point of maximum curvature (elbow point)
def find_elbow_point(x, y):
    # Normalize data
    x_norm = (x - min(x)) / (max(x) - min(x))
    y_norm = (y - min(y)) / (max(y) - min(y))

    # Create interpolation function
    interp_func = interp1d(x_norm, y_norm, kind="cubic")

    # Function to minimize (distance from point to line connecting endpoints)
    def distance_to_line(point):
        p = np.array([point, interp_func(point)])
        start = np.array([0, interp_func(0)])
        end = np.array([1, interp_func(1)])
        return np.abs(np.cross(end - start, start - p)) / np.linalg.norm(end - start)

    # Find point of maximum distance
    result = minimize(lambda p: -distance_to_line(p[0]), [0.5], bounds=[(0, 1)])
    elbow_x = result.x[0] * (max(x) - min(x)) + min(x)
    return int(round(elbow_x))


# Find elbow point
elbow_n_archetypes = find_elbow_point(np.array(n_archetypes_list), np.array(reconstruction_errors))
print(f"Optimal number of archetypes based on elbow method: {elbow_n_archetypes}")

## 8. In-Depth Analysis of the Optimal Model

Now we'll conduct a detailed examination of the model identified as optimal by our evaluation.


In [None]:
# Select the optimal model based on interpretability
optimal_model = models_dict[optimal_n_archetypes]

# Display comprehensive evaluation metrics for the optimal model
optimal_results = results[optimal_n_archetypes]
print(f"Detailed evaluation metrics for model with {optimal_n_archetypes} archetypes:")
for metric, value in optimal_results.items():
    if isinstance(value, (int, float)):
        print(f"{metric}: {value:.4f}")
    elif isinstance(value, np.ndarray) and value.ndim == 1:
        print(f"\n{metric} per archetype:")
        for i, val in enumerate(value):
            print(f"  Archetype {i + 1}: {val:.4f}")
    else:
        print(f"{metric}: {type(value)}")

## 9. Analyzing Feature Distinctiveness and Sparsity

We'll examine how distinctive and sparse each archetype is, providing insights into their interpretability.


In [None]:
# Extract archetypes from the optimal model
archetypes = np.asarray(optimal_model.archetypes)

# Calculate feature distinctiveness for archetypes
distinctiveness = interpreter.feature_distinctiveness(archetypes)
print("Feature distinctiveness scores for archetypes:")
for i, score in enumerate(distinctiveness):
    print(f"Archetype {i + 1}: {score:.4f}")

# Calculate sparsity coefficients for archetypes
sparsity = interpreter.sparsity_coefficient(archetypes)
print("\nSparsity coefficients for archetypes:")
for i, score in enumerate(sparsity):
    print(f"Archetype {i + 1}: {score:.4f}")

# Visualize archetype profiles
plt.figure(figsize=(14, 10))
for i in range(archetypes.shape[0]):
    plt.subplot(int(np.ceil(archetypes.shape[0] / 2)), 2, i + 1)
    plt.bar(range(archetypes.shape[1]), archetypes[i], color="teal", alpha=0.7)
    plt.title(f"Archetype {i + 1} Profile")
    plt.xlabel("Feature Index")
    plt.ylabel("Value")
    plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

## 10. Evaluating Cluster Purity

We'll assess how well the identified archetypes correspond to natural clusters in the data.


In [None]:
# Extract weights from the optimal model
weights = np.asarray(optimal_model.weights)

# Calculate cluster purity
purity_scores, avg_purity = interpreter.cluster_purity(weights)
print(f"Average cluster purity: {avg_purity:.4f}")
print("\nPurity scores per archetype:")
for i, score in enumerate(purity_scores):
    print(f"Archetype {i + 1}: {score:.4f}")

# Identify dominant archetype for each sample
dominant_archetypes = np.argmax(weights, axis=1)

# Visualize the distribution of dominant archetypes
plt.figure(figsize=(10, 6))
sns.countplot(x=dominant_archetypes, palette="viridis")
plt.title("Distribution of Dominant Archetypes")
plt.xlabel("Archetype ID")
plt.ylabel("Number of Samples")
plt.grid(True, alpha=0.3)
plt.show()

## 11. Advanced Visualization of Archetypal Structure

We'll employ the `ArchetypalAnalysisVisualizer` to create sophisticated visualizations of our model's results.


In [None]:
# Initialize the visualizer with our optimal model
visualizer = ArchetypalAnalysisVisualizer()

# Generate simplex visualization (if data is low-dimensional or can be projected)
try:
    fig = visualizer.plot_simplex_2d(optimal_model, X)
    plt.title("2D Simplex Visualization")
    plt.show()
except Exception as e:
    print(f"Could not generate simplex visualization: {e}")

# Generate archetype profiles
fig = visualizer.plot_archetype_profiles(optimal_model, X)
plt.title("Archetype Profiles")
plt.show()

# Generate membership heatmap
fig = visualizer.plot_membership_heatmap(optimal_model, X)
plt.title("Membership Heatmap")
plt.show()

## 12. Comparing with Ground Truth

Finally, we'll compare the archetypes identified by our model with the known ground truth clusters.


In [None]:
# Compare dominant archetypes with true cluster labels
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.scatter(X[:, 0], X[:, 1], c=dominant_archetypes, cmap="viridis", s=80, alpha=0.8, edgecolors="k")
plt.title("Data Points Colored by Dominant Archetype")
plt.xlabel("Feature 1")
plt.ylabel("Feature 2")
plt.colorbar(label="Archetype ID")
plt.grid(True, alpha=0.3)

plt.subplot(1, 2, 2)
plt.scatter(X[:, 0], X[:, 1], c=y_true, cmap="plasma", s=80, alpha=0.8, edgecolors="k")
plt.title("Data Points Colored by True Cluster Label")
plt.xlabel("Feature 1")
plt.ylabel("Feature 2")
plt.colorbar(label="Cluster ID")
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Calculate agreement between dominant archetypes and true clusters
from sklearn.metrics import adjusted_rand_score, normalized_mutual_info_score

ari = adjusted_rand_score(y_true, dominant_archetypes)
nmi = normalized_mutual_info_score(y_true, dominant_archetypes)

print(f"Adjusted Rand Index: {ari:.4f}")
print(f"Normalized Mutual Information: {nmi:.4f}")

## 13. Conclusion and Key Insights

This tutorial has demonstrated the sophisticated capabilities of the `ArchetypalAnalysisInterpreter` for evaluating and interpreting archetypal analysis models. Key aspects covered include:

1. Generation of synthetic data with clear cluster structure
2. Training of multiple archetypal analysis models with varying numbers of archetypes
3. Systematic evaluation of interpretability metrics across model configurations
4. Identification of optimal archetype numbers using both interpretability and elbow method
5. In-depth analysis of feature distinctiveness and sparsity
6. Assessment of cluster purity and archetype distribution
7. Advanced visualization of archetypal structures
8. Comparison with ground truth to validate model performance

Archetypal analysis offers a powerful approach for identifying extreme patterns in data, with the `ArchetypalAnalysisInterpreter` providing essential tools for model interpretation and evaluation.
