In [42]:
import numpy as np
import matplotlib.pyplot as plt
import random
from sklearn.datasets import make_blobs
from clustering import PCKmeans

%matplotlib qt

def test_pckmeans():
    """
    A simple test function for PCKmeans.
    1) Creates synthetic blobs (clusters).
    2) Defines a few must-link and cannot-link constraints.
    3) Runs PCKmeans and prints results.
    """

    # Set random seed for reproducibility
    seed = 42
    np.random.seed(seed)
    random.seed(seed)

    # 1) Generate synthetic data
    # Let's create 3 blobs with 2D features.

    # Number of points per cluster
    n = 1000

    # Means for the three clusters (in 2D)
    means = [(3, 0), (6, 2), (2, 6)]

    # Covariance matrices for the three clusters
    covs = [
        [[1, 0], 
        [0, 1]],
        [[1, 0], 
        [0, 1]],
        [[1, 0], 
        [0, 7]]
    ]

    X_list = []
    y_list = []
    for idx, (mean, cov) in enumerate(zip(means, covs)):
        # Draw samples from a 2D Gaussian with the given mean and covariance
        
        X_cluster = np.random.multivariate_normal(mean, cov, n)
        y_cluster = np.full(n, idx)
        X_list.append(X_cluster)
        y_list.append(y_cluster)

    X = np.vstack(X_list)
    y_true = np.concatenate(y_list)


    # 2) Define constraints
    # We will pick a few must-link and cannot-link pairs manually.
    
    # Must-link pairs: pick some points in the same true cluster
    #   e.g., we know points 0 and 1 belong to cluster 0 in y_true => must link them
    #   similarly for points 10 and 11 => cluster 1
    #   and points 20 and 21 => cluster 2
    labeled_indices = np.random.choice(np.arange(3000), size=30, replace=False)
    labels = y_true[labeled_indices]
    #construct all must-link pairs according to their labels
    must_link_pairs = []
    for cls in np.unique(labels):
        indices = np.where(labels == cls)[0]
        must_link_pairs += [(labeled_indices[i], labeled_indices[j]) for i in indices for j in indices if i < j]
    
    # construct all cannot-link pairs according to their labels
    cannot_link_pairs = []
    for i in range(len(labels)):
        for j in range(i + 1, len(labels)):
            if labels[i] != labels[j]:
                cannot_link_pairs.append((labeled_indices[i], labeled_indices[j]))

    



    constraints = {
        'must_link': must_link_pairs,
        'cannot_link': cannot_link_pairs
    }

    # 3) Instantiate PCKmeans
    # We want k=3 clusters. 
    # Provide constraints, set must_link_weight high to enforce must-link constraints strongly.
    pck = PCKmeans(
        k=3,
        device='cpu',
        plot=True,  # Turn off plotting for this minimal test
        constraints=constraints,
        labeled_indices=labeled_indices,
        max_iter=1,
        penalty_weight=1.0,      # penalty for cannot-link violation
        must_link_weight=5.0,    # stronger penalty for must-link violation
    )

    # 4) Run clustering
    fig, axes = plt.subplots(1, 2, figsize=(10, 4))
    final_loss = pck.cluster(
        fig=fig, 
        axes=axes, 
        x_data=X, 
        true_labels=y_true, 
        epoch=1, 
        verbose=True
    )

    # 5) Print results: assignments, final cluster centers, constraint violations
    print("\n===== Test Results =====")
    print(f"Final Loss: {final_loss:.4f}")
    print(f"Labels assigned: {pck.labels_}")
    print(f"Cluster Centers:\n{pck.cluster_centers_}")
    print("Images Lists (cluster -> indices):")
    for cluster_id, indices in enumerate(pck.images_lists):
        print(f"  Cluster {cluster_id}: {indices}")

    # Must-link and cannot-link checks
    must_link_violations, cannot_link_violations = pck._compute_constraint_violations(pck.labels_)
    print(f"\nMust-link violations: {len(must_link_violations)} / {len(must_link_pairs)}")
    print(f"Cannot-link violations: {len(cannot_link_violations)} / {len(cannot_link_pairs)}")

    plt.show(block=False)
    plt.pause(3)
    # plt.close(fig)

# If you want to run this immediately in your script:
if __name__ == '__main__':
    test_pckmeans()



Constraint Violation Details:
Must-link violations (0/154)

Cannot-link violations (60/281)
  Pair (2165,1647): both assigned to cluster 1
  Pair (2165,1907): both assigned to cluster 1
  Pair (2165,1353): both assigned to cluster 1
  Pair (2165,1409): both assigned to cluster 1
  Pair (2165,1564): both assigned to cluster 1

Must-link constraint satisfaction: 154/154 (100.0%)
[PCKmeans] Finished in 3.23s, final_loss=12104.9691, #supernodes=2973

===== Test Results =====
Final Loss: 12104.9691
Labels assigned: [0 0 2 ... 1 1 1]
Cluster Centers:
[[3.9468226  1.5666414 ]
 [4.2381186  4.9412184 ]
 [2.2053146  0.21160527]]
Images Lists (cluster -> indices):
  Cluster 0: [0, 1, 3, 10, 15, 20, 27, 29, 33, 35, 36, 41, 43, 45, 48, 53, 56, 59, 61, 62, 70, 72, 74, 75, 78, 81, 82, 83, 87, 88, 89, 93, 97, 100, 101, 103, 105, 106, 110, 117, 121, 124, 126, 128, 129, 133, 135, 136, 140, 141, 142, 148, 149, 151, 153, 154, 155, 156, 157, 158, 160, 161, 163, 164, 167, 175, 180, 182, 185, 186, 187, 189,