In [3]:
# Minimal parity check between your Torch KMeans and scikit-learn's KMeans.
# - Samples n points in R^d with torch
# - Compares only inertia and the first few cluster centers (no alignment, no labels)
#
# Usage:
# 1) Put your class in ./kmeans.py (same folder as the notebook)
# 2) Ensure its imports resolve (utils/distances, etc.), or adjust the import below.
# 3) Make sure scikit-learn is installed: pip install scikit-learn
#
# Feel free to edit n, d, k, seed, and init_method.

import os, sys, importlib
import numpy as np
import torch

# -------- params you may tweak --------
n = 1000
d = 3
k = 5
seed = 0
init_method = "rnd"   # or "rnd" to match your class
num_init = 10
max_iter = 300
tol = 1e-4
first_centers_to_show = 3
# -------------------------------------

torch.manual_seed(seed)

# Sample n points in R^d from a simple Gaussian mixture (uniform cluster prior)
means = torch.randn(k, d) * 4.0
labels = torch.randint(0, k, (n,))
X = means[labels] + 0.5 * torch.randn(n, d)  # (n, d)

# Try importing your class from ./kmeans.py
TorchKMeans = None
LpDistance = None
try:
    if os.path.exists("./kmeans.py"):
        spec = importlib.util.spec_from_file_location("kmeans_local", "./kmeans.py")
        kmeans_local = importlib.util.module_from_spec(spec)
        sys.modules["kmeans_local"] = kmeans_local
        spec.loader.exec_module(kmeans_local)  # type: ignore
        TorchKMeans = kmeans_local.KMeans
        # If your class expects the distance class from its own module, fetch it there
        # (Adjust this if your project structure differs)
        LpDistance = getattr(kmeans_local, "LpDistance", None)
        if LpDistance is None:
            # If not exported, try to access where your class expects it
            try:
                from utils.distances import LpDistance as LpDistance  # type: ignore
            except Exception:
                pass
    else:
        raise FileNotFoundError("Could not find './kmeans.py'. Place your class there.")
except Exception as e:
    print("Error importing your KMeans from './kmeans.py':", e)

# Try importing sklearn KMeans
sklearn_available = True
try:
    from sklearn.cluster import KMeans as SKKMeans
except Exception as e:
    sklearn_available = False
    print("scikit-learn is not available here. Install with: pip install scikit-learn")

# Prepare outputs
results = {}

# ---- Torch KMeans ----
if TorchKMeans is not None and LpDistance is not None:
    try:
        # Your API expects (BS, N, D) with BS=1
        Xt = X.float().unsqueeze(0)  # (1, n, d)
        tkm = TorchKMeans(
            init_method=init_method if init_method in ["k-means++", "rnd"] else "rnd",
            num_init=num_init,
            max_iter=max_iter,
            distance=LpDistance,
            p_norm=2,
            tol=tol,
            normalize=None,
            n_clusters=k,
            verbose=False,
            seed=seed,
        )
        _ = tkm.fit_predict(Xt, k=k)
        inertia_torch = float(tkm._result.inertia[0].cpu().numpy())
        centers_torch = tkm._result.centers[0].cpu().numpy()
        results["torch"] = (inertia_torch, centers_torch)
        print(f"[Torch] inertia = {inertia_torch:.6f}")
        print(f"[Torch] first {first_centers_to_show} centers:\n", centers_torch[:first_centers_to_show])
    except Exception as e:
        print("Error running your Torch KMeans:", e)
else:
    if TorchKMeans is None:
        print("Torch KMeans not available (import issue).")
    if LpDistance is None:
        print("LpDistance not found. Ensure your distances module is importable.")

# ---- sklearn KMeans ----
if sklearn_available:
    try:
        sk = SKKMeans(
            n_clusters=k,
            init="k-means++" if init_method == "k-means++" else "random",
            n_init=num_init,
            max_iter=max_iter,
            tol=tol,
            algorithm="lloyd",
            random_state=seed,
        )
        sk.fit(X.numpy())
        inertia_sk = float(sk.inertia_)
        centers_sk = sk.cluster_centers_
        results["sklearn"] = (inertia_sk, centers_sk)
        print(f"[sklearn] inertia = {inertia_sk:.6f}")
        print(f"[sklearn] first {first_centers_to_show} centers:\n", centers_sk[:first_centers_to_show])
    except Exception as e:
        print("Error running scikit-learn KMeans:", e)

# If both ran, print an immediate delta on inertia (no alignment/labeling)
if "torch" in results and "sklearn" in results:
    inertia_t, _ = results["torch"]
    inertia_s, _ = results["sklearn"]
    rel_gap = abs(inertia_t - inertia_s) / max(1.0, abs(inertia_s))
    print(f"\nRelative inertia gap (|Torch - sklearn| / max(1, sklearn)) = {rel_gap:.6e}")


Error importing your KMeans from './kmeans.py': attempted relative import with no known parent package
Torch KMeans not available (import issue).
LpDistance not found. Ensure your distances module is importable.
[sklearn] inertia = 3545.143066
[sklearn] first 3 centers:
 [[ 2.3362153  -4.318768   -5.521825  ]
 [-1.6123555  -2.374844    0.6924534 ]
 [ 5.9944916  -0.87325835 -8.900952  ]]
