In [None]:
import sys
import os

import numpy as np

sys.path.insert(0, os.path.abspath("../"))

from archetypax.models import SparseArchetypalAnalysis, ImprovedArchetypalAnalysis

In [None]:
# Data preparation
X = np.random.rand(100, 10)  # Sample data (100 samples × 10 features)

# Creating a sparse archetypal analysis model with L1 regularization
model = SparseArchetypalAnalysis(
    n_archetypes=3,         # Number of archetypes
    lambda_sparsity=0.1,    # Strength of sparsity constraint
    sparsity_method="l1",   # Method of sparsity ("l1", "l0_approx", "feature_selection")
    max_iter=200,           # Maximum number of iterations
    normalize=True          # Normalize the data
)

# Fitting the model
weights = model.fit_transform(X)

# Utilizing the results of the training
X_reconstructed = np.dot(weights, model.archetypes)
sparsity_scores = model.get_archetype_sparsity() # Sparsity scores of the archetypes

print(f"Sparsity scores of the archetypes: {sparsity_scores}")

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from sklearn.manifold import TSNE


def generate_synthetic_data(n_samples=500, n_features=50, n_true_archetypes=5, noise_level=0.1):
    np.random.seed(42)

    # Generate true archetypes (each archetype activates only a few features)
    true_archetypes = np.zeros((n_true_archetypes, n_features))

    for i in range(n_true_archetypes):
        # Each archetype is associated with a specific group of features
        features_per_archetype = n_features // n_true_archetypes
        start_idx = i * features_per_archetype
        end_idx = (i + 1) * features_per_archetype

        # Activate only a few features (sparsity)
        n_active = features_per_archetype // 2
        active_indices = np.random.choice(
            np.arange(start_idx, end_idx),
            size=n_active,
            replace=False
        )
        true_archetypes[i, active_indices] = np.random.uniform(0.5, 1.0, size=n_active)

        # Add some noisy features (not completely isolated)
        noise_indices = np.random.choice(
            np.setdiff1d(np.arange(n_features), np.arange(start_idx, end_idx)),
            size=features_per_archetype // 5,
            replace=False
        )
        true_archetypes[i, noise_indices] = np.random.uniform(0.1, 0.3, size=len(noise_indices))

    # Normalize each archetype
    true_archetypes = true_archetypes / np.linalg.norm(true_archetypes, axis=1, keepdims=True)

    # Generate weights for each data point
    weights = np.zeros((n_samples, n_true_archetypes))

    for i in range(n_samples):
        # Each sample is primarily a combination of 1-2 archetypes
        n_dominant = np.random.choice([1, 2], p=[0.7, 0.3])
        dominant_archetypes = np.random.choice(n_true_archetypes, size=n_dominant, replace=False)

        # Generate weights
        weights_unnormalized = np.zeros(n_true_archetypes)
        weights_unnormalized[dominant_archetypes] = np.random.uniform(0.5, 1.0, size=n_dominant)

        # Assign a small amount of weight to other archetypes
        non_dominant = np.setdiff1d(np.arange(n_true_archetypes), dominant_archetypes)
        if len(non_dominant) > 0:
            n_minor = np.random.randint(0, len(non_dominant) + 1)
            if n_minor > 0:
                minor_archetypes = np.random.choice(non_dominant, size=n_minor, replace=False)
                weights_unnormalized[minor_archetypes] = np.random.uniform(0.01, 0.2, size=n_minor)

        # Normalize weights
        weights[i] = weights_unnormalized / weights_unnormalized.sum()

    # Generate data
    X_clean = np.dot(weights, true_archetypes)

    # Add noise
    noise = np.random.normal(0, noise_level, size=X_clean.shape)
    X = X_clean + noise

    # Clip negative values to zero (as non-negative data)
    X = np.maximum(X, 0)

    return X, true_archetypes, weights


