# ArchetypAX: Biarchetypal Analysis

This notebook demonstrates the application of Biarchetypal Analysis, a technique that simultaneously learns archetypes for both rows (observations) and columns (features).


## 1. Importing Required Libraries


In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.datasets import make_blobs, make_moons
from sklearn.preprocessing import StandardScaler

# Import BiarchetypalAnalysis from archetypax
from archetypax.models.biarchetypes import BiarchetypalAnalysis

## 2. Generating Sample Data


In [None]:
# Generate synthetic data (3 clusters)
X, y = make_blobs(n_samples=300, centers=3, n_features=2, random_state=42)
X = StandardScaler().fit_transform(X)

# Visualize the data
plt.figure(figsize=(10, 6))
plt.scatter(X[:, 0], X[:, 1], c=y, cmap="viridis", alpha=0.7)
plt.title("Sample Data (3 Clusters)")
plt.xlabel("Feature 1")
plt.ylabel("Feature 2")
plt.colorbar(label="Cluster")
plt.grid(alpha=0.3)
plt.show()

## 3. Initializing and Training the Biarchetypal Analysis Model


In [None]:
# Initialize the Biarchetypal Analysis model
# Number of row archetypes=3, number of column archetypes=2
model = BiarchetypalAnalysis(
    n_row_archetypes=3,  # Number of row (observation) archetypes
    n_col_archetypes=2,  # Number of column (feature) archetypes
    max_iter=1000,  # Maximum number of iterations
    tol=1e-6,  # Convergence tolerance
    random_seed=42,  # Random seed
    learning_rate=0.001,  # Learning rate
    lambda_reg=0.01,  # Regularization parameter
)

# Train the model
model.fit(X, normalize=True)  # Normalize data during training

## 4. Retrieving and Visualizing Row and Column Archetypes


In [None]:
# Retrieve row archetypes
row_archetypes = model.get_row_archetypes()
print(f"Shape of row archetypes: {row_archetypes.shape}")

# Visualize row archetypes
plt.figure(figsize=(10, 6))
plt.scatter(X[:, 0], X[:, 1], c="lightgray", alpha=0.5, label="Data points")
plt.scatter(row_archetypes[:, 0], row_archetypes[:, 1], c="red", s=100, marker="*", label="Row archetypes")

# Label each archetype
for i, archetype in enumerate(row_archetypes):
    plt.annotate(f"A{i + 1}", (archetype[0], archetype[1]), fontsize=12, xytext=(10, 10), textcoords="offset points")

plt.title("Data and Row Archetypes")
plt.xlabel("Feature 1")
plt.ylabel("Feature 2")
plt.legend()
plt.grid(alpha=0.3)
plt.show()

In [None]:
# Retrieve and visualize the biarchetypes matrix
biarchetypes = model.get_biarchetypes()
print(f"Shape of biarchetypes matrix: {biarchetypes.shape}")

plt.figure(figsize=(8, 6))
sns.heatmap(
    biarchetypes,
    annot=True,
    cmap="viridis",
    fmt=".2f",
    xticklabels=[f"C{i + 1}" for i in range(model.n_col_archetypes)],
    yticklabels=[f"R{i + 1}" for i in range(model.n_row_archetypes)],
)
plt.title("Biarchetypes Matrix")
plt.xlabel("Column Archetypes")
plt.ylabel("Row Archetypes")
plt.tight_layout()
plt.show()

## 5. Data Transformation and Reconstruction


In [None]:
# Transform data into row and column archetype weights
row_weights, col_weights = model.transform(X)
print(f"Shape of row weights: {row_weights.shape}")
print(f"Shape of column weights: {col_weights.shape}")

In [None]:
# Visualize row weights (first 10 samples)
plt.figure(figsize=(10, 5))
sns.heatmap(
    row_weights[:10],
    annot=True,
    cmap="Blues",
    fmt=".2f",
    xticklabels=[f"A{i + 1}" for i in range(model.n_row_archetypes)],
    yticklabels=[f"Sample {i + 1}" for i in range(10)],
)
plt.title("Row Archetype Weights for First 10 Samples")
plt.xlabel("Row Archetypes")
plt.ylabel("Samples")
plt.tight_layout()
plt.show()

In [None]:
# Reconstruct the data
X_reconstructed = model.reconstruct(X)
print(f"Shape of reconstructed data: {X_reconstructed.shape}")

# Compare original and reconstructed data
plt.figure(figsize=(15, 6))

plt.subplot(1, 2, 1)
plt.scatter(X[:, 0], X[:, 1], c=y, cmap="viridis", alpha=0.7)
plt.title("Original Data")
plt.xlabel("Feature 1")
plt.ylabel("Feature 2")
plt.grid(alpha=0.3)

plt.subplot(1, 2, 2)
plt.scatter(X_reconstructed[:, 0], X_reconstructed[:, 1], c=y, cmap="viridis", alpha=0.7)
plt.title("Reconstructed Data")
plt.xlabel("Feature 1")
plt.ylabel("Feature 2")
plt.grid(alpha=0.3)

plt.tight_layout()
plt.show()

## 6. Experiment with a Different Dataset (Moon-shaped Data)


In [None]:
# Generate moon-shaped data
X_moons, y_moons = make_moons(n_samples=300, noise=0.1, random_state=42)
X_moons = StandardScaler().fit_transform(X_moons)

