In [None]:
import torch.nn as nn


class Autoencoder(nn.Module):
    def __init__(self, input_dim, latent_dim):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, 500),
            nn.ReLU(),
            nn.Linear(500, 500),
            nn.ReLU(),
            nn.Linear(500, 2000),
            nn.ReLU(),
            nn.Linear(2000, latent_dim)
        )
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 2000),
            nn.ReLU(),
            nn.Linear(2000, 500),
            nn.ReLU(),
            nn.Linear(500, 500),
            nn.ReLU(),
            nn.Linear(500, input_dim)
        )

    def forward(self, x):
        z = self.encoder(x)
        x_hat = self.decoder(z)
        return z, x_hat


In [None]:
import torch.nn as nn
import torch
import torch.nn.functional as F
from sklearn.cluster import KMeans


class GCEALsHead(nn.Module):
    def __init__(self, latent_dim, n_clusters):
        super().__init__()
        self.n_clusters = n_clusters

        # Centroids μ_j (to be initialized with k-means, Algorithm 1 step 6)
        self.centroids = nn.Parameter(torch.zeros(n_clusters, latent_dim))

        # Diagonal covariance Σ_j = diag(exp(logvar))
        self.logvar = nn.Parameter(torch.zeros(n_clusters, latent_dim))

        # MLP Head for auxiliary distribution Q (Eq. 10)
        self.mlp = nn.Linear(latent_dim, n_clusters)

    def forward(self, z):
        """
        Returns:
          P = cluster distribution from Gaussian/Mahalanobis (Eq. 6–9)
          Q = cluster distribution from MLP head (Eq. 10)
        """
        if z.dim() == 1:  # single sample (D,)
            z = z.unsqueeze(0)  # → (1, D)

        B, D = z.shape

        # --- Gaussian assignment (P) ---
        # Expand z to (B, K, D)
        z_exp = z.unsqueeze(1).expand(B, self.n_clusters, D)  # (B, K, D)
        mu = self.centroids.unsqueeze(0).expand(B, self.n_clusters, D)  # (B, K, D)
        var = torch.exp(self.logvar).unsqueeze(0).expand(B, self.n_clusters, D)  # (B, K, D)

        dist = torch.sqrt(((z_exp - mu) ** 2 / var).sum(dim=2))  # (B, K)
        P = F.softmax(-dist, dim=1)  # soft assignments across clusters

        # --- MLP assignment (Q) ---
        logits = self.mlp(z)  # (B, K)
        Q = F.softmax(logits, dim=1)

        return P, Q


