# ✅ Final Integrated Notebook: KAN Ablation Experiments A–G

# Ablation Experiments: A–G

In [49]:

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import adjusted_rand_score, normalized_mutual_info_score, silhouette_score
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
from tqdm import tqdm
import seaborn as sns
from scipy.optimize import linear_sum_assignment
from copy import deepcopy


In [50]:

def load_trace_dataset(train_path='Trace_TRAIN.tsv', test_path='Trace_TEST.tsv', normalize=True):
    train_data = np.loadtxt(train_path)
    test_data = np.loadtxt(test_path)
    X = np.concatenate([train_data[:, 1:], test_data[:, 1:]], axis=0)
    y = np.concatenate([train_data[:, 0], test_data[:, 0]], axis=0).astype(int)

    unique_labels = np.unique(y)
    label_map = {old: new for new, old in enumerate(unique_labels)}
    y = np.array([label_map[label] for label in y])

    if normalize:
        X = StandardScaler().fit_transform(X)
    return X, y

X, y = load_trace_dataset()
num_classes = len(np.unique(y))


In [52]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import sys

sys.path.insert(0, './pykan_local/kan')  # 加入你本地pykan目录

from KANLayer_local import KANLayer # 来自你 pykan 文件夹
all_encoders = {}
class KANEncoder(nn.Module):
    def __init__(self, input_dim=275, hidden_dim=128, output_dim=128):
        super().__init__()
        self.layer1 = KANLayer(in_dim=input_dim, out_dim=hidden_dim, num=21)
        self.layer2 = KANLayer(in_dim=hidden_dim, out_dim=hidden_dim, num=21)
        self.layer3 = KANLayer(in_dim=hidden_dim, out_dim=output_dim, num=21)

    def forward(self, x):
        x, _, _, _ = self.layer1(x)
        x, _, _, _ = self.layer2(x)
        x, _, _, _ = self.layer3(x)
        return F.normalize(x, dim=-1)

In [53]:
import torch.nn as nn

class MLPEncoder(nn.Module):
    def __init__(self, input_dim=275, hidden_dim=128, output_dim=128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )

    def forward(self, x):
        return self.net(x)


In [8]:

class CenterLoss(nn.Module):
    def __init__(self, num_classes, feat_dim):
        super().__init__()
        self.centers = nn.Parameter(torch.randn(num_classes, feat_dim))

    def forward(self, x, labels):
        return ((x - self.centers[labels])**2).sum() / 2.0

class SupConLoss(nn.Module):
    def __init__(self, temperature=0.07):
        super().__init__()
        self.temperature = temperature

    def forward(self, features, labels):
        device = features.device
        features = F.normalize(features, dim=1)
        sim_matrix = torch.matmul(features, features.T) / self.temperature
        mask = torch.eq(labels.unsqueeze(1), labels.unsqueeze(0)).float().to(device)
        logits_mask = torch.ones_like(mask) - torch.eye(features.shape[0], device=device)
        exp_sim = torch.exp(sim_matrix) * logits_mask
        log_prob = sim_matrix - torch.log(exp_sim.sum(dim=1, keepdim=True) + 1e-9)
        mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1).clamp(min=1)
        return -mean_log_prob_pos.mean()

def differentiable_kshape_loss(features, labels, num_classes):
    loss = 0
    for c in range(num_classes):
        mask = labels == c
        if mask.sum() < 2:
            continue
        cluster_feat = features[mask]
        prototype = cluster_feat.mean(dim=0)
        aligned = F.cosine_similarity(cluster_feat, prototype.unsqueeze(0))
        loss += (1 - aligned).mean()
    return loss


## Experiment A: {'kan': False, 'supcon': False, 'diversity': False}

In [55]:
use_kan = False
use_supcon = False
use_diversity = False
exp_name = 'A' 

In [56]:
encoder = KANEncoder(input_dim=X.shape[1], output_dim=128) if use_kan else MLPEncoder(input_dim=X.shape[1], output_dim=128)
center_loss_fn = CenterLoss(num_classes=num_classes, feat_dim=128)
supcon_loss_fn = SupConLoss()
optimizer = torch.optim.Adam(encoder.parameters(), lr=1e-3)

In [57]:
from sklearn.cluster import KMeans
from sklearn.metrics import adjusted_rand_score, normalized_mutual_info_score, silhouette_score
import matplotlib.pyplot as plt
import torch
from torch.utils.data import TensorDataset, DataLoader
from tqdm import tqdm

X_tensor = torch.tensor(X, dtype=torch.float32)
y_tensor = torch.tensor(y, dtype=torch.long)
dataset = TensorDataset(X_tensor, y_tensor)
loader = DataLoader(dataset, batch_size=32, shuffle=True)

encoder = KANEncoder(input_dim=X.shape[1], output_dim=128)
center_loss_fn = CenterLoss(num_classes=num_classes, feat_dim=128)
supcon_loss_fn = SupConLoss()
optimizer = torch.optim.Adam(encoder.parameters(), lr=1e-3)

epochs = 20
loss_history = []
ari_history = []
nmi_history = []

for epoch in tqdm(range(epochs)):
    encoder.train()
    total_loss = 0

    for xb, yb in loader:
        optimizer.zero_grad()
        feats = encoder(xb)
        loss = torch.tensor(0.0, requires_grad=True).to(xb.device)

        if use_diversity:
            loss += center_loss_fn(feats, yb) + differentiable_kshape_loss(feats, yb, num_classes)
        if use_supcon:
            loss += supcon_loss_fn(feats, yb)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    loss_history.append(total_loss)

   # ===== ✅ 聚类评估指标 =====
    encoder.eval()
    with torch.no_grad():
        feats_all = encoder(X_tensor).cpu().numpy()
        true_labels = y_tensor.cpu().numpy()

    kmeans = KMeans(n_clusters=num_classes, n_init=20, random_state=42)
    pred_labels = kmeans.fit_predict(feats_all)

    ari = adjusted_rand_score(true_labels, pred_labels)
    nmi = normalized_mutual_info_score(true_labels, pred_labels)
    sil = silhouette_score(feats_all, pred_labels)

    ari_history.append(ari)
    nmi_history.append(nmi)

    print(f"Epoch {epoch} | Loss: {total_loss:.4f} | ARI: {ari:.4f} | NMI: {nmi:.4f} | Silhouette: {sil:.4f}")
    # 保存当前实验的最终 encoder
