# ArchetypAX: Basic Usage Example

This notebook demonstrates the fundamental capabilities of ArchetypAX, a GPU-accelerated implementation of Archetypal Analysis using JAX.


In [None]:
import os
import sys

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from scipy.spatial import ConvexHull

sys.path.insert(0, os.path.abspath("../"))

from archetypax.models import ImprovedArchetypalAnalysis as ArchetypalAnalysis
from archetypax.tools.evaluation import ArchetypalAnalysisEvaluator

## 1. Generating Synthetic Data

We'll start by creating a synthetic dataset with clear cluster structure to demonstrate the effectiveness of archetypal analysis.


In [None]:
# Generate synthetic data with 3 clusters
np.random.seed(42)
n_samples = 5000
n_centers = 3  # Number of clusters

# Create 3 clusters with some overlap
cluster1 = np.random.randn(n_samples // 3, 2) * 0.5 + np.array([2, 2])
cluster2 = np.random.randn(n_samples // 3, 2) * 0.5 + np.array([-2, 2])
cluster3 = np.random.randn(n_samples // 3, 2) * 0.5 + np.array([0, -2])

# Create corresponding labels for each cluster
y1 = np.zeros(n_samples // 3)
y2 = np.ones(n_samples // 3)
y3 = np.ones(n_samples // 3) * 2

# Combine data and labels simultaneously
X = np.vstack([cluster1, cluster2, cluster3])
y = np.concatenate([y1, y2, y3])

# Shuffle data and labels together to maintain correspondence
indices = np.arange(X.shape[0])
np.random.shuffle(indices)
X = X[indices]
y = y[indices]

print(f"Data shape: {X.shape}")

# Visualize the data
plt.figure(figsize=(10, 6))
plt.scatter(X[:, 0], X[:, 1], c=y, cmap="viridis", alpha=0.7, s=50)
plt.title("Synthetic Dataset with 3 Clusters", fontsize=14)
plt.xlabel("Feature 1", fontsize=12)
plt.ylabel("Feature 2", fontsize=12)
plt.colorbar(label="Cluster")
plt.grid(alpha=0.3)
plt.show()

## 2. Fitting the Archetypal Analysis Model

Now we'll apply ArchetypAX to identify the archetypes in our data. We'll set the number of archetypes to match our known number of clusters.


In [None]:
# Initialize and fit the model
model = ArchetypalAnalysis(
    n_archetypes=3,
    max_iter=1000,
    tol=1e-10,
    learning_rate=0.001,
    lambda_reg=0.01,
    normalize=False,
    projection_alpha=0.2,
    projection_method="cbap",
    archetype_init_method="directional",
    random_seed=42,
    verbose_level=1,
)
weights = model.fit_transform(
    X,
    method="adam",
    max_iter=1000,
)

# Display the loss history
loss_history = model.get_loss_history()
plt.figure(figsize=(10, 5))
plt.plot(loss_history)
plt.title("Convergence of Archetypal Analysis", fontsize=14)
plt.xlabel("Iteration", fontsize=12)
plt.ylabel("Loss", fontsize=12)
plt.grid(alpha=0.3)
plt.show()

In [None]:
# Calculate the convex hull of the archetypes
hull = ConvexHull(model.archetypes, qhull_options='QJ')

# Get the volume of the convex hull
hull_volume = hull.volume

# Display the volume
print(f"Convex hull volume of the archetypes: {hull_volume:.4f}")

In [None]:
thresh = 0.0

plt.figure(figsize=(10, 10))
_weights = model.weights[np.max(model.weights, axis=1) > thresh]
sns.heatmap(_weights[:30], cmap="viridis", annot=True, fmt=".2f")
plt.title("Sample of archetype weights from the last iteration")
plt.show()

plt.figure(figsize=(10, 10))
_weights = weights[np.max(weights, axis=1) > thresh]
sns.heatmap(_weights[:30], cmap="viridis", annot=True, fmt=".2f")
plt.title("Sample of archetype weights from the transformed data")
plt.show()

## 3. Visualizing the Archetypes

Let's visualize the identified archetypes in relation to our data points.


In [None]:
# Extract the archetypes
archetypes = model.archetypes

# Visualize data points and archetypes
plt.figure(figsize=(12, 8))
plt.scatter(X[:, 0], X[:, 1], c=y, cmap="viridis", alpha=0.6, s=40, label="Data points")
plt.scatter(
    archetypes[:, 0], archetypes[:, 1], c="red", s=200, marker="*", edgecolor="black", linewidth=1.5, label="Archetypes"
)

# Add archetype indices
for i, (_x, _y) in enumerate(archetypes):
    plt.annotate(f"A{i + 1}", (_x, _y), fontsize=14, fontweight="bold", xytext=(10, 10), textcoords="offset points")

plt.title("Data Points and Identified Archetypes", fontsize=14)
plt.xlabel("Feature 1", fontsize=12)
plt.ylabel("Feature 2", fontsize=12)
plt.legend(fontsize=12)
plt.grid(alpha=0.3)
plt.show()

## 4. Analyzing Archetype Weights

Each data point is represented as a convex combination of the archetypes. Let's visualize these weights.


In [None]:
# Create a colormap based on the dominant archetype for each point
dominant_archetypes = np.argmax(weights, axis=1)

# Visualize data points colored by their dominant archetype
plt.figure(figsize=(12, 8))
plt.scatter(X[:, 0], X[:, 1], c=dominant_archetypes, cmap="Set1", alpha=0.7, s=50)
plt.scatter(
    archetypes[:, 0],
    archetypes[:, 1],
    c="black",
    s=150,
    marker="*",
    edgecolor="white",
    linewidth=1.5,
    label="Archetypes",
)

# Add archetype indices
for i, (_x, _y) in enumerate(archetypes):
    plt.annotate(f"A{i + 1}", (_x, _y), fontsize=14, fontweight="bold", xytext=(10, 10), textcoords="offset points")

plt.title("Data Points Colored by Dominant Archetype", fontsize=14)
plt.xlabel("Feature 1", fontsize=12)
plt.ylabel("Feature 2", fontsize=12)
plt.legend(fontsize=12)
plt.grid(alpha=0.3)
plt.show()

## 5. Evaluating the Model

Let's use the ArchetypalAnalysisEvaluator to assess the quality of our model.


In [None]:
# Initialize the evaluator
evaluator = ArchetypalAnalysisEvaluator(model)

# Calculate reconstruction error
frobenius_error = evaluator.reconstruction_error(X, metric="frobenius")
relative_error = evaluator.reconstruction_error(X, metric="relative")
mse_error = evaluator.reconstruction_error(X, metric="mse")

# Calculate explained variance
explained_var = evaluator.explained_variance(X)

# Calculate archetype purity
purity_results = evaluator.dominant_archetype_purity()

# Calculate archetype separation
separation_results = evaluator.archetype_separation()

# Display results
print(f"Reconstruction Error (Frobenius): {frobenius_error:.4f}")
print(f"Reconstruction Error (Relative): {relative_error:.4f}")
print(f"Reconstruction Error (MSE): {mse_error:.4f}")
print(f"Explained Variance: {explained_var:.4f}")
print(f"Overall Archetype Purity: {purity_results['overall_purity']:.4f}")
print(f"Archetype Separation (Mean Distance): {separation_results['mean_distance']:.4f}")

## 6. Visualizing the Reconstruction

Let's compare the original data with its reconstruction using the archetypes.


In [None]:
# Reconstruct the data
X_reconstructed = model.reconstruct()
X_reconstructed = weights @ archetypes

# Visualize original vs reconstructed data
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))

# Original data
ax1.scatter(X[:, 0], X[:, 1], c=y, cmap="viridis", alpha=0.7, s=50)
ax1.scatter(
    archetypes[:, 0],
    archetypes[:, 1],
    c="red",
    s=200,
    marker="*",
    edgecolor="black",
    linewidth=1.5,
    label="Archetypes",
)
ax1.set_title("Original Data", fontsize=14)
ax1.set_xlabel("Feature 1", fontsize=12)
ax1.set_ylabel("Feature 2", fontsize=12)
ax1.grid(alpha=0.3)
ax1.legend(fontsize=12)

# Reconstructed data
ax2.scatter(X_reconstructed[:, 0], X_reconstructed[:, 1], c=y, cmap="viridis", alpha=0.7, s=50)
ax2.scatter(
    archetypes[:, 0],
    archetypes[:, 1],
    c="red",
    s=200,
    marker="*",
    edgecolor="black",
    linewidth=1.5,
    label="Archetypes",
)

# Connect archetype points with lines to enhance visualization
# Create a convex hull by connecting the archetypes
hull = ConvexHull(archetypes)
for simplex in hull.simplices:
    ax2.plot(archetypes[simplex, 0], archetypes[simplex, 1], "r-", alpha=0.6, linewidth=2)

ax2.set_title("Reconstructed Data", fontsize=14)
ax2.set_xlabel("Feature 1", fontsize=12)
ax2.set_ylabel("Feature 2", fontsize=12)
ax2.legend(fontsize=12)
ax2.grid(alpha=0.3)

plt.tight_layout()
plt.show()

## 7. Visualizing Weight Distribution

Let's examine how the weights are distributed across the archetypes.


In [None]:
# Create a DataFrame for the weights
weights_df = pd.DataFrame(weights, columns=[f"Archetype {i + 1}" for i in range(model.n_archetypes)])

# Plot the distribution of weights
plt.figure(figsize=(12, 6))
sns.boxplot(data=weights_df)
plt.title("Distribution of Archetype Weights", fontsize=14)
plt.ylabel("Weight Value", fontsize=12)
plt.grid(axis="y", alpha=0.3)
plt.show()

# Plot the correlation between weights
plt.figure(figsize=(10, 8))
sns.heatmap(weights_df.corr(), annot=True, cmap="coolwarm", vmin=-1, vmax=1, center=0)
plt.title("Correlation Between Archetype Weights", fontsize=14)
plt.show()

## 8. Conclusion

This notebook has demonstrated the basic usage of ArchetypAX for archetypal analysis. We've shown how to:

1. Fit an archetypal analysis model to data
2. Visualize the identified archetypes
3. Analyze the weight distributions
4. Evaluate the model's performance
5. Reconstruct the data using the archetypes

Archetypal analysis provides an interpretable representation of data by identifying extreme, yet representative points (archetypes) and expressing each data point as a mixture of these archetypes.