class GCEALs(nn.Module):
    def __init__(self, input_dim, latent_dim, n_clusters):
        super().__init__()
        self.ae = Autoencoder(input_dim, latent_dim)
        self.cluster_head = GCEALsHead(latent_dim, n_clusters)
        self.n_clusters = n_clusters
        self.pi = None  # cluster priors (Eq. 8)

    def forward(self, x):
        z, x_hat = self.ae(x)
        P, Q = self.cluster_head(z)
        return P, Q, x_hat, z

    def pretrain(self, dataloader, optimizer, epochs=50, device='cuda', patience=10):
        """
        Pretrain autoencoder only (Eq. 11).
        """
        self.train()
        self.to(device)
        loss_fn = nn.MSELoss()

        best_loss = float('inf')
        patience_counter = 0

        for epoch in range(epochs):
            total_loss, n_samples = 0.0, 0

            for x_batch in dataloader:
                x_batch = x_batch.to(device)
                _, x_hat = self.ae(x_batch)
                loss = loss_fn(x_hat, x_batch)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                total_loss += loss.item() * x_batch.size(0)
                n_samples += x_batch.size(0)

            avg_loss = total_loss / n_samples
            print(f'[Pretrain] Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}')

            # --- Early stopping check ---
            if avg_loss < best_loss - 1e-4:  # small tolerance
                best_loss = avg_loss
                patience_counter = 0
            else:
                patience_counter += 1

            if patience_counter >= patience:
                print(f'[Pretrain] Early stopping at epoch {epoch+1}, best loss={best_loss:.4f}')
                break

    def init_centroids(self, dataloader, device='cuda'):
        """
        Initialize centroids μ_j using k-means on latent space (Alg. 1, step 6).
        """
        self.eval()
        all_z = []
        with torch.no_grad():
            for x_batch in dataloader:
                x_batch = x_batch.to(device)
                z, _ = self.ae(x_batch)
                all_z.append(z.cpu())
        all_z = torch.cat(all_z, dim=0).numpy()

        # Run k-means
        kmeans = KMeans(n_clusters=self.n_clusters, n_init=20)
        kmeans.fit(all_z)
        centroids = torch.tensor(kmeans.cluster_centers_, dtype=torch.float)

        # Initialize cluster centroids μ_j
        self.cluster_head.centroids.data.copy_(centroids)

        # Initialize priors ω_j equally (Eq. 8)
        self.pi = torch.ones(self.n_clusters, device=device) / self.n_clusters
        print("[Init] Centroids initialized with k-means")

    def train_gceals(self, dataloader, optimizer, epochs=100, device='cuda', gamma=1.0):
        """
        Joint training of AE + clustering (Algorithm 1).
        Loss = L_rec + γ * KL(P || Q)   (Eq. 13)
        """
        self.train()
        self.to(device)
        recon_loss_fn = nn.MSELoss()

        for epoch in range(epochs):
            total_loss, n_samples = 0.0, 0

            for x_batch in dataloader:
                x_batch = x_batch.to(device)

                # Forward pass
                P, Q, x_hat, z = self(x_batch)  # P=Gaussian, Q=MLP

                # ====== Loss ======
                recon_loss = recon_loss_fn(x_hat, x_batch)
                kl_div = F.kl_div(Q.log(), P, reduction='batchmean')  # KL(P || Q)
                loss = recon_loss + gamma * kl_div

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                total_loss += loss.item() * x_batch.size(0)
                n_samples += x_batch.size(0)

            # ====== Update priors ω_j (Eq. 8) ======
            with torch.no_grad():
                all_P = []
                for x_batch in dataloader:
                    x_batch = x_batch.to(device)
                    P, _, _, _ = self(x_batch)
                    all_P.append(P)
                all_P = torch.stack(all_P, dim=0)
                self.pi = all_P.mean(dim=0)

            avg_loss = total_loss / n_samples
            print(f'[GCEALs] Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}')

            # ====== Early stopping check (Algorithm 1) ======
            if torch.any(self.pi <= 1.0 / (2 * self.n_clusters)):
                print(f"[Early Stopping] Cluster prior too small: {self.pi}")
                break


In [None]:
import torch
import torch.optim as optim
import pandas as pd
import numpy as np
import os
import sys
from torch.utils.data import DataLoader, TensorDataset
import json
from pathlib import Path

from sklearn.metrics import confusion_matrix, accuracy_score, adjusted_rand_score
from scipy.optimize import linear_sum_assignment


# Config
DATA_PATH = 'data_processed.csv'
LABEL_PATH = 'clusters.csv'
METADATA_PATH = 'metadata.json'


dataset_index = '37'



with open(METADATA_PATH, 'r') as f:
  metadata = json.load(f)
num_numeric = metadata['num_numerical_features']
categories = metadata['num_classes_per_cat']
n_clusters = metadata['num_clusters']


pretrain_epochs = 100
train_epochs = 20
batch_size = 256
lr = 1e-3
gamma=0.1



device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load data
x_real = pd.read_csv(DATA_PATH).values.astype(np.float32)
y_true = pd.read_csv(LABEL_PATH).values.flatten()
if y_true.dtype.kind in {'U', 'S', 'O'}:
    unique_labels, y_true = np.unique(np.asarray(y_true).astype(str), return_inverse=True)
x_real = torch.tensor(x_real, dtype=torch.float32).to(device)
dataset = TensorDataset(x_real)
dataloader = DataLoader(dataset, batch_size=128, shuffle=True)
N, D = x_real.shape