all_encoders[exp_name] = deepcopy(encoder)


  5%|▌         | 1/20 [00:01<00:35,  1.89s/it]

Epoch 0 | Loss: 0.0000 | ARI: 0.5820 | NMI: 0.6716 | Silhouette: 0.3079


 10%|█         | 2/20 [00:04<00:40,  2.25s/it]

Epoch 1 | Loss: 0.0000 | ARI: 0.5820 | NMI: 0.6716 | Silhouette: 0.3079


 15%|█▌        | 3/20 [00:07<00:42,  2.50s/it]

Epoch 2 | Loss: 0.0000 | ARI: 0.5820 | NMI: 0.6716 | Silhouette: 0.3079


 20%|██        | 4/20 [00:09<00:39,  2.46s/it]

Epoch 3 | Loss: 0.0000 | ARI: 0.5820 | NMI: 0.6716 | Silhouette: 0.3079


 25%|██▌       | 5/20 [00:11<00:36,  2.40s/it]

Epoch 4 | Loss: 0.0000 | ARI: 0.5820 | NMI: 0.6716 | Silhouette: 0.3079


 30%|███       | 6/20 [00:14<00:34,  2.44s/it]

Epoch 5 | Loss: 0.0000 | ARI: 0.5820 | NMI: 0.6716 | Silhouette: 0.3079


 35%|███▌      | 7/20 [00:16<00:31,  2.45s/it]

Epoch 6 | Loss: 0.0000 | ARI: 0.5820 | NMI: 0.6716 | Silhouette: 0.3079


 40%|████      | 8/20 [00:19<00:29,  2.44s/it]

Epoch 7 | Loss: 0.0000 | ARI: 0.5820 | NMI: 0.6716 | Silhouette: 0.3079


 45%|████▌     | 9/20 [00:21<00:27,  2.49s/it]

Epoch 8 | Loss: 0.0000 | ARI: 0.5820 | NMI: 0.6716 | Silhouette: 0.3079


 50%|█████     | 10/20 [00:24<00:24,  2.43s/it]

Epoch 9 | Loss: 0.0000 | ARI: 0.5820 | NMI: 0.6716 | Silhouette: 0.3079


 55%|█████▌    | 11/20 [00:26<00:22,  2.54s/it]

Epoch 10 | Loss: 0.0000 | ARI: 0.5820 | NMI: 0.6716 | Silhouette: 0.3079


 60%|██████    | 12/20 [00:29<00:20,  2.62s/it]

Epoch 11 | Loss: 0.0000 | ARI: 0.5820 | NMI: 0.6716 | Silhouette: 0.3079


 65%|██████▌   | 13/20 [00:32<00:17,  2.50s/it]

Epoch 12 | Loss: 0.0000 | ARI: 0.5820 | NMI: 0.6716 | Silhouette: 0.3079


 70%|███████   | 14/20 [00:34<00:15,  2.56s/it]

Epoch 13 | Loss: 0.0000 | ARI: 0.5820 | NMI: 0.6716 | Silhouette: 0.3079


 75%|███████▌  | 15/20 [00:37<00:12,  2.48s/it]

Epoch 14 | Loss: 0.0000 | ARI: 0.5820 | NMI: 0.6716 | Silhouette: 0.3079


 80%|████████  | 16/20 [00:39<00:09,  2.40s/it]

Epoch 15 | Loss: 0.0000 | ARI: 0.5820 | NMI: 0.6716 | Silhouette: 0.3079


 85%|████████▌ | 17/20 [00:41<00:07,  2.48s/it]

Epoch 16 | Loss: 0.0000 | ARI: 0.5820 | NMI: 0.6716 | Silhouette: 0.3079


 90%|█████████ | 18/20 [00:44<00:05,  2.52s/it]

Epoch 17 | Loss: 0.0000 | ARI: 0.5820 | NMI: 0.6716 | Silhouette: 0.3079


 95%|█████████▌| 19/20 [00:46<00:02,  2.39s/it]

Epoch 18 | Loss: 0.0000 | ARI: 0.5820 | NMI: 0.6716 | Silhouette: 0.3079


100%|██████████| 20/20 [00:48<00:00,  2.42s/it]

Epoch 19 | Loss: 0.0000 | ARI: 0.5820 | NMI: 0.6716 | Silhouette: 0.3079





## Experiment B: {'kan': True, 'supcon': False, 'diversity': False}

In [58]:
use_kan = True
use_supcon = False
use_diversity = False
exp_name = 'B' 

In [59]:
encoder = KANEncoder(input_dim=X.shape[1], output_dim=128) if use_kan else MLPEncoder(input_dim=X.shape[1], output_dim=128)
center_loss_fn = CenterLoss(num_classes=num_classes, feat_dim=128)
supcon_loss_fn = SupConLoss()
optimizer = torch.optim.Adam(encoder.parameters(), lr=1e-3)

In [60]:
from sklearn.cluster import KMeans
from sklearn.metrics import adjusted_rand_score, normalized_mutual_info_score, silhouette_score
import matplotlib.pyplot as plt
import torch
from torch.utils.data import TensorDataset, DataLoader
from tqdm import tqdm

X_tensor = torch.tensor(X, dtype=torch.float32)
y_tensor = torch.tensor(y, dtype=torch.long)
dataset = TensorDataset(X_tensor, y_tensor)
loader = DataLoader(dataset, batch_size=32, shuffle=True)

