In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

from archetypax.models import ImprovedArchetypalAnalysis
from archetypax.tools import ArchetypalAnalysisEvaluator, ArchetypalAnalysisInterpreter, ArchetypalAnalysisVisualizer

In [None]:
X = np.random.randn(2000, 30)

# Train models with varying numbers of archetypes
models = {}
for k in range(2, 11):
    model = ImprovedArchetypalAnalysis(n_archetypes=k)
    model.fit(X)
    models[k] = model

In [None]:
# Evaluate interpretability
interpreter = ArchetypalAnalysisInterpreter(models)
interpreter.evaluate_all_models(X)
interpreter.plot_interpretability_metrics()

# Suggest optimal number of archetypes
optimal_k = interpreter.suggest_optimal_archetypes(method="balance")
print(f"Optimal number of archetypes: {optimal_k}")

# Examine detailed interpretability of the best model
best_model = models[optimal_k]
best_results = interpreter.results[optimal_k]
print(f"Interpretability score: {best_results['interpretability_score']:.4f}")
print(f"Information gain: {best_results.get('information_gain', 'N/A')}")

In [None]:
# Usage example
# Generate sample data
np.random.seed(42)
n_samples = 66

# Create 3 clusters with some overlap
cluster1 = np.random.randn(n_samples // 3, 2) * 0.5 + np.array([2, 2])
cluster2 = np.random.randn(n_samples // 3, 2) * 0.5 + np.array([-2, 2])
cluster3 = np.random.randn(n_samples // 3, 2) * 0.5 + np.array([0, -2])

X = np.vstack([cluster2, cluster1, cluster3])
for _ in range(5):
    np.random.shuffle(X)
print(f"Data shape: {X.shape}")

# Train models with varying numbers of archetypes
models = {}
for k in range(2, 11):
    model = ImprovedArchetypalAnalysis(n_archetypes=k)
    model.fit(X)
    models[k] = model

# Evaluate interpretability
interpreter = ArchetypalAnalysisInterpreter(models)
interpreter.evaluate_all_models(X)
interpreter.plot_interpretability_metrics()

# Suggest optimal number of archetypes
optimal_k = interpreter.suggest_optimal_archetypes(method="balance")
print(f"Optimal number of archetypes: {optimal_k}")

# Examine detailed interpretability of the best model
best_model = models[optimal_k]
best_results = interpreter.results[optimal_k]
print(f"Interpretability score: {best_results['interpretability_score']:.4f}")
print(f"Information gain: {best_results.get('information_gain', 'N/A')}")

In [None]:
# Plot results
visualizer = ArchetypalAnalysisVisualizer()
visualizer.plot_loss(model)
visualizer.plot_membership_weights(model, n_samples=10)
visualizer.plot_archetype_distribution(model)
visualizer.plot_archetypes_2d(model, X)
visualizer.plot_reconstruction_comparison(model, X)

if model.n_archetypes == 3:
    visualizer.plot_simplex_2d(model)

In [None]:
# Initialize the evaluator
evaluator = ArchetypalAnalysisEvaluator(model)

# Generate comprehensive evaluation report
evaluator.print_evaluation_report(X)

# Calculate individual evaluation metrics
explained_var = evaluator.explained_variance(X)
purity = evaluator.dominant_archetype_purity()

# Visualize high-dimensional data
# evaluator.plot_feature_importance_heatmap(feature_names=column_names)
evaluator.plot_archetype_feature_comparison(top_n=10)
evaluator.plot_weight_distributions()
evaluator.plot_distance_matrix()
evaluator.plot_entropy_vs_reconstruction(X)

# Demonstration


In [None]:
import datetime
import warnings

import pandas as pd

warnings.filterwarnings("ignore")


def generate_customer_transactions(customer_id, n_transactions, start_date, products_df, archetype_type):
    transactions = []

    if archetype_type == "bulk_buyer":
        visit_frequency = max(1, int(np.random.normal(14, 3)))
        basket_size_mean = 15
        basket_size_std = 3
        weekend_prob = 0.8
        evening_prob = 0.3
        health_prob = 0.4
        premium_prob = 0.3
        discount_sensitivity = 0.7

    elif archetype_type == "premium_daily":
        visit_frequency = max(1, int(np.random.normal(3, 1)))
        basket_size_mean = 6
        basket_size_std = 2
        weekend_prob = 0.4
        evening_prob = 0.5
        health_prob = 0.7
        premium_prob = 0.8
        discount_sensitivity = 0.2

    elif archetype_type == "time_saving":
        visit_frequency = max(1, int(np.random.normal(4, 1)))
        basket_size_mean = 4
        basket_size_std = 1
        weekend_prob = 0.2
        evening_prob = 0.8
        health_prob = 0.3
        premium_prob = 0.4
        discount_sensitivity = 0.5

    else:  # 'price_sensitive'
        visit_frequency = max(1, int(np.random.normal(7, 2)))
        basket_size_mean = 8
        basket_size_std = 2
        weekend_prob = 0.6
        evening_prob = 0.4
        health_prob = 0.3
        premium_prob = 0.1
        discount_sensitivity = 0.9

    current_date = start_date
    for i in range(n_transactions):
        days_increment = max(1, int(np.random.normal(visit_frequency, visit_frequency / 3)))
        current_date += datetime.timedelta(days=days_increment)

        is_weekend = np.random.random() < weekend_prob
        day_of_week = np.random.choice([5, 6]) if is_weekend else np.random.randint(0, 5)

        is_evening = np.random.random() < evening_prob
        hour = np.random.randint(17, 23) if is_evening else np.random.randint(9, 17)

        days_to_add = int((day_of_week - current_date.weekday()) % 7)
        txn_date = current_date + datetime.timedelta(days=days_to_add)
        txn_datetime = txn_date.replace(hour=hour, minute=np.random.randint(0, 60))

        receipt_id = f"R{customer_id[1:]}{i:03d}"

        basket_size = max(1, int(np.random.normal(basket_size_mean, basket_size_std)))

        for _ in range(basket_size):
            if archetype_type == "bulk_buyer":
                if _ > 0 and np.random.random() < 0.4:
                    cat_focus = transactions[-1]["category_id"]
                    product_pool = products_df[products_df["category_id"] == cat_focus]
                else:
                    product_pool = products_df

            elif archetype_type == "premium_daily":
                if np.random.random() < premium_prob:
                    premium_threshold = products_df["standard_price"].quantile(0.7)
                    product_pool = products_df[products_df["standard_price"] >= premium_threshold]
                elif np.random.random() < health_prob:
                    health_threshold = 0.6
                    product_pool = products_df[products_df["health_index"] >= health_threshold]
                else:
                    product_pool = products_df

            elif archetype_type == "time_saving":
                if np.random.random() < 0.6:
                    product_pool = products_df[products_df["category_id"].isin(["CAT07"])]
                else:
                    product_pool = products_df

            else:  # 'price_sensitive'
                if np.random.random() < 0.6:
                    product_pool = products_df[products_df["brand_type"] == "PB"]
                else:
                    product_pool = products_df

            if len(product_pool) == 0:
                product_pool = products_df

            product = product_pool.sample(1).iloc[0]

            base_price = product["standard_price"]

            discount_rate = 0
            if np.random.random() < discount_sensitivity:
                discount_rate = np.random.choice([0.1, 0.2, 0.3, 0.5], p=[0.4, 0.3, 0.2, 0.1])

            discount_amount = int(base_price * discount_rate)
            final_price = base_price - discount_amount

            quantity = 1
            if archetype_type == "bulk_buyer" and np.random.random() < 0.3:
                quantity = np.random.randint(2, 5)

            transactions.append({
                "customer_id": customer_id,
                "purchase_datetime": txn_datetime,
                "receipt_id": receipt_id,
                "product_id": product["product_id"],
                "category_id": product["category_id"],
                "quantity": quantity,
                "amount": final_price * quantity,
                "discount_amount": discount_amount * quantity,
                "standard_price": base_price * quantity,
            })

    return transactions


n_customers = 1000
n_transactions_per_customer = 30
n_products = 100
n_categories = 10

products = pd.DataFrame({
    "product_id": [f"P{i:03d}" for i in range(1, n_products + 1)],
    "category_id": [f"CAT{np.random.randint(1, n_categories + 1):02d}" for _ in range(n_products)],
    "standard_price": np.random.randint(100, 2000, size=n_products),
    "brand_type": np.random.choice(["NB", "PB"], size=n_products, p=[0.7, 0.3]),
    "health_index": np.random.rand(n_products),
})

category_labels = {
    "CAT01": "Vegetables & Fruits",
    "CAT02": "Meat & Fish",
    "CAT03": "Dairy Products",
    "CAT04": "Beverages",
    "CAT05": "Confectionery",
    "CAT06": "Condiments",
    "CAT07": "Ready-made Foods",
    "CAT08": "Daily Necessities",
    "CAT09": "Health Foods",
    "CAT10": "Other Foods",
}
products["category_name"] = products["category_id"].map(category_labels)

archetypes = ["bulk_buyer", "premium_daily", "time_saving", "price_sensitive"]
archetype_probs = [0.25, 0.25, 0.25, 0.25]

customers = pd.DataFrame({
    "customer_id": [f"C{i:04d}" for i in range(1, n_customers + 1)],
    "archetype": np.random.choice(archetypes, size=n_customers, p=archetype_probs),
})

all_transactions = []
start_date = datetime.datetime(2023, 1, 1)

print("Generating transaction data...")
for idx, customer in customers.iterrows():
    customer_txn_count = max(10, int(np.random.normal(n_transactions_per_customer, 5)))

    txns = generate_customer_transactions(
        customer["customer_id"], customer_txn_count, start_date, products, customer["archetype"]
    )
    all_transactions.extend(txns)

    if (idx + 1) % 50 == 0:
        print(f"{idx + 1} / {len(customers)} customer data generation completed")

transactions_df = pd.DataFrame(all_transactions)
print(f"Generation complete: {len(transactions_df)} transactions")

print("\nProducts data sample:")
display(products.head())

print("\nTransaction data sample:")
display(transactions_df.head())


def extract_customer_features(transactions, customer_id):
    customer_txn = transactions[transactions["customer_id"] == customer_id]

    features = {}
    features["customer_id"] = customer_id

    date_range = (customer_txn["purchase_datetime"].max() - customer_txn["purchase_datetime"].min()).days
    if date_range > 0:
        features["purchase_frequency"] = len(customer_txn["receipt_id"].unique()) / (date_range / 30)
    else:
        features["purchase_frequency"] = 1

    receipt_amounts = customer_txn.groupby("receipt_id")["amount"].sum()
    features["average_purchase_amount"] = receipt_amounts.mean()

    basket_sizes = customer_txn.groupby("receipt_id").size()
    features["average_basket_size"] = basket_sizes.mean()

    customer_txn["day_of_week"] = pd.to_datetime(customer_txn["purchase_datetime"]).dt.dayofweek
    weekend_receipts = customer_txn[customer_txn["day_of_week"] >= 5]["receipt_id"].nunique()
    total_receipts = customer_txn["receipt_id"].nunique()
    features["weekend_purchase_ratio"] = weekend_receipts / total_receipts if total_receipts > 0 else 0

    customer_txn["hour"] = pd.to_datetime(customer_txn["purchase_datetime"]).dt.hour
    evening_receipts = customer_txn[customer_txn["hour"] >= 18]["receipt_id"].nunique()
    features["evening_purchase_ratio"] = evening_receipts / total_receipts if total_receipts > 0 else 0

    discount_items = customer_txn[customer_txn["discount_amount"] > 0]
    features["discounted_item_ratio"] = len(discount_items) / len(customer_txn) if len(customer_txn) > 0 else 0

    if len(discount_items) > 0:
        original_prices = discount_items["amount"] + discount_items["discount_amount"]
        features["average_discount_rate"] = (discount_items["discount_amount"] / original_prices).mean()
    else:
        features["average_discount_rate"] = 0

    prepared_food = customer_txn[customer_txn["category_id"] == "CAT07"]
    features["prepared_food_ratio"] = len(prepared_food) / len(customer_txn) if len(customer_txn) > 0 else 0

    health_food = customer_txn[customer_txn["category_id"] == "CAT09"]
    features["health_food_ratio"] = len(health_food) / len(customer_txn) if len(customer_txn) > 0 else 0

    features["product_diversity"] = (
        customer_txn["product_id"].nunique() / len(customer_txn) if len(customer_txn) > 0 else 0
    )

    if len(customer_txn["receipt_id"].unique()) > 1:
        purchase_dates = pd.to_datetime(customer_txn.drop_duplicates("receipt_id")["purchase_datetime"]).sort_values()
        intervals = np.diff(purchase_dates) / np.timedelta64(1, "D")
        features["purchase_interval_variability"] = intervals.std() if len(intervals) > 0 else 0
    else:
        features["purchase_interval_variability"] = 0

    return features


print("\nFeature engineering in progress...")
customer_features_list = []
for i, customer_id in enumerate(customers["customer_id"]):
    features = extract_customer_features(transactions_df, customer_id)
    customer_features_list.append(features)
    if (i + 1) % 50 == 0:
        print(f"{i + 1} / {len(customers)} customer feature extraction completed")

customer_features_df = pd.DataFrame(customer_features_list)
customer_features_df = customer_features_df.set_index("customer_id")
customer_features_df = customer_features_df.merge(customers[["customer_id", "archetype"]], on="customer_id", how="left")

print("\nCustomer feature data sample:")
display(customer_features_df.head())

In [None]:
from sklearn.preprocessing import RobustScaler

feature_cols = [col for col in customer_features_df.columns if col not in ["customer_id", "archetype"]]
X = customer_features_df[feature_cols]

corr_matrix = X.corr()
plt.figure(figsize=(10, 8))
sns.heatmap(corr_matrix, annot=True, cmap="coolwarm")
plt.title("Feature Correlation Matrix")
plt.show()

threshold = 0.99
high_corr_features = set()
for i in range(len(corr_matrix.columns)):
    for j in range(i):
        if abs(corr_matrix.iloc[i, j]) > threshold:
            high_corr_features.add(corr_matrix.columns[i])
print(f"Features with high correlation: {high_corr_features}")

scaler = RobustScaler()
X_reduced = X.drop(columns=list(high_corr_features))
X_scaled = scaler.fit_transform(X_reduced)
print(f"{X.shape[0]=}, {X_reduced.shape[0]=}, {X_scaled.shape[0]=}")

In [None]:
# Train models with varying numbers of archetypes
models = {}
for k in range(2, 11):
    model = ImprovedArchetypalAnalysis(n_archetypes=k, max_iter=500, tol=1e-10, learning_rate=0.0001)
    model.fit(X_scaled)
    models[k] = model

# Evaluate interpretability
interpreter = ArchetypalAnalysisInterpreter(models)
interpreter.evaluate_all_models(X_scaled)
interpreter.plot_interpretability_metrics()

# Suggest optimal number of archetypes
optimal_k = interpreter.suggest_optimal_archetypes(method="balance")
print(f"Optimal number of archetypes: {optimal_k}")

# Examine detailed interpretability of the best model
best_model = models[optimal_k]
best_results = interpreter.results[optimal_k]
print(f"Interpretability score: {best_results['interpretability_score']:.4f}")
print(f"Information gain: {best_results.get('information_gain', 'N/A')}")

In [None]:
model = ImprovedArchetypalAnalysis(n_archetypes=optimal_k, max_iter=5000, tol=1e-10, learning_rate=0.0001)
model.fit(X_scaled)

# Plot results
visualizer = ArchetypalAnalysisVisualizer()
visualizer.plot_loss(model)
visualizer.plot_membership_weights(model, n_samples=10)
visualizer.plot_archetype_distribution(model)

if model.n_archetypes == 2:
    visualizer.plot_archetypes_2d(model, X)
    visualizer.plot_reconstruction_comparison(model, X)

if model.n_archetypes == 3:
    visualizer.plot_simplex_2d(model)

In [None]:
# Initialize the evaluator
evaluator = ArchetypalAnalysisEvaluator(model)

# Generate comprehensive evaluation report
evaluator.print_evaluation_report(X)

# Calculate individual evaluation metrics
explained_var = evaluator.explained_variance(X)
purity = evaluator.dominant_archetype_purity()

# Visualize high-dimensional data
evaluator.plot_feature_importance_heatmap(feature_names=X_reduced.columns)
evaluator.plot_archetype_feature_comparison(top_n=10, feature_names=X_reduced.columns)
evaluator.plot_weight_distributions(bins=100)
evaluator.plot_distance_matrix()
evaluator.plot_entropy_vs_reconstruction(X_scaled)