# Advanced Tutorial: Interpreting Biarchetypal Analysis Models

This notebook demonstrates sophisticated techniques for interpreting and evaluating biarchetypal analysis models using the `BiarchetypalAnalysisInterpreter` class from the `archetypax` library. Biarchetypal analysis extends traditional archetypal analysis by simultaneously identifying archetypes in both observations (rows) and features (columns), offering a more nuanced understanding of complex data structures.


## 1. Importing Essential Libraries


In [None]:
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from sklearn.datasets import make_blobs
from sklearn.preprocessing import StandardScaler

# Import archetypax components
from archetypax.models.biarchetypes import BiarchetypalAnalysis
from archetypax.tools.interpret import BiarchetypalAnalysisInterpreter
from archetypax.tools.visualization import BiarchetypalAnalysisVisualizer

# Configure visualization settings
plt.style.use("seaborn-v0_8-whitegrid")
sns.set_context("notebook", font_scale=1.2)
plt.rcParams["figure.figsize"] = [12, 8]

## 2. Generating Synthetic Data with Dual Structure

We'll create a synthetic dataset with inherent structure in both rows and columns, making it ideal for demonstrating biarchetypal analysis. This approach allows us to evaluate the model's ability to recover known patterns.


In [None]:
# Set random seed for reproducibility
np.random.seed(42)

# Define dataset parameters
n_samples = 10
n_features = 10
n_row_clusters = 3
n_col_clusters = 4

# Generate data with row clusters
X_raw, row_labels = make_blobs(
    n_samples=n_samples, n_features=n_features, centers=n_row_clusters, random_state=42, cluster_std=1.5
)

# Add column structure by creating feature groups with distinct correlation patterns
feature_groups = np.array_split(np.arange(n_features), n_col_clusters)
col_labels = np.zeros(n_features, dtype=int)

# Apply different correlation structures to each feature group
for i, group in enumerate(feature_groups):
    # Generate a random positive-definite correlation matrix
    corr_matrix = np.random.uniform(0.5, 0.9, size=(len(group), len(group)))
    corr_matrix = (corr_matrix + corr_matrix.T) / 2  # Ensure symmetry
    np.fill_diagonal(corr_matrix, 1.0)  # Set diagonal to 1

    # Apply correlation structure using Cholesky decomposition
    L = np.linalg.cholesky(corr_matrix)
    uncorrelated = X_raw[:, group]
    correlated = uncorrelated @ L.T

    # Apply group-specific scaling
    scale_factor = np.random.uniform(0.5, 2.0)
    X_raw[:, group] = correlated * scale_factor

    # Assign column labels
    col_labels[group] = i

# Standardize the data
scaler = StandardScaler()
X = scaler.fit_transform(X_raw)

print(f"Dataset dimensions: {X.shape}")
print(f"Number of row clusters: {n_row_clusters}")
print(f"Number of column clusters: {n_col_clusters}")

## 3. Visualizing the Dual Structure

Before applying biarchetypal analysis, let's visualize the inherent structure in our synthetic dataset to establish a baseline for comparison.


In [None]:
# Create a heatmap of the data matrix with rows and columns sorted by their respective clusters
plt.figure(figsize=(12, 10))

# Sort indices by cluster labels
row_idx = np.argsort(row_labels)
col_idx = np.argsort(col_labels)

# Generate heatmap with sorted data
sorted_data = X[row_idx][:, col_idx]
sns.heatmap(sorted_data, cmap="viridis", center=0)
plt.title("Data Matrix Heatmap (Sorted by True Cluster Labels)")
plt.xlabel("Features (Columns)")
plt.ylabel("Samples (Rows)")
plt.show()

## 4. Training Multiple Biarchetypal Models

We'll train a suite of biarchetypal models with varying numbers of row and column archetypes to identify the optimal configuration for our dataset.


In [None]:
# Define the range of archetype numbers to explore
row_archetypes_range = range(2, 6)  # From 2 to 5 row archetypes
col_archetypes_range = range(2, 6)  # From 2 to 5 column archetypes

# Dictionary to store trained models
models_dict = {}