encoder = KANEncoder(input_dim=X.shape[1], output_dim=128)
center_loss_fn = CenterLoss(num_classes=num_classes, feat_dim=128)
supcon_loss_fn = SupConLoss()
optimizer = torch.optim.Adam(encoder.parameters(), lr=1e-3)

epochs = 30
loss_history = []
ari_history = []
nmi_history = []

for epoch in tqdm(range(epochs)):
    encoder.train()
    total_loss = 0

    for xb, yb in loader:
        optimizer.zero_grad()
        feats = encoder(xb)
        loss = torch.tensor(0.0, requires_grad=True).to(xb.device)
        if use_diversity:
            loss += center_loss_fn(feats, yb) + differentiable_kshape_loss(feats, yb, num_classes)
        if use_supcon:
            loss += supcon_loss_fn(feats, yb)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    loss_history.append(total_loss)

   # ===== ✅ 聚类评估指标 =====
    encoder.eval()
    with torch.no_grad():
        feats_all = encoder(X_tensor).cpu().numpy()
        true_labels = y_tensor.cpu().numpy()

    kmeans = KMeans(n_clusters=num_classes, n_init=20, random_state=42)
    pred_labels = kmeans.fit_predict(feats_all)

    ari = adjusted_rand_score(true_labels, pred_labels)
    nmi = normalized_mutual_info_score(true_labels, pred_labels)
    sil = silhouette_score(feats_all, pred_labels)

    ari_history.append(ari)
    nmi_history.append(nmi)

    print(f"Epoch {epoch} | Loss: {total_loss:.4f} | ARI: {ari:.4f} | NMI: {nmi:.4f} | Silhouette: {sil:.4f}")
    # 保存当前实验的最终 encoder
all_encoders[exp_name] = deepcopy(encoder)


  3%|▎         | 1/30 [00:00<00:23,  1.26it/s]

Epoch 0 | Loss: 0.0000 | ARI: 0.3407 | NMI: 0.5088 | Silhouette: 0.3179


  7%|▋         | 2/30 [00:02<00:30,  1.08s/it]

Epoch 1 | Loss: 0.0000 | ARI: 0.3407 | NMI: 0.5088 | Silhouette: 0.3179


 10%|█         | 3/30 [00:03<00:29,  1.10s/it]

Epoch 2 | Loss: 0.0000 | ARI: 0.3407 | NMI: 0.5088 | Silhouette: 0.3179


 13%|█▎        | 4/30 [00:04<00:32,  1.25s/it]

Epoch 3 | Loss: 0.0000 | ARI: 0.3407 | NMI: 0.5088 | Silhouette: 0.3179


 17%|█▋        | 5/30 [00:05<00:28,  1.13s/it]

Epoch 4 | Loss: 0.0000 | ARI: 0.3407 | NMI: 0.5088 | Silhouette: 0.3179


 20%|██        | 6/30 [00:07<00:33,  1.39s/it]

Epoch 5 | Loss: 0.0000 | ARI: 0.3407 | NMI: 0.5088 | Silhouette: 0.3179


 23%|██▎       | 7/30 [00:09<00:34,  1.50s/it]

Epoch 6 | Loss: 0.0000 | ARI: 0.3407 | NMI: 0.5088 | Silhouette: 0.3179


 27%|██▋       | 8/30 [00:10<00:29,  1.34s/it]

Epoch 7 | Loss: 0.0000 | ARI: 0.3407 | NMI: 0.5088 | Silhouette: 0.3179


 30%|███       | 9/30 [00:11<00:27,  1.29s/it]

Epoch 8 | Loss: 0.0000 | ARI: 0.3407 | NMI: 0.5088 | Silhouette: 0.3179


 33%|███▎      | 10/30 [00:13<00:32,  1.61s/it]

Epoch 9 | Loss: 0.0000 | ARI: 0.3407 | NMI: 0.5088 | Silhouette: 0.3179


 37%|███▋      | 11/30 [00:15<00:29,  1.54s/it]

Epoch 10 | Loss: 0.0000 | ARI: 0.3407 | NMI: 0.5088 | Silhouette: 0.3179


 40%|████      | 12/30 [00:16<00:28,  1.56s/it]

Epoch 11 | Loss: 0.0000 | ARI: 0.3407 | NMI: 0.5088 | Silhouette: 0.3179


 43%|████▎     | 13/30 [00:17<00:24,  1.45s/it]

Epoch 12 | Loss: 0.0000 | ARI: 0.3407 | NMI: 0.5088 | Silhouette: 0.3179


 47%|████▋     | 14/30 [00:18<00:20,  1.26s/it]

Epoch 13 | Loss: 0.0000 | ARI: 0.3407 | NMI: 0.5088 | Silhouette: 0.3179


 50%|█████     | 15/30 [00:19<00:18,  1.27s/it]

Epoch 14 | Loss: 0.0000 | ARI: 0.3407 | NMI: 0.5088 | Silhouette: 0.3179


 53%|█████▎    | 16/30 [00:21<00:17,  1.28s/it]

Epoch 15 | Loss: 0.0000 | ARI: 0.3407 | NMI: 0.5088 | Silhouette: 0.3179


 57%|█████▋    | 17/30 [00:22<00:17,  1.35s/it]

Epoch 16 | Loss: 0.0000 | ARI: 0.3407 | NMI: 0.5088 | Silhouette: 0.3179


 60%|██████    | 18/30 [00:24<00:17,  1.42s/it]

Epoch 17 | Loss: 0.0000 | ARI: 0.3407 | NMI: 0.5088 | Silhouette: 0.3179


 63%|██████▎   | 19/30 [00:25<00:14,  1.33s/it]

Epoch 18 | Loss: 0.0000 | ARI: 0.3407 | NMI: 0.5088 | Silhouette: 0.3179


 67%|██████▋   | 20/30 [00:26<00:12,  1.25s/it]