# Visualize the data
plt.figure(figsize=(10, 6))
plt.scatter(X_moons[:, 0], X_moons[:, 1], c=y_moons, cmap="coolwarm", alpha=0.7)
plt.title("Moon-shaped Data")
plt.xlabel("Feature 1")
plt.ylabel("Feature 2")
plt.colorbar(label="Class")
plt.grid(alpha=0.3)
plt.show()

In [None]:
# Initialize and train the Biarchetypal Analysis model
model_moons = BiarchetypalAnalysis(
    n_row_archetypes=4,  # Increased number of archetypes for complex shapes
    n_col_archetypes=2,
    max_iter=500,
    random_seed=42,
)

# Train the model
model_moons.fit(X_moons, normalize=True)

In [None]:
# Retrieve and visualize row archetypes
row_archetypes_moons = model_moons.get_row_archetypes()

plt.figure(figsize=(10, 6))
plt.scatter(X_moons[:, 0], X_moons[:, 1], c=y_moons, cmap="coolwarm", alpha=0.5)
plt.scatter(row_archetypes_moons[:, 0], row_archetypes_moons[:, 1], c="red", s=100, marker="*", label="Row archetypes")

# Label each archetype
for i, archetype in enumerate(row_archetypes_moons):
    plt.annotate(f"A{i + 1}", (archetype[0], archetype[1]), fontsize=12, xytext=(10, 10), textcoords="offset points")

plt.title("Moon-shaped Data and Row Archetypes")
plt.xlabel("Feature 1")
plt.ylabel("Feature 2")
plt.legend()
plt.grid(alpha=0.3)
plt.show()

In [None]:
# Reconstruct the data
X_moons_reconstructed = model_moons.reconstruct(X_moons)

# Compare original and reconstructed data
plt.figure(figsize=(15, 6))

plt.subplot(1, 2, 1)
plt.scatter(X_moons[:, 0], X_moons[:, 1], c=y_moons, cmap="coolwarm", alpha=0.7)
plt.title("Original Moon-shaped Data")
plt.xlabel("Feature 1")
plt.ylabel("Feature 2")
plt.grid(alpha=0.3)

plt.subplot(1, 2, 2)
plt.scatter(X_moons_reconstructed[:, 0], X_moons_reconstructed[:, 1], c=y_moons, cmap="coolwarm", alpha=0.7)
plt.title("Reconstructed Moon-shaped Data")
plt.xlabel("Feature 1")
plt.ylabel("Feature 2")
plt.grid(alpha=0.3)

plt.tight_layout()
plt.show()

## 7. Application to High-Dimensional Data


In [None]:
# Generate high-dimensional data (10 dimensions, 5 clusters)
X_high, y_high = make_blobs(n_samples=500, centers=5, n_features=10, random_state=42)
X_high = StandardScaler().fit_transform(X_high)
print(f"Shape of high-dimensional data: {X_high.shape}")

# Initialize and train the Biarchetypal Analysis model
model_high = BiarchetypalAnalysis(
    n_row_archetypes=5,  # Corresponding to 5 clusters
    n_col_archetypes=3,  # Dimensionality reduction for features
    max_iter=500,
    random_seed=42,
)

# Train the model
model_high.fit(X_high, normalize=True)

In [None]:
# Retrieve and visualize the biarchetypes matrix
biarchetypes_high = model_high.get_biarchetypes()
print(f"Shape of biarchetypes matrix: {biarchetypes_high.shape}")

plt.figure(figsize=(10, 8))
sns.heatmap(
    biarchetypes_high,
    annot=True,
    cmap="viridis",
    fmt=".2f",
    xticklabels=[f"C{i + 1}" for i in range(model_high.n_col_archetypes)],
    yticklabels=[f"R{i + 1}" for i in range(model_high.n_row_archetypes)],
)
plt.title("Biarchetypes Matrix for High-Dimensional Data")
plt.xlabel("Column Archetypes")
plt.ylabel("Row Archetypes")
plt.tight_layout()
plt.show()

In [None]:
# Dimensionality reduction using row weights
row_weights_high, _ = model_high.transform(X_high)

# Create a scatter plot using the first two row archetype weights
plt.figure(figsize=(10, 8))
plt.scatter(row_weights_high[:, 0], row_weights_high[:, 1], c=y_high, cmap="tab10", alpha=0.7)
plt.title("Dimensionality Reduction Using Row Archetype Weights (First 2 Dimensions)")
plt.xlabel("Weight of Row Archetype 1")
plt.ylabel("Weight of Row Archetype 2")
plt.colorbar(label="Class")
plt.grid(alpha=0.3)
plt.show()

## 8. Summary and Discussion


In [None]:
# Summary of advantages and features of Biarchetypal Analysis
advantages = [
    "Simultaneously learns archetypes for both rows and columns",
    "Captures extreme patterns in data",
    "Can be used for both dimensionality reduction and feature extraction",
    "Highly interpretable (each data point is represented as a combination of extremes)",
    "Can capture non-linear patterns (as shown in the moon-shaped data example)",
]

print("Advantages of Biarchetypal Analysis:")
for i, adv in enumerate(advantages, 1):
    print(f"{i}. {adv}")

# Application examples
applications = [
    "Customer segmentation (identifying extreme customer profiles)",
    "Image analysis (extracting distinctive patterns)",
    "Text analysis (simultaneous analysis of distinctive documents and words)",
    "Financial data analysis (identifying extreme market conditions)",
    "Biomedical data analysis (linking distinctive patient groups with biomarkers)",
]

print("\nApplications of Biarchetypal Analysis:")
for i, app in enumerate(applications, 1):
    print(f"{i}. {app}")