# Train models for each combination of row and column archetypes
for n_row in row_archetypes_range:
    for n_col in col_archetypes_range:
        print(f"Training model with {n_row} row archetypes and {n_col} column archetypes...")

        # Initialize the model with carefully selected hyperparameters
        model = BiarchetypalAnalysis(
            n_row_archetypes=n_row,
            n_col_archetypes=n_col,
            max_iter=100,  # Limit iterations for faster convergence
            random_seed=42,
            learning_rate=0.01,
            lambda_reg=0.01,  # Regularization parameter for improved stability
        )

        # Fit the model to our data
        model.fit(X)

        # Store the trained model
        models_dict[(n_row, n_col)] = model

print("Model training complete for all archetype combinations.")

## 5. Initializing and Evaluating with BiarchetypalAnalysisInterpreter

Now we'll leverage the `BiarchetypalAnalysisInterpreter` to systematically evaluate our trained models using multiple interpretability metrics.


In [None]:
# Initialize the interpreter
interpreter = BiarchetypalAnalysisInterpreter()

# Add all trained models to the interpreter
for (n_row, n_col), model in models_dict.items():
    interpreter.add_model(n_row, n_col, model)

# Evaluate all models using comprehensive metrics
results = interpreter.evaluate_all_models(X)

# Calculate information gain for model comparison
interpreter.compute_information_gain(X)

print("Comprehensive evaluation completed for all models.")

## 6. Determining Optimal Archetype Configurations

We'll employ multiple methodologies to identify the optimal number of archetypes, comparing their recommendations.


In [None]:
# Determine optimal configuration using balance method
optimal_balance = interpreter.suggest_optimal_biarchetypes(method="balance")
print(
    f"Optimal configuration via balance method: {optimal_balance[0]} row archetypes, {optimal_balance[1]} column archetypes"
)

# Determine optimal configuration using interpretability method
optimal_interpretability = interpreter.suggest_optimal_biarchetypes(method="interpretability")
print(
    f"Optimal configuration via interpretability method: {optimal_interpretability[0]} row archetypes, {optimal_interpretability[1]} column archetypes"
)

# Determine optimal configuration using information gain method
try:
    optimal_info_gain = interpreter.suggest_optimal_biarchetypes(method="information_gain")
    print(
        f"Optimal configuration via information gain method: {optimal_info_gain[0]} row archetypes, {optimal_info_gain[1]} column archetypes"
    )
except ValueError as e:
    print(f"Information gain method unavailable: {e}")

## 7. Visualizing Interpretability Metrics

Let's create heatmaps to visualize how interpretability metrics vary across different archetype configurations.


In [None]:
# Generate interpretability heatmaps
fig = interpreter.plot_interpretability_heatmap()
plt.tight_layout()
plt.show()

## 8. In-Depth Analysis of the Optimal Model

Now we'll conduct a detailed examination of the model identified as optimal by our evaluation.


In [None]:
# Select the optimal model (using the balance method recommendation)
optimal_model = models_dict[optimal_balance]

# Display comprehensive evaluation metrics for the optimal model
optimal_results = results[optimal_balance]
print("Evaluation metrics for the optimal model:")
for metric, value in optimal_results.items():
    if isinstance(value, (int, float)):
        print(f"{metric}: {value:.4f}")
    else:
        print(f"{metric}: {type(value)}")

## 9. Analyzing Feature Distinctiveness and Sparsity

We'll examine how distinctive and sparse each archetype is, providing insights into their interpretability.


In [None]:
# Extract row and column archetypes from the optimal model
row_archetypes, col_archetypes = optimal_model.get_all_archetypes()

# Calculate feature distinctiveness for row archetypes
row_distinctiveness = interpreter.feature_distinctiveness(row_archetypes)
print("Feature distinctiveness scores for row archetypes:")
for i, score in enumerate(row_distinctiveness):
    print(f"Archetype {i + 1}: {score:.4f}")

# Calculate feature distinctiveness for column archetypes
col_distinctiveness = interpreter.feature_distinctiveness(col_archetypes)
print("\nFeature distinctiveness scores for column archetypes:")
for i, score in enumerate(col_distinctiveness):
    print(f"Archetype {i + 1}: {score:.4f}")

# Calculate sparsity coefficients for row archetypes
row_sparsity = interpreter.sparsity_coefficient(row_archetypes)
print("\nSparsity coefficients for row archetypes:")
for i, score in enumerate(row_sparsity):
    print(f"Archetype {i + 1}: {score:.4f}")