Epoch 19 | Loss: 0.0000 | ARI: 0.3407 | NMI: 0.5088 | Silhouette: 0.3179


 70%|███████   | 21/30 [00:27<00:11,  1.24s/it]

Epoch 20 | Loss: 0.0000 | ARI: 0.3407 | NMI: 0.5088 | Silhouette: 0.3179


 73%|███████▎  | 22/30 [00:28<00:09,  1.14s/it]

Epoch 21 | Loss: 0.0000 | ARI: 0.3407 | NMI: 0.5088 | Silhouette: 0.3179


 77%|███████▋  | 23/30 [00:29<00:08,  1.15s/it]

Epoch 22 | Loss: 0.0000 | ARI: 0.3407 | NMI: 0.5088 | Silhouette: 0.3179


 80%|████████  | 24/30 [00:30<00:06,  1.14s/it]

Epoch 23 | Loss: 0.0000 | ARI: 0.3407 | NMI: 0.5088 | Silhouette: 0.3179


 83%|████████▎ | 25/30 [00:33<00:07,  1.40s/it]

Epoch 24 | Loss: 0.0000 | ARI: 0.3407 | NMI: 0.5088 | Silhouette: 0.3179


 87%|████████▋ | 26/30 [00:36<00:08,  2.06s/it]

Epoch 25 | Loss: 0.0000 | ARI: 0.3407 | NMI: 0.5088 | Silhouette: 0.3179


 90%|█████████ | 27/30 [00:40<00:07,  2.58s/it]

Epoch 26 | Loss: 0.0000 | ARI: 0.3407 | NMI: 0.5088 | Silhouette: 0.3179


 93%|█████████▎| 28/30 [00:42<00:05,  2.52s/it]

Epoch 27 | Loss: 0.0000 | ARI: 0.3407 | NMI: 0.5088 | Silhouette: 0.3179


 97%|█████████▋| 29/30 [00:45<00:02,  2.50s/it]

Epoch 28 | Loss: 0.0000 | ARI: 0.3407 | NMI: 0.5088 | Silhouette: 0.3179


100%|██████████| 30/30 [00:47<00:00,  1.58s/it]

Epoch 29 | Loss: 0.0000 | ARI: 0.3407 | NMI: 0.5088 | Silhouette: 0.3179





## Experiment C: {'kan': True, 'supcon': True, 'diversity': False}

In [61]:
use_kan = True
use_supcon = True
use_diversity = False
exp_name = 'C' 

In [62]:
encoder = KANEncoder(input_dim=X.shape[1], output_dim=128) if use_kan else MLPEncoder(input_dim=X.shape[1], output_dim=128)
center_loss_fn = CenterLoss(num_classes=num_classes, feat_dim=128)
supcon_loss_fn = SupConLoss()
optimizer = torch.optim.Adam(encoder.parameters(), lr=1e-3)

In [63]:
from sklearn.cluster import KMeans
from sklearn.metrics import adjusted_rand_score, normalized_mutual_info_score, silhouette_score
import matplotlib.pyplot as plt
import torch
from torch.utils.data import TensorDataset, DataLoader
from tqdm import tqdm

X_tensor = torch.tensor(X, dtype=torch.float32)
y_tensor = torch.tensor(y, dtype=torch.long)
dataset = TensorDataset(X_tensor, y_tensor)
loader = DataLoader(dataset, batch_size=32, shuffle=True)

encoder = KANEncoder(input_dim=X.shape[1], output_dim=128)
center_loss_fn = CenterLoss(num_classes=num_classes, feat_dim=128)
supcon_loss_fn = SupConLoss()
optimizer = torch.optim.Adam(encoder.parameters(), lr=1e-3)

epochs = 50
loss_history = []
ari_history = []
nmi_history = []

for epoch in tqdm(range(epochs)):
    encoder.train()
    total_loss = 0

for xb, yb in loader:
    optimizer.zero_grad()
    feats = encoder(xb)
    
    loss = 0
    if use_diversity:
        loss = loss + center_loss_fn(feats, yb) + differentiable_kshape_loss(feats, yb, num_classes)
    if use_supcon:
        loss = loss + supcon_loss_fn(feats, yb)
    
    if not isinstance(loss, torch.Tensor):
        loss = torch.tensor(loss, dtype=torch.float32, requires_grad=True).to(xb.device)
    
    loss.backward()
    optimizer.step()
    total_loss += loss.item()

    loss_history.append(total_loss)

   # ===== ✅ 聚类评估指标 =====
    encoder.eval()
    with torch.no_grad():
        feats_all = encoder(X_tensor).cpu().numpy()
        true_labels = y_tensor.cpu().numpy()

    kmeans = KMeans(n_clusters=num_classes, n_init=20, random_state=42)
    pred_labels = kmeans.fit_predict(feats_all)

    ari = adjusted_rand_score(true_labels, pred_labels)
    nmi = normalized_mutual_info_score(true_labels, pred_labels)
    sil = silhouette_score(feats_all, pred_labels)

    ari_history.append(ari)
    nmi_history.append(nmi)

    print(f"Epoch {epoch} | Loss: {total_loss:.4f} | ARI: {ari:.4f} | NMI: {nmi:.4f} | Silhouette: {sil:.4f}")
    # 保存当前实验的最终 encoder
all_encoders[exp_name] = deepcopy(encoder)


100%|██████████| 50/50 [00:00<00:00, 71114.00it/s]


Epoch 49 | Loss: 4.5410 | ARI: 0.5452 | NMI: 0.7115 | Silhouette: 0.4965
Epoch 49 | Loss: 7.1088 | ARI: 0.6084 | NMI: 0.7519 | Silhouette: 0.5409
Epoch 49 | Loss: 9.7721 | ARI: 0.5992 | NMI: 0.6909 | Silhouette: 0.2943
Epoch 49 | Loss: 12.3359 | ARI: 0.6230 | NMI: 0.7090 | Silhouette: 0.3355
Epoch 49 | Loss: 14.9533 | ARI: 0.6491 | NMI: 0.7332 | Silhouette: 0.3797
Epoch 49 | Loss: 17.4355 | ARI: 0.6495 | NMI: 0.7336 | Silhouette: 0.4214
Epoch 49 | Loss: 18.0994 | ARI: 0.6503 | NMI: 0.7345 | Silhouette: 0.4639