def main():
    # Generating synthetic data
    print("Generating synthetic data...")
    n_samples = 300
    n_features = 50
    n_true_archetypes = 5
    X, true_archetypes, true_weights = generate_synthetic_data(
        n_samples=n_samples,
        n_features=n_features,
        n_true_archetypes=n_true_archetypes,
        noise_level=0.15
    )

    print(f"Data shape: {X.shape}")

    # Visualizing the data
    plt.figure(figsize=(14, 6))

    plt.subplot(1, 2, 1)
    plt.imshow(X[:50], aspect='auto', cmap='viridis')
    plt.title("Data Sample (First 50 Entries)")
    plt.colorbar()

    plt.subplot(1, 2, 2)
    plt.imshow(true_archetypes, aspect='auto', cmap='viridis')
    plt.title("True Archetypes")
    plt.colorbar()

    plt.tight_layout()
    plt.show()

    # Visualizing with PCA for dimensionality reduction
    pca = PCA(n_components=2)
    X_pca = pca.fit_transform(X)

    plt.figure(figsize=(8, 6))
    # Color based on the true dominant archetypes
    dominant_archetypes = np.argmax(true_weights, axis=1)

    scatter = plt.scatter(X_pca[:, 0], X_pca[:, 1], c=dominant_archetypes, cmap='viridis', alpha=0.7)
    plt.colorbar(scatter, label='Dominant Archetypes')
    plt.title("Synthetic Data Visualized with PCA")
    plt.xlabel("First Principal Component")
    plt.ylabel("Second Principal Component")
    plt.grid(True, alpha=0.3)
    plt.show()

    # Comparing standard archetypal analysis with sparse archetypal analysis
    n_archetypes = n_true_archetypes  # Set to the true number of archetypes

    print(f"\nRunning standard archetypal analysis (number of archetypes = {n_archetypes})...")
    model_standard = ImprovedArchetypalAnalysis(n_archetypes=n_archetypes, random_seed=42)
    model_standard.fit(X, normalize=True)

    print(f"\nRunning L1 sparse archetypal analysis (number of archetypes = {n_archetypes})...")
    model_sparse_l1 = SparseArchetypalAnalysis(
        n_archetypes=n_archetypes,
        random_seed=42,
        lambda_sparsity=0.15,
        sparsity_method="l1"
    )
    model_sparse_l1.fit(X, normalize=True)

    print(f"\nRunning feature selection sparse archetypal analysis (number of archetypes = {n_archetypes})...")
    model_sparse_fs = SparseArchetypalAnalysis(
        n_archetypes=n_archetypes,
        random_seed=42,
        lambda_sparsity=0.15,
        sparsity_method="feature_selection"
    )
    model_sparse_fs.fit(X, normalize=True)

    # Comparing sparsity scores
    sparsity_standard = np.zeros(n_archetypes)
    for i in range(n_archetypes):
        values = np.abs(model_standard.archetypes[i])
        sorted_values = np.sort(values)
        n = len(sorted_values)
        cumsum = np.cumsum(sorted_values)
        gini = 1 - 2 * np.sum(cumsum) / (n * np.sum(sorted_values))
        sparsity_standard[i] = gini

    sparsity_l1 = model_sparse_l1.get_archetype_sparsity()
    sparsity_fs = model_sparse_fs.get_archetype_sparsity()

    print("\nSparsity scores of the archetypes (Gini coefficient, higher is sparser):")
    print(f"Standard Model: {sparsity_standard.mean():.4f}")
    print(f"L1 Sparse Model: {sparsity_l1.mean():.4f}")
    print(f"Feature Selection Model: {sparsity_fs.mean():.4f}")

    # Visualizing the archetypes
    plt.figure(figsize=(15, 12))

    # True archetypes
    plt.subplot(4, 1, 1)
    sns.heatmap(true_archetypes, cmap='viridis', cbar_kws={'label': 'Value'})
    plt.title("True Archetypes")
    plt.ylabel("Archetypes")

    # Archetypes of the standard model
    plt.subplot(4, 1, 2)
    sns.heatmap(model_standard.archetypes, cmap='viridis', cbar_kws={'label': 'Value'})
    plt.title(f"Archetypes of the Standard Model (Average Sparsity: {sparsity_standard.mean():.4f})")
    plt.ylabel("Archetypes")

    # Archetypes of the L1 sparse model
    plt.subplot(4, 1, 3)
    sns.heatmap(model_sparse_l1.archetypes, cmap='viridis', cbar_kws={'label': 'Value'})
    plt.title(f"Archetypes of the L1 Sparse Model (Average Sparsity: {sparsity_l1.mean():.4f})")
    plt.ylabel("Archetypes")

    # Archetypes of the feature selection model
    plt.subplot(4, 1, 4)
    sns.heatmap(model_sparse_fs.archetypes, cmap='viridis', cbar_kws={'label': 'Value'})
    plt.title(f"Archetypes of the Feature Selection Model (Average Sparsity: {sparsity_fs.mean():.4f})")
    plt.ylabel("Archetypes")
    plt.xlabel("Features")

    plt.tight_layout()
    plt.show()

    # Comparing reconstruction errors
    X_recon_standard = model_standard.reconstruct()
    X_recon_sparse_l1 = model_sparse_l1.reconstruct()
    X_recon_sparse_fs = model_sparse_fs.reconstruct()

    mse_standard = np.mean((X - X_recon_standard) ** 2)
    mse_sparse_l1 = np.mean((X - X_recon_sparse_l1) ** 2)
    mse_sparse_fs = np.mean((X - X_recon_sparse_fs) ** 2)

    print("\nReconstruction Mean Squared Error (MSE):")
    print(f"Standard Model: {mse_standard:.6f}")
    print(f"L1 Sparse Model: {mse_sparse_l1:.6f}")
    print(f"Feature Selection Model: {mse_sparse_fs:.6f}")

    # Visualizing the feature distribution of each archetype
    plt.figure(figsize=(15, 12))

    for i in range(min(n_archetypes, 6)):  # Display a maximum of 6 archetypes
        plt.subplot(3, 2, i+1)

        plt.plot(true_archetypes[i], 'k-', label='True Archetype', alpha=0.7)
        plt.plot(model_standard.archetypes[i], 'b-', label='Standard Model')
        plt.plot(model_sparse_l1.archetypes[i], 'r-', label='L1 Sparse Model')
        plt.plot(model_sparse_fs.archetypes[i], 'g-', label='Feature Selection Model')

        plt.title(f"Feature Distribution of Archetype {i+1}")
        plt.xlabel("Features")
        plt.ylabel("Value")
        plt.grid(True, alpha=0.3)

        if i == 0:  # Show legend only on the first subplot
            plt.legend()

    plt.tight_layout()
    plt.show()

    # Scatter plot showing the trade-off between sparsity and reconstruction error for each model
    sparsity_values = [
        sparsity_standard.mean(),
        sparsity_l1.mean(),
        sparsity_fs.mean()
    ]

    mse_values = [mse_standard, mse_sparse_l1, mse_sparse_fs]
    model_names = ['Standard Model', 'L1 Sparse Model', 'Feature Selection Model']

    plt.figure(figsize=(8, 6))
    plt.scatter(sparsity_values, mse_values, s=100)

    # Label each point
    for i, name in enumerate(model_names):
        plt.annotate(name, (sparsity_values[i], mse_values[i]), xytext=(5, 5), textcoords='offset points',)

    plt.xlabel('Average Sparsity (Gini Coefficient)')
    plt.ylabel('Reconstruction MSE')
    plt.title('Trade-off between Sparsity and Reconstruction Error')
    plt.grid(True, alpha=0.3)
    plt.show()


main()