BEST_PATH = f"best_checkpoint_{dataset_index}.pth"
best_acc = -1.0
best_meta = None
for latent_dim in [5, 10, 15, 20]:
  print(f"latent_dim: {latent_dim}")
  # Models
  model = GCEALs(input_dim=D, latent_dim=latent_dim, n_clusters=n_clusters).to(device)


  # Optimizer
  optimizer = optim.Adam(model.parameters(), lr=lr)

  # Pretrain AE
  print("🔹 Pretraining autoencoder...")
  model.pretrain(x_real, optimizer, epochs=pretrain_epochs, device=device)

  # Full training
  print("🔹 Training GCEALs model...")
  model.train_gceals(x_real, optimizer, epochs=train_epochs, device=device, gamma=gamma)

  # Predict cluster assignments
  model.eval()
  with torch.no_grad():
      z, _ = model.ae(x_real)
      _, q = model.cluster_head(z)
      y_pred = q.argmax(dim=1).cpu().numpy()

  # Alignment
  def cluster_accuracy(y_true, y_pred):
      contingency = confusion_matrix(y_true, y_pred)
      row_ind, col_ind = linear_sum_assignment(-contingency)
      mapping = dict(zip(col_ind, row_ind))
      y_aligned = np.array([mapping[label] for label in y_pred])
      acc = accuracy_score(y_true, y_aligned)
      return acc, y_aligned

  accuracy, y_aligned = cluster_accuracy(y_true, y_pred)
  ari = adjusted_rand_score(y_true, y_pred)

  print(f"✅ Final clustering performance:")
  print(f"Accuracy: {accuracy:.4f}")
  print(f"ARI: {ari:.4f}")


  # Example: after you finish one training run
  results = {
      "accuracy": float(accuracy),      # your computed accuracy
      "ari": float(ari),                # your computed ARI
      "latent_dim": int(latent_dim),    # latent dimension
      "dataset_index": str(dataset_index)      # e.g., "mnist", "cifar10", etc.
  }

  # Path to results file
  results_file = Path("gceals_results.json")

  # If file exists, load and append; else create new
  if results_file.exists():
      with open(results_file, "r") as f:
          all_results = json.load(f)
  else:
      all_results = []

  all_results.append(results)

  # Save updated results
  with open(results_file, "w") as f:
      json.dump(all_results, f, indent=4)

  print(f"Saved results to {results_file}\n")

  # --- NEW: update 'best' and save checkpoint if improved ---
  if accuracy > best_acc:
      best_acc = float(accuracy)
      best_meta = {
          "latent_dim": int(latent_dim),
          "dataset_index": str(dataset_index),
          "accuracy": float(accuracy),
          "ari": float(ari)
      }

      torch.save({
          "model_state": model.state_dict(),                # AE + cluster head
            "optimizer_state": optimizer.state_dict() if optimizer is not None else None,
            "config": {
                "input_dim": D,
                "latent_dim": latent_dim,
                "n_clusters": n_clusters,
                "lr": lr,
                "batch_size": batch_size,
                "pretrain_epochs": pretrain_epochs,
                "gceals_epochs": train_epochs,
                "gamma": gamma,
                "architecture": {
                    "encoder_layers": [D, 500, 500, 2000, latent_dim],
                    "decoder_layers": [latent_dim, 2000, 500, 500, D],
                    },
                },
            "metrics": {
                "accuracy": float(accuracy),
                "ari": float(ari)
                }
            }, BEST_PATH)

      print(f"💾 New best model saved to {BEST_PATH} (acc={accuracy:.4f}, z={latent_dim}, K={n_clusters})")
    # ----------------------------------------------------------

# --- after both loops finish ---
print("🏁 Tuning finished.")
print(f"Best acc: {best_acc:.4f}")
if best_meta is not None:
    print(f"Best config: latent_dim={best_meta['latent_dim']}")


latent_dim: 5
🔹 Pretraining autoencoder...
[Pretrain] Epoch 1/100, Loss: 2.1478
[Pretrain] Epoch 2/100, Loss: 1.5777
[Pretrain] Epoch 3/100, Loss: 1.2285
[Pretrain] Epoch 4/100, Loss: 0.9982
[Pretrain] Epoch 5/100, Loss: 0.9848
[Pretrain] Epoch 6/100, Loss: 0.9445
[Pretrain] Epoch 7/100, Loss: 0.8180
[Pretrain] Epoch 8/100, Loss: 0.7329
[Pretrain] Epoch 9/100, Loss: 0.6794
[Pretrain] Epoch 10/100, Loss: 0.6460
[Pretrain] Epoch 11/100, Loss: 0.7188
[Pretrain] Epoch 12/100, Loss: 0.5657
[Pretrain] Epoch 13/100, Loss: 0.5549
[Pretrain] Epoch 14/100, Loss: 0.5368
[Pretrain] Epoch 15/100, Loss: 0.6294
[Pretrain] Epoch 16/100, Loss: 0.4759
[Pretrain] Epoch 17/100, Loss: 0.5130
[Pretrain] Epoch 18/100, Loss: 0.4619
[Pretrain] Epoch 19/100, Loss: 0.4508
[Pretrain] Epoch 20/100, Loss: 0.4992
[Pretrain] Epoch 21/100, Loss: 0.3922
[Pretrain] Epoch 22/100, Loss: 0.4297
[Pretrain] Epoch 23/100, Loss: 0.4057
[Pretrain] Epoch 24/100, Loss: 0.4032
[Pretrain] Epoch 25/100, Loss: 0.3723
[Pretrain] Epoch