## Experiment D: {'kan': True, 'supcon': False, 'diversity': True}

In [64]:
use_kan = True
use_supcon = False
use_diversity = True
exp_name = 'D'  

In [65]:
encoder = KANEncoder(input_dim=X.shape[1], output_dim=128) if use_kan else MLPEncoder(input_dim=X.shape[1], output_dim=128)
center_loss_fn = CenterLoss(num_classes=num_classes, feat_dim=128)
supcon_loss_fn = SupConLoss()
optimizer = torch.optim.Adam(encoder.parameters(), lr=1e-3)

In [66]:
from sklearn.cluster import KMeans
from sklearn.metrics import adjusted_rand_score, normalized_mutual_info_score, silhouette_score
import matplotlib.pyplot as plt
import torch
from torch.utils.data import TensorDataset, DataLoader
from tqdm import tqdm

X_tensor = torch.tensor(X, dtype=torch.float32)
y_tensor = torch.tensor(y, dtype=torch.long)
dataset = TensorDataset(X_tensor, y_tensor)
loader = DataLoader(dataset, batch_size=32, shuffle=True)

encoder = KANEncoder(input_dim=X.shape[1], output_dim=128)
center_loss_fn = CenterLoss(num_classes=num_classes, feat_dim=128)
supcon_loss_fn = SupConLoss()
optimizer = torch.optim.Adam(encoder.parameters(), lr=1e-3)

epochs = 10
loss_history = []
ari_history = []
nmi_history = []

for epoch in tqdm(range(epochs)):
    encoder.train()
    total_loss = 0

    for xb, yb in loader:
        optimizer.zero_grad()
        feats = encoder(xb)
        loss = 0
        if use_diversity:
            loss += center_loss_fn(feats, yb) + differentiable_kshape_loss(feats, yb, num_classes)
        if use_supcon:
            loss += supcon_loss_fn(feats, yb)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    loss_history.append(total_loss)

   # ===== ✅ 聚类评估指标 =====
    encoder.eval()
    with torch.no_grad():
        feats_all = encoder(X_tensor).cpu().numpy()
        true_labels = y_tensor.cpu().numpy()

    kmeans = KMeans(n_clusters=num_classes, n_init=20, random_state=42)
    pred_labels = kmeans.fit_predict(feats_all)

    ari = adjusted_rand_score(true_labels, pred_labels)
    nmi = normalized_mutual_info_score(true_labels, pred_labels)
    sil = silhouette_score(feats_all, pred_labels)

    ari_history.append(ari)
    nmi_history.append(nmi)

    print(f"Epoch {epoch} | Loss: {total_loss:.4f} | ARI: {ari:.4f} | NMI: {nmi:.4f} | Silhouette: {sil:.4f}")
    # 保存当前实验的最终 encoder
all_encoders[exp_name] = deepcopy(encoder)


 10%|█         | 1/10 [00:04<00:37,  4.20s/it]

Epoch 0 | Loss: 13255.1648 | ARI: 0.6144 | NMI: 0.7030 | Silhouette: 0.4572


 20%|██        | 2/10 [00:08<00:34,  4.26s/it]

Epoch 1 | Loss: 12632.9910 | ARI: 0.6849 | NMI: 0.7763 | Silhouette: 0.6162


 30%|███       | 3/10 [00:12<00:28,  4.00s/it]

Epoch 2 | Loss: 12498.0450 | ARI: 0.7017 | NMI: 0.7952 | Silhouette: 0.6629


 40%|████      | 4/10 [00:15<00:23,  3.92s/it]

Epoch 3 | Loss: 12430.6806 | ARI: 0.8045 | NMI: 0.8686 | Silhouette: 0.6745


 50%|█████     | 5/10 [00:19<00:19,  3.88s/it]

Epoch 4 | Loss: 12368.0454 | ARI: 0.9481 | NMI: 0.9491 | Silhouette: 0.7493


 60%|██████    | 6/10 [00:23<00:14,  3.72s/it]

Epoch 5 | Loss: 12264.0778 | ARI: 0.9358 | NMI: 0.9404 | Silhouette: 0.8148


 70%|███████   | 7/10 [00:28<00:12,  4.20s/it]

Epoch 6 | Loss: 12139.8682 | ARI: 1.0000 | NMI: 1.0000 | Silhouette: 0.8792


 80%|████████  | 8/10 [00:32<00:08,  4.21s/it]

Epoch 7 | Loss: 12104.2094 | ARI: 1.0000 | NMI: 1.0000 | Silhouette: 0.9083


 90%|█████████ | 9/10 [00:37<00:04,  4.44s/it]

Epoch 8 | Loss: 12086.8430 | ARI: 1.0000 | NMI: 1.0000 | Silhouette: 0.9310


100%|██████████| 10/10 [00:41<00:00,  4.19s/it]

Epoch 9 | Loss: 12076.1656 | ARI: 1.0000 | NMI: 1.0000 | Silhouette: 0.9410





## Experiment E: {'kan': False, 'supcon': True, 'diversity': True}

In [67]:
use_kan = False
use_supcon = True
use_diversity = True
exp_name = 'E' 

In [68]:
encoder = KANEncoder(input_dim=X.shape[1], output_dim=128) if use_kan else MLPEncoder(input_dim=X.shape[1], output_dim=128)
center_loss_fn = CenterLoss(num_classes=num_classes, feat_dim=128)
supcon_loss_fn = SupConLoss()
optimizer = torch.optim.Adam(encoder.parameters(), lr=1e-3)

