In [None]:
import os
import sys
import time

import jax
import jax.numpy as jnp
import matplotlib.animation as animation
import matplotlib.pyplot as plt
import numpy as np
import optax
from matplotlib import animation
from matplotlib.patches import Polygon
from scipy.spatial import ConvexHull
from sklearn.datasets import make_blobs, make_moons


sys.path.insert(0, os.path.abspath("../"))

from archetypax.models import ImprovedArchetypalAnalysis
from archetypax.models.archetypes import ArchetypeTracker
from archetypax.tools.evaluation import ArchetypalAnalysisEvaluator

In [None]:
def generate_data(n_samples=10000, n_centers=3, noise=0.05, random_seed=42):
    np.random.seed(random_seed)

    # Create 3 clusters with some overlap
    cluster1 = np.random.randn(n_samples // n_centers, 2) * 0.5 + np.array([2, 2])
    cluster2 = np.random.randn(n_samples // n_centers, 2) * 0.5 + np.array([-2, 2])
    cluster3 = np.random.randn(n_samples // n_centers, 2) * 0.5 + np.array([0, -2])
    X = np.vstack([cluster1, cluster2, cluster3])
    y = np.zeros(n_samples)
    y[n_samples // n_centers:2 * n_samples // n_centers] = 1
    y[2 * n_samples // n_centers:] = 2

    X, y = make_moons(n_samples=n_samples, noise=noise, random_state=random_seed)

    # Shuffle data and labels together to maintain correspondence
    indices = np.arange(X.shape[0])
    np.random.shuffle(indices)
    X = X[indices]
    y = y[indices]

    return X, y


def visualize_data(X, y=None):
    """Visualize the generated data."""
    plt.figure(figsize=(10, 8))
    plt.scatter(X[:, 0], X[:, 1], alpha=0.6, c=y, cmap="viridis")
    plt.title("Generated 2D Data")
    plt.xlabel("Feature 1")
    plt.ylabel("Feature 2")
    plt.grid(True, linestyle="--", alpha=0.7)
    plt.show()


def visualize_final_archetypes(X, model):
    """Visualize the final positions of the archetypes."""
    plt.figure(figsize=(12, 10))

    # Plot data points
    plt.scatter(X[:, 0], X[:, 1], alpha=0.5, label="Data Points")

    # Plot final archetypes
    plt.scatter(model.archetypes[:, 0], model.archetypes[:, 1], c="red", s=200, marker="*", label="Archetypes")

    # Annotate archetypes
    for i, archetype in enumerate(model.archetypes):
        plt.annotate(
            f"A{i + 1}", (archetype[0], archetype[1]), fontsize=12, xytext=(10, 10), textcoords="offset points"
        )

    # Draw the convex hull of the archetypes
    if len(model.archetypes) > 2:  # Use ConvexHull if there are more than 2 points
        try:
            hull = ConvexHull(model.archetypes)
            hull_vertices = model.archetypes[hull.vertices]
            # Draw the vertices of the convex hull in order
            plt.fill(hull_vertices[:, 0], hull_vertices[:, 1], alpha=0.2, color="red")
        except Exception as e:
            print(f"Failed to compute the convex hull: {e}")
    else:  # Draw as a line or point if there are 2 or fewer points
        if len(model.archetypes) == 2:
            plt.plot(model.archetypes[:, 0], model.archetypes[:, 1], "r-", alpha=0.2)

    plt.title("Relationship Between Data and Archetypes")
    plt.xlabel("Feature 1")
    plt.ylabel("Feature 2")
    plt.grid(True, linestyle="--", alpha=0.7)
    plt.legend()
    plt.show()


def visualize_loss_history(model):
    """Visualize the progression of loss."""
    plt.figure(figsize=(10, 6))
    plt.plot(model.loss_history, marker="o", linestyle="-", markersize=4)
    plt.title("Progression of Loss per Iteration")
    plt.xlabel("Iteration")
    plt.ylabel("Loss")
    plt.grid(True, linestyle="--", alpha=0.7)
    plt.yscale("log")  # Display on a logarithmic scale
    plt.show()


def create_animation(X, model, projection_method):
    """Visualize the movement of archetypes through animation."""
    n_archetypes = model.n_archetypes

    # Create the animation
    fig, ax = plt.subplots(figsize=(12, 10))

    # Plot data points (fixed)
    ax.scatter(X[:, 0], X[:, 1], alpha=0.3, color="blue")

    # Initial positions of archetypes
    archetype_scatter = ax.scatter(
        model.archetype_history[0][:, 0], model.archetype_history[0][:, 1], c="red", s=200, marker="*"
    )

    # Empty list to store the trails of archetypes
    trails = []
    for i in range(n_archetypes):
        (trail,) = ax.plot([], [], "r-", alpha=0.3)
        trails.append(trail)

    # Labels for archetypes
    archetype_labels = []
    for i in range(n_archetypes):
        label = ax.annotate(
            f"A{i + 1}",
            (model.archetype_history[0][i, 0], model.archetype_history[0][i, 1]),
            fontsize=12,
            xytext=(10, 10),
            textcoords="offset points",
        )
        archetype_labels.append(label)

    # Initial setup for the convex hull
    hull_patch = None
    hull_line = None

    # Variable for updating animation frames
    hull_artists = []

    # Set up the graph
    ax.set_title("Movement of Archetypes")
    ax.set_xlabel("Feature 1")
    ax.set_ylabel("Feature 2")
    ax.grid(True, linestyle="--", alpha=0.7)

    # Text to display iteration number and loss
    iteration_text = ax.text(0.02, 0.95, "", transform=ax.transAxes, fontsize=12)
    loss_text = ax.text(0.02, 0.90, "", transform=ax.transAxes, fontsize=12)

    # Animation update function
    def update(frame):
        # Limit the frame number (to not exceed the length of history)
        frame = min(frame, len(model.archetype_history) - 1)

        # Create a list of artists
        artists = [archetype_scatter, *trails, *archetype_labels, iteration_text, loss_text]

        # Update positions of archetypes
        archetype_scatter.set_offsets(model.archetype_history[frame])

        # Update trails
        for i in range(n_archetypes):
            # Draw the trail up to the current frame
            x_trail = [model.archetype_history[j][i, 0] for j in range(frame + 1)]
            y_trail = [model.archetype_history[j][i, 1] for j in range(frame + 1)]
            trails[i].set_data(x_trail, y_trail)

        # Update label positions
        for i in range(n_archetypes):
            archetype_labels[i].set_position((
                model.archetype_history[frame][i, 0],
                model.archetype_history[frame][i, 1],
            ))

        # Remove the previous convex hull
        for artist in hull_artists:
            if artist in ax.get_children():
                artist.remove()
        hull_artists.clear()

        # Update the convex hull
        archetypes_frame = model.archetype_history[frame]
        if len(archetypes_frame) > 2:  # If there are more than 2 points
            try:
                hull = ConvexHull(archetypes_frame)
                hull_vertices = archetypes_frame[hull.vertices]
                # Draw as a closed polygon
                poly = plt.fill(hull_vertices[:, 0], hull_vertices[:, 1], alpha=0.2, color="red")[0]
                hull_artists.append(poly)
                artists.append(poly)
            except Exception as e:
                print(f"Failed to compute the convex hull at frame {frame}: {e}")
        elif len(archetypes_frame) == 2:  # Draw as a line if there are 2 points
            (line,) = ax.plot(archetypes_frame[:, 0], archetypes_frame[:, 1], "r-", alpha=0.2)
            hull_artists.append(line)
            artists.append(line)

        # Update text information
        iteration_text.set_text(f"Iteration: {frame}")
        if frame < len(model.loss_history):
            loss_text.set_text(f"Loss: {model.loss_history[frame]:.6f}")

        return artists

    # Create the animation
    ani = animation.FuncAnimation(
        fig,
        update,
        frames=len(model.archetype_history),
        interval=200,  # Delay between frames in milliseconds
        blit=True,
    )

    # Save the animation (optional)
    try:
        print("Saving the animation as a GIF file...")
        ani.save(f"./output/archetype_movement_{projection_method}.gif", writer="pillow", fps=5)
        print(f"Animation saved: archetype_movement_{projection_method}.gif")
    except Exception as e:
        print(f"An error occurred while saving the animation: {e}")
        print("Skipping the saving of the animation.")

    plt.show()


In [None]:
# Generate data
X, y = generate_data()
visualize_data(X, y)

In [None]:
# Initialize and train the model
n_archetypes = 3
projection_methods = ["cbap", "convex_hull", "knn"]
archetype_init_methods = ["directional", "qhull", "kmeans_pp"]
for archetype_init_method in archetype_init_methods:
    for projection_method in projection_methods:
        print("--------------------------------")
        print(f"Projection method: {projection_method}")
        print(f"Archetype init method: {archetype_init_method}")
        print("--------------------------------")
        model = ArchetypeTracker(
            n_archetypes=n_archetypes,
            max_iter=500,  # Limit the number of iterations for easier visualization
            tol=1e-10,
            learning_rate=0.0001,
            projection_method=projection_method,  # Use the default projection method
            projection_alpha=0.05,
            archetype_init_method=archetype_init_method,
        )

        # Train the model
        start_time = time.time()
        model.fit(X)
        end_time = time.time()

        print(f"Training time: {end_time - start_time:.2f} seconds")
        print(f"Number of iterations: {len(model.loss_history)}")
        print(f"Final loss: {model.loss_history[-1]:.6f}")

        # Visualize the final archetypes
        visualize_final_archetypes(X, model)
        visualize_loss_history(model)
        create_animation(X, model, f"{projection_method}_{archetype_init_method}")