# Calculate sparsity coefficients for column archetypes
col_sparsity = interpreter.sparsity_coefficient(col_archetypes)
print("\nSparsity coefficients for column archetypes:")
for i, score in enumerate(col_sparsity):
    print(f"Archetype {i + 1}: {score:.4f}")

## 10. Evaluating Cluster Purity

We'll assess how well the identified archetypes correspond to natural clusters in the data.


In [None]:
# Extract weights for rows and columns
row_weights, col_weights = optimal_model.get_all_weights()

# Calculate cluster purity for row weights
row_dominant, row_purity = interpreter.cluster_purity(row_weights)
print(f"Cluster purity for row archetypes: {row_purity:.4f}")

# Calculate cluster purity for column weights
col_dominant, col_purity = interpreter.cluster_purity(col_weights)
print(f"Cluster purity for column archetypes: {col_purity:.4f}")

# Visualize the distribution of dominant archetypes
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
sns.countplot(x=row_dominant)
plt.title("Distribution of Dominant Row Archetypes")
plt.xlabel("Archetype ID")
plt.ylabel("Number of Samples")

plt.subplot(1, 2, 2)
sns.countplot(x=col_dominant)
plt.title("Distribution of Dominant Column Archetypes")
plt.xlabel("Archetype ID")
plt.ylabel("Number of Features")

plt.tight_layout()
plt.show()

## 11. Advanced Visualization of Biarchetypal Structure

We'll employ the `BiarchetypalAnalysisVisualizer` to create sophisticated visualizations of our model's results.


In [None]:
# Initialize the visualizer with our optimal model
visualizer = BiarchetypalAnalysisVisualizer()

# Generate dual membership heatmap
visualizer.plot_dual_membership_heatmap(model=optimal_model)

## 12. Comparing with Ground Truth

Finally, we'll compare the archetypes identified by our model with the known ground truth clusters.


In [None]:
# Compare dominant row archetypes with true row cluster labels
plt.figure(figsize=(10, 6))
plt.scatter(X[:, 0], X[:, 1], c=row_dominant, cmap="viridis", s=100, alpha=0.7, edgecolors="k")
plt.title("Data Points Colored by Dominant Row Archetype")
plt.xlabel("Feature 1")
plt.ylabel("Feature 2")
plt.colorbar(label="Archetype ID")
plt.show()

plt.figure(figsize=(10, 6))
plt.scatter(X[:, 0], X[:, 1], c=row_labels, cmap="plasma", s=100, alpha=0.7, edgecolors="k")
plt.title("Data Points Colored by True Cluster Label")
plt.xlabel("Feature 1")
plt.ylabel("Feature 2")
plt.colorbar(label="Cluster ID")
plt.show()

# Compare dominant column archetypes with true column cluster labels
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 8))

# Plot dominant column archetypes
ax1.bar(range(n_features), col_dominant)
ax1.set_title("Dominant Column Archetypes")
ax1.set_xlabel("Feature Index")
ax1.set_ylabel("Archetype ID")

# Plot true column cluster labels
ax2.bar(range(n_features), col_labels)
ax2.set_title("True Column Cluster Labels")
ax2.set_xlabel("Feature Index")
ax2.set_ylabel("Cluster ID")

plt.tight_layout()
plt.show()

## 13. Conclusion and Key Insights

This tutorial has demonstrated the sophisticated capabilities of the `BiarchetypalAnalysisInterpreter` for evaluating and interpreting biarchetypal models. Key aspects covered include:

1. Generation of synthetic data with dual structure in both rows and columns
2. Training of multiple biarchetypal models with varying archetype configurations
3. Systematic evaluation and selection of optimal archetype numbers
4. Visualization of interpretability metrics across model configurations
5. In-depth analysis of feature distinctiveness and sparsity
6. Assessment of cluster purity and archetype distribution
7. Advanced visualization of biarchetypal structures
8. Comparison with ground truth to validate model performance

Biarchetypal analysis offers a powerful approach for simultaneously capturing structure in both observations and features, with the `BiarchetypalAnalysisInterpreter` providing essential tools for model interpretation and evaluation.