In [69]:
from sklearn.cluster import KMeans
from sklearn.metrics import adjusted_rand_score, normalized_mutual_info_score, silhouette_score
import matplotlib.pyplot as plt
import torch
from torch.utils.data import TensorDataset, DataLoader
from tqdm import tqdm

X_tensor = torch.tensor(X, dtype=torch.float32)
y_tensor = torch.tensor(y, dtype=torch.long)
dataset = TensorDataset(X_tensor, y_tensor)
loader = DataLoader(dataset, batch_size=32, shuffle=True)

encoder = KANEncoder(input_dim=X.shape[1], output_dim=128)
center_loss_fn = CenterLoss(num_classes=num_classes, feat_dim=128)
supcon_loss_fn = SupConLoss()
optimizer = torch.optim.Adam(encoder.parameters(), lr=1e-3)

epochs = 10
loss_history = []
ari_history = []
nmi_history = []

for epoch in tqdm(range(epochs)):
    encoder.train()
    total_loss = 0

    for xb, yb in loader:
        optimizer.zero_grad()
        feats = encoder(xb)
        loss = 0
        if use_diversity:
            loss += center_loss_fn(feats, yb) + differentiable_kshape_loss(feats, yb, num_classes)
        if use_supcon:
            loss += supcon_loss_fn(feats, yb)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    loss_history.append(total_loss)

   # ===== ✅ 聚类评估指标 =====
    encoder.eval()
    with torch.no_grad():
        feats_all = encoder(X_tensor).cpu().numpy()
        true_labels = y_tensor.cpu().numpy()

    kmeans = KMeans(n_clusters=num_classes, n_init=20, random_state=42)
    pred_labels = kmeans.fit_predict(feats_all)

    ari = adjusted_rand_score(true_labels, pred_labels)
    nmi = normalized_mutual_info_score(true_labels, pred_labels)
    sil = silhouette_score(feats_all, pred_labels)

    ari_history.append(ari)
    nmi_history.append(nmi)

    print(f"Epoch {epoch} | Loss: {total_loss:.4f} | ARI: {ari:.4f} | NMI: {nmi:.4f} | Silhouette: {sil:.4f}")
    # 保存当前实验的最终 encoder
all_encoders[exp_name] = deepcopy(encoder)


 10%|█         | 1/10 [00:03<00:30,  3.35s/it]

Epoch 0 | Loss: 11855.4006 | ARI: 0.5676 | NMI: 0.6707 | Silhouette: 0.4382


 20%|██        | 2/10 [00:07<00:32,  4.09s/it]

Epoch 1 | Loss: 11225.8594 | ARI: 0.6709 | NMI: 0.7609 | Silhouette: 0.5938


 30%|███       | 3/10 [00:12<00:28,  4.13s/it]

Epoch 2 | Loss: 11036.0498 | ARI: 0.6367 | NMI: 0.8051 | Silhouette: 0.7507


 40%|████      | 4/10 [00:17<00:27,  4.64s/it]

Epoch 3 | Loss: 10969.6244 | ARI: 0.6658 | NMI: 0.8200 | Silhouette: 0.7974


 50%|█████     | 5/10 [00:21<00:22,  4.56s/it]

Epoch 4 | Loss: 10935.7238 | ARI: 0.9358 | NMI: 0.9404 | Silhouette: 0.6693


 60%|██████    | 6/10 [00:28<00:21,  5.32s/it]

Epoch 5 | Loss: 10882.1548 | ARI: 0.9606 | NMI: 0.9587 | Silhouette: 0.7814


 70%|███████   | 7/10 [00:35<00:17,  5.73s/it]

Epoch 6 | Loss: 10789.1786 | ARI: 0.9606 | NMI: 0.9587 | Silhouette: 0.8481


 80%|████████  | 8/10 [00:40<00:11,  5.63s/it]

Epoch 7 | Loss: 10690.4960 | ARI: 1.0000 | NMI: 1.0000 | Silhouette: 0.8990


 90%|█████████ | 9/10 [00:44<00:05,  5.15s/it]

Epoch 8 | Loss: 10665.6438 | ARI: 1.0000 | NMI: 1.0000 | Silhouette: 0.9142


100%|██████████| 10/10 [00:50<00:00,  5.07s/it]

Epoch 9 | Loss: 10650.2880 | ARI: 1.0000 | NMI: 1.0000 | Silhouette: 0.9262





## Experiment F: {'kan': False, 'supcon': True, 'diversity': False}

In [70]:
use_kan = False
use_supcon = True
use_diversity = False
exp_name = 'F' 

In [71]:
encoder = KANEncoder(input_dim=X.shape[1], output_dim=128) if use_kan else MLPEncoder(input_dim=X.shape[1], output_dim=128)
center_loss_fn = CenterLoss(num_classes=num_classes, feat_dim=128)
supcon_loss_fn = SupConLoss()
optimizer = torch.optim.Adam(encoder.parameters(), lr=1e-3)

In [72]:
from sklearn.cluster import KMeans
from sklearn.metrics import adjusted_rand_score, normalized_mutual_info_score, silhouette_score
import matplotlib.pyplot as plt
import torch
from torch.utils.data import TensorDataset, DataLoader
from tqdm import tqdm

X_tensor = torch.tensor(X, dtype=torch.float32)
y_tensor = torch.tensor(y, dtype=torch.long)
dataset = TensorDataset(X_tensor, y_tensor)
loader = DataLoader(dataset, batch_size=32, shuffle=True)

encoder = KANEncoder(input_dim=X.shape[1], output_dim=128)
center_loss_fn = CenterLoss(num_classes=num_classes, feat_dim=128)
supcon_loss_fn = SupConLoss()
optimizer = torch.optim.Adam(encoder.parameters(), lr=1e-3)

epochs = 20
loss_history = []
ari_history = []
nmi_history = []

for epoch in tqdm(range(epochs)):
    encoder.train()
    total_loss = 0

    for xb, yb in loader:
        optimizer.zero_grad()
        feats = encoder(xb)
        loss = 0
        if use_diversity:
            loss += center_loss_fn(feats, yb) + differentiable_kshape_loss(feats, yb, num_classes)
        if use_supcon:
            loss += supcon_loss_fn(feats, yb)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    loss_history.append(total_loss)

   # ===== ✅ 聚类评估指标 =====
    encoder.eval()
    with torch.no_grad():
        feats_all = encoder(X_tensor).cpu().numpy()
        true_labels = y_tensor.cpu().numpy()

    kmeans = KMeans(n_clusters=num_classes, n_init=20, random_state=42)
    pred_labels = kmeans.fit_predict(feats_all)

    ari = adjusted_rand_score(true_labels, pred_labels)
    nmi = normalized_mutual_info_score(true_labels, pred_labels)
    sil = silhouette_score(feats_all, pred_labels)

    ari_history.append(ari)
    nmi_history.append(nmi)

    print(f"Epoch {epoch} | Loss: {total_loss:.4f} | ARI: {ari:.4f} | NMI: {nmi:.4f} | Silhouette: {sil:.4f}")
    # 保存当前实验的最终 encoder
all_encoders[exp_name] = deepcopy(encoder)


  5%|▌         | 1/20 [00:04<01:26,  4.53s/it]

Epoch 0 | Loss: 17.4956 | ARI: 0.6250 | NMI: 0.7112 | Silhouette: 0.4862


 10%|█         | 2/20 [00:09<01:29,  4.98s/it]

Epoch 1 | Loss: 14.0047 | ARI: 0.6655 | NMI: 0.7525 | Silhouette: 0.5902


 15%|█▌        | 3/20 [00:15<01:28,  5.21s/it]

Epoch 2 | Loss: 14.4954 | ARI: 0.6658 | NMI: 0.8200 | Silhouette: 0.7010


 20%|██        | 4/20 [00:17<01:05,  4.10s/it]

Epoch 3 | Loss: 12.9096 | ARI: 0.6645 | NMI: 0.7534 | Silhouette: 0.5919


 25%|██▌       | 5/20 [00:19<00:49,  3.27s/it]

Epoch 4 | Loss: 13.5182 | ARI: 0.6645 | NMI: 0.7534 | Silhouette: 0.5579


 30%|███       | 6/20 [00:21<00:39,  2.79s/it]

Epoch 5 | Loss: 12.7036 | ARI: 0.6644 | NMI: 0.7530 | Silhouette: 0.5694


 35%|███▌      | 7/20 [00:23<00:32,  2.51s/it]

Epoch 6 | Loss: 13.7916 | ARI: 0.6343 | NMI: 0.8040 | Silhouette: 0.6792


 40%|████      | 8/20 [00:24<00:25,  2.10s/it]

Epoch 7 | Loss: 12.7886 | ARI: 0.6963 | NMI: 0.7689 | Silhouette: 0.5587


 45%|████▌     | 9/20 [00:26<00:22,  2.00s/it]

Epoch 8 | Loss: 12.4725 | ARI: 0.6780 | NMI: 0.7588 | Silhouette: 0.5633


 50%|█████     | 10/20 [00:27<00:18,  1.88s/it]

Epoch 9 | Loss: 12.2307 | ARI: 0.8127 | NMI: 0.8724 | Silhouette: 0.5444


 55%|█████▌    | 11/20 [00:29<00:15,  1.67s/it]

Epoch 10 | Loss: 13.7762 | ARI: 0.9122 | NMI: 0.9251 | Silhouette: 0.6134


 60%|██████    | 12/20 [00:31<00:14,  1.80s/it]

Epoch 11 | Loss: 12.8877 | ARI: 0.9606 | NMI: 0.9587 | Silhouette: 0.6778


 65%|██████▌   | 13/20 [00:32<00:11,  1.68s/it]

Epoch 12 | Loss: 13.1178 | ARI: 1.0000 | NMI: 1.0000 | Silhouette: 0.6672


 70%|███████   | 14/20 [00:34<00:09,  1.60s/it]

Epoch 13 | Loss: 12.0578 | ARI: 0.9481 | NMI: 0.9491 | Silhouette: 0.6446


 75%|███████▌  | 15/20 [00:35<00:07,  1.51s/it]

Epoch 14 | Loss: 9.1200 | ARI: 0.9735 | NMI: 0.9696 | Silhouette: 0.6494


 80%|████████  | 16/20 [00:36<00:05,  1.45s/it]

Epoch 15 | Loss: 11.5600 | ARI: 0.9238 | NMI: 0.9324 | Silhouette: 0.6410


 85%|████████▌ | 17/20 [00:38<00:04,  1.43s/it]

Epoch 16 | Loss: 11.5765 | ARI: 1.0000 | NMI: 1.0000 | Silhouette: 0.6823


 90%|█████████ | 18/20 [00:39<00:02,  1.42s/it]

Epoch 17 | Loss: 9.9074 | ARI: 1.0000 | NMI: 1.0000 | Silhouette: 0.6662


 95%|█████████▌| 19/20 [00:41<00:01,  1.56s/it]

Epoch 18 | Loss: 10.7485 | ARI: 1.0000 | NMI: 1.0000 | Silhouette: 0.6526


100%|██████████| 20/20 [00:42<00:00,  2.15s/it]

Epoch 19 | Loss: 11.7168 | ARI: 1.0000 | NMI: 1.0000 | Silhouette: 0.6595





## Experiment G: {'kan': False, 'supcon': False, 'diversity': True}

In [73]:
use_kan = False
use_supcon = False
use_diversity = True
exp_name = 'G' 

In [74]:
encoder = KANEncoder(input_dim=X.shape[1], output_dim=128) if use_kan else MLPEncoder(input_dim=X.shape[1], output_dim=128)
center_loss_fn = CenterLoss(num_classes=num_classes, feat_dim=128)
supcon_loss_fn = SupConLoss()
optimizer = torch.optim.Adam(encoder.parameters(), lr=1e-3)

In [75]:
from sklearn.cluster import KMeans
from sklearn.metrics import adjusted_rand_score, normalized_mutual_info_score, silhouette_score
import matplotlib.pyplot as plt
import torch
from torch.utils.data import TensorDataset, DataLoader
from tqdm import tqdm

X_tensor = torch.tensor(X, dtype=torch.float32)
y_tensor = torch.tensor(y, dtype=torch.long)
dataset = TensorDataset(X_tensor, y_tensor)
loader = DataLoader(dataset, batch_size=32, shuffle=True)

encoder = KANEncoder(input_dim=X.shape[1], output_dim=128)
center_loss_fn = CenterLoss(num_classes=num_classes, feat_dim=128)
supcon_loss_fn = SupConLoss()
optimizer = torch.optim.Adam(encoder.parameters(), lr=1e-3)

epochs = 10
loss_history = []
ari_history = []
nmi_history = []

for epoch in tqdm(range(epochs)):
    encoder.train()
    total_loss = 0

    for xb, yb in loader:
        optimizer.zero_grad()
        feats = encoder(xb)
        loss = 0
        if use_diversity:
            loss += center_loss_fn(feats, yb) + differentiable_kshape_loss(feats, yb, num_classes)
        if use_supcon:
            loss += supcon_loss_fn(feats, yb)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    loss_history.append(total_loss)

   # ===== ✅ 聚类评估指标 =====
    encoder.eval()
    with torch.no_grad():
        feats_all = encoder(X_tensor).cpu().numpy()
        true_labels = y_tensor.cpu().numpy()

    kmeans = KMeans(n_clusters=num_classes, n_init=20, random_state=42)
    pred_labels = kmeans.fit_predict(feats_all)

    ari = adjusted_rand_score(true_labels, pred_labels)
    nmi = normalized_mutual_info_score(true_labels, pred_labels)
    sil = silhouette_score(feats_all, pred_labels)

    ari_history.append(ari)
    nmi_history.append(nmi)

    print(f"Epoch {epoch} | Loss: {total_loss:.4f} | ARI: {ari:.4f} | NMI: {nmi:.4f} | Silhouette: {sil:.4f}")
    # 保存当前实验的最终 encoder
all_encoders[exp_name] = deepcopy(encoder)


 10%|█         | 1/10 [00:01<00:13,  1.48s/it]

Epoch 0 | Loss: 13557.5269 | ARI: 0.5994 | NMI: 0.6938 | Silhouette: 0.4705


 20%|██        | 2/10 [00:03<00:12,  1.61s/it]

Epoch 1 | Loss: 12868.1075 | ARI: 0.6347 | NMI: 0.7213 | Silhouette: 0.6068


 30%|███       | 3/10 [00:05<00:12,  1.74s/it]

Epoch 2 | Loss: 12701.4049 | ARI: 0.6940 | NMI: 0.7960 | Silhouette: 0.6819


 40%|████      | 4/10 [00:06<00:09,  1.66s/it]

Epoch 3 | Loss: 12622.7784 | ARI: 0.7007 | NMI: 0.8088 | Silhouette: 0.7064


 50%|█████     | 5/10 [00:08<00:08,  1.64s/it]

Epoch 4 | Loss: 12587.8470 | ARI: 0.8686 | NMI: 0.9001 | Silhouette: 0.6898


 60%|██████    | 6/10 [00:09<00:06,  1.68s/it]

Epoch 5 | Loss: 12531.6910 | ARI: 0.9481 | NMI: 0.9491 | Silhouette: 0.7804


 70%|███████   | 7/10 [00:11<00:04,  1.56s/it]

Epoch 6 | Loss: 12446.3091 | ARI: 0.9735 | NMI: 0.9696 | Silhouette: 0.8468


 80%|████████  | 8/10 [00:12<00:02,  1.48s/it]

Epoch 7 | Loss: 12337.7678 | ARI: 0.9866 | NMI: 0.9823 | Silhouette: 0.9017


 90%|█████████ | 9/10 [00:13<00:01,  1.43s/it]

Epoch 8 | Loss: 12292.7165 | ARI: 1.0000 | NMI: 1.0000 | Silhouette: 0.9150


100%|██████████| 10/10 [00:15<00:00,  1.53s/it]

Epoch 9 | Loss: 12281.1765 | ARI: 1.0000 | NMI: 1.0000 | Silhouette: 0.9321





## 📊 Evaluation: ARI, NMI, Silhouette for A–G

In [78]:

from sklearn.metrics import adjusted_rand_score, normalized_mutual_info_score, silhouette_score
from sklearn.cluster import KMeans
import pandas as pd

experiments = all_encoders

results = []
X_all = X_tensor
y_true = y_tensor.numpy()

for exp_name, encoder in experiments.items():
    encoder.eval()
    with torch.no_grad():
        feats = encoder(X_all).cpu().numpy()

    pred_labels = KMeans(n_clusters=num_classes, random_state=42).fit_predict(feats)

    ari = adjusted_rand_score(y_true, pred_labels)
    nmi = normalized_mutual_info_score(y_true, pred_labels)
    sil = silhouette_score(feats, pred_labels)

    results.append({
        'Experiment': exp_name,
        'ARI': round(ari, 4),
        'NMI': round(nmi, 4),
        'Silhouette': round(sil, 4)
    })

df_result = pd.DataFrame(results).sort_values(by='Experiment')
print(df_result.to_markdown(index=False))


| Experiment   |    ARI |    NMI |   Silhouette |
|:-------------|-------:|-------:|-------------:|
| A            | 0.5939 | 0.6889 |       0.3068 |
| B            | 0.3407 | 0.5088 |       0.3179 |
| C            | 0.6503 | 0.7345 |       0.4639 |
| D            | 1      | 1      |       0.941  |
| E            | 1      | 1      |       0.9262 |
| F            | 1      | 1      |       0.6595 |
| G            | 1      | 1      |       0.9321 |
