In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class TabularEncoder(nn.Module):
    """
    Encoder for tabular data.
    Maps input features to latent mean and log-variance for Gaussian latent space.
    """
    def __init__(self, input_dim, hidden_dims, latent_dim):
        super().__init__()
        self.input_dim = input_dim
        self.hidden_dims = hidden_dims
        self.latent_dim = latent_dim

        layers = []
        prev_dim = input_dim
        for h_dim in hidden_dims:
            layers.append(nn.Linear(prev_dim, h_dim))
            layers.append(nn.ReLU())
            prev_dim = h_dim
        self.feature_extractor = nn.Sequential(*layers)

        self.mu_layer = nn.Linear(prev_dim, latent_dim)
        self.logvar_layer = nn.Linear(prev_dim, latent_dim)

    def forward(self, x):
        """
        Args:
            x: (batch_size, input_dim)

        Returns:
            mu: (batch_size, latent_dim)
            logvar: (batch_size, latent_dim)
        """
        h = self.feature_extractor(x)
        mu = self.mu_layer(h)
        logvar = self.logvar_layer(h)
        return mu, logvar


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class Denoiser(nn.Module):
    """
    Simple denoiser model for tabular diffusion.
    Applies a linear -> ReLU -> linear architecture.
    Outputs numerical predictions and categorical probabilities (via softmax).
    """
    def __init__(self, dim_in, latent_dim, dim_hidden, num_numeric, categories):
        super().__init__()
        self.num_numeric = num_numeric
        self.categories = categories
        self.net = nn.Sequential(
            nn.Linear(dim_in + latent_dim + 1, dim_hidden),
            nn.ReLU(),
            nn.Linear(dim_hidden, dim_in)
        )


    def forward(self, x, z, t):
        """
        Forward pass of the denoiser.

        Args:
            x: Tensor, shape (batch_size, dim_in)
            t: Tensor, shape (batch_size,)

        Returns:
            out_num: numerical denoised output
            out_cat: categorical probabilities after softmax
        """
        t = t.unsqueeze(1).float()
        xzt = torch.cat([x, z, t], dim=1)
        out = self.net(xzt)

        out_num = out[:, :self.num_numeric]
        out_cat_raw = out[:, self.num_numeric:]

        out_cat = []
        idx = 0
        for K in self.categories:
            logits = out_cat_raw[:, idx:idx+K]
            probs = F.softmax(logits, dim=1)
            out_cat.append(probs)
            idx += K

        out_cat = torch.cat(out_cat, dim=1) if out_cat else None
        return out_num, out_cat


In [None]:
import torch
import os
import numpy as np
import matplotlib.pyplot as plt
from sklearn.mixture import GaussianMixture
import torch.nn.utils as nn_utils
import itertools

class CluTaD:
    """
    CluTaD model: wraps encoder, denoiser, diffusion schedules, and GMM.
    Supports pretraining, ELBO training, GMM fitting, and sampling.
    """
    def __init__(self, encoder, denoiser, T, num_numeric, categories, n_clusters, device):
        super().__init__()
        self.encoder = encoder
        self.denoiser = denoiser
        self.T = T
        self.num_numeric = num_numeric
        self.categories = categories
        self.n_clusters = n_clusters
        self.device = device

        # Diffusion schedules
        betas = 0.01 * torch.arange(1, T + 1).float() / T
        alphas = 1 - betas
        self.alpha_bars = torch.cumprod(alphas, dim=0).to(device)
        self.sqrtab = self.alpha_bars.sqrt()
        self.sqrtmab = (1 - self.alpha_bars).sqrt()

        self.gmm = None  # Will hold fitted GMM

        # MLP head for auxiliary distribution Q
        self.mlp = nn.Linear(encoder.latent_dim, n_clusters).to(device)

        # Priors (π): uniform if not provided
        #if pi is None:
         #   self.pi = torch.full((n_clusters,), 1.0 / n_clusters, device=device)
        #else:
         #   self.pi = torch.tensor(pi, dtype=torch.float32, device=device)


    def pretrain_step(self, x, optimizer):
        """
        One pretraining step: predict noise from x_t + z + t
        """
        B = x.shape[0]
        t = torch.randint(1, self.T + 1, (B,), device=self.device) - 1
        noise = torch.randn_like(x)

        x_t = self.sqrtab[t].unsqueeze(1) * x + self.sqrtmab[t].unsqueeze(1) * noise

        mu, logvar = self.encoder(x)
        z = mu + torch.randn_like(mu) * (0.5 * logvar).exp()

        t_norm = t.float() / self.T
        pred_num, pred_cat = self.denoiser(x_t, z, t_norm)

        if self.num_numeric > 0 and pred_num is not None:
            noise_num = noise[:, :self.num_numeric]
            loss_num = ((pred_num - noise_num) ** 2).mean()
        else:
            loss_num = torch.zeros((), device=x.device)

        loss_cat = torch.zeros((), device=x.device)

        has_cats = bool(self.categories) and sum(self.categories) > 0
        if has_cats and pred_cat is not None:
          x0_cat = x[:, self.num_numeric:]
          idx_c = 0
          for K in self.categories:
              target = x0_cat[:, idx_c:idx_c+K]
              pred_prob = pred_cat[:, idx_c:idx_c+K]
              kl = (target * (torch.log(target + 1e-10) - torch.log(pred_prob + 1e-10))).sum(1).mean()
              loss_cat += kl
              idx_c += K
          if len(self.categories) > 0:
              loss_cat /= len(self.categories)

        loss = loss_num + loss_cat

        optimizer.zero_grad()
        loss.backward()
        nn_utils.clip_grad_norm_(
            itertools.chain(self.encoder.parameters(), self.denoiser.parameters()),
            max_norm=0.5
        )
        optimizer.step()

        return loss.item(), loss_num.item(), loss_cat.item()


    def pretrain(self, dataloader, optimizer, epochs, batch_size, plot_freq=100):
        """
        Pretrain encoder + denoiser over multiple steps.

        Args:
            dataloader: full dataset tensor (N, D)
            optimizer: optimizer for encoder + denoiser
            epochs: number of pretraining epochs
            batch_size: batch size
            plot_freq: print loss every plot_freq steps
        """
        for epoch in range(epochs):
            total_loss = 0.0
            n_samples = 0
            for (x_batch,) in dataloader:
                if x_batch.ndim == 1:
                    x_batch = x_batch.unsqueeze(0)
                x_batch = x_batch.to(self.device)

                loss, loss_num, loss_cat = self.pretrain_step(x_batch, optimizer)

                total_loss += loss * x_batch.size(0)
                n_samples += x_batch.size(0)

            avg_loss = total_loss / n_samples
            if (epoch+1) % plot_freq == 0:
              print(f'[Pretrain] Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}')


    def fit_gmm(self, dataloader):
        """
        Fit a Gaussian Mixture Model on the latent space.

        Args:
            x_real: full dataset tensor (N, D)
            n_clusters: number of clusters to fit
        """
        self.encoder.eval()
        latent_z = []
        with torch.no_grad():
            for (x,) in dataloader:
                x = x.to(self.device)
                z_mu, z_sigma2_log = self.encoder(x)
                z = torch.randn_like(z_mu) * torch.exp(z_sigma2_log / 2) + z_mu
                latent_z.append(z)
        latent_z = torch.cat(latent_z, 0).detach().cpu().numpy()

        if self.gmm is not None:
            init_means = self.gmm.means_
            init_precisions = self.gmm.precisions_
            gmm = GaussianMixture(n_components = self.n_clusters,
                                  covariance_type = 'diag',
                                  reg_covar=1e-1,
                                  means_init = init_means,
                                  precisions_init = init_precisions)
        else:
          gmm = GaussianMixture(n_components=self.n_clusters, covariance_type='diag', reg_covar=1e-2)
        gmm.fit(latent_z)
        #gmm.weights_ = self.pi.detach().cpu().numpy()
        self.gmm = gmm
        #print(f"✅ GMM fitted with {self.n_clusters} components")


    def elbo_step(self, x, optimizer, kl_weight=0.1):
        """
        One ELBO training step for CluTaD.

        Args:
            x: input batch (B, D)
            optimizer: optimizer
            kl_weight: weight for the KL terms

        Returns:
            total_loss, rec_loss, kl_loss
        """
        B = x.shape[0]
        t = torch.randint(1, self.T + 1, (B,), device=self.device) - 1
        noise = torch.randn_like(x)

        # Diffusion forward process
        x_t = self.sqrtab[t].unsqueeze(1) * x + self.sqrtmab[t].unsqueeze(1) * noise

        # Encoder
        mu_phi, logvar_phi = self.encoder(x)
        sigma2_phi = torch.exp(logvar_phi)
        z = mu_phi + torch.randn_like(mu_phi) * (0.5 * logvar_phi).exp()

        # Denoising
        t_norm = t.float() / self.T
        pred_num, pred_cat = self.denoiser(x_t, z, t_norm)

        # Reconstruction loss
        if self.num_numeric > 0 and pred_num is not None:
            noise_num = noise[:, :self.num_numeric]
            rec_loss_num = ((pred_num - noise_num) ** 2).mean()
        else:
            rec_loss_num = torch.zeros((), device=x.device)

        rec_loss_cat = 0.0
        x0_cat = x[:, self.num_numeric:]
        idx_c = 0
        for K in self.categories:
            target = x0_cat[:, idx_c:idx_c + K]
            pred_prob = pred_cat[:, idx_c:idx_c + K]
            kl = (target * (torch.log(target + 1e-10) - torch.log(pred_prob + 1e-10))).sum(1).mean()
            rec_loss_cat += kl
            idx_c += K
        if len(self.categories) > 0:
            rec_loss_cat /= len(self.categories)

        rec_loss = rec_loss_num + rec_loss_cat

        # ===== Cluster assignments =====
        # Q from MLP head
        logits = self.mlp(z)        # (B, K)
        Q = F.softmax(logits, dim=1)

        # P from GMM (Mahalanobis distance + softmax)
        B, D = z.shape
        z_exp = z.unsqueeze(1).expand(B, self.n_clusters, D)                   # (B, K, D)
        mu = torch.from_numpy(self.gmm.means_).to(z.device).float().unsqueeze(0)      # (1, K, D)
        var = torch.from_numpy(self.gmm.covariances_).to(z.device).float().unsqueeze(0)  # (1, K, D)

        dist = torch.sqrt(((z_exp - mu) ** 2 / var).sum(dim=2)) /2             # (B, K)
        P = F.softmax(-dist, dim=1)                                            # (B, K)

        # Update priors for stopping criterion
        #self.pi = P.mean(dim=0).detach()

        # ===== Clustering loss KL(P||Q) =====
        # cluster_loss = F.kl_div(Q.log(), P, reduction='batchmean')
        cluster_loss = -(Q * torch.log(P + 1e-5)).sum(dim=1).mean()

        ()

        # ===== Total loss =====
        total_loss = rec_loss + kl_weight * cluster_loss

        # Backward + update
        optimizer.zero_grad()
        total_loss.backward()
        nn_utils.clip_grad_norm_(
            itertools.chain(self.encoder.parameters(),
                            self.denoiser.parameters(),
                            self.mlp.parameters()),
            max_norm=0.5
        )
        optimizer.step()

        return total_loss.item(), rec_loss.item(), cluster_loss.item()


    def train_elbo(self, dataloader, optimizer, batch_size, kl_weight=0.1, plot_freq=100):
        """
        ELBO training loop: combines reconstruction + KL loss.

        Args:
            x_real: dataset (N, D)
            optimizer: optimizer
            batch_size: batch size
            kl_weight: weight on KL
            plot_freq: print every plot_freq steps
        """
        self.encoder.train()
        self.denoiser.train()
        self.mlp.train()
        total_loss = 0.0
        total_recon_loss = 0.0
        total_kl_loss = 0.0
        n_samples = 0

        for (x_batch,) in dataloader:
            if x_batch.ndim == 1:
                x_batch = x_batch.unsqueeze(0)
            x_batch = x_batch.to(self.device)
            loss, rec_loss, kl_loss = self.elbo_step(x_batch, optimizer, kl_weight=kl_weight)

            total_loss += loss * x_batch.size(0)
            total_recon_loss += rec_loss * x_batch.size(0)
            total_kl_loss += kl_loss * x_batch.size(0)
            n_samples += x_batch.size(0)

        avg_loss = total_loss / n_samples
        avg_recon_loss = total_recon_loss / n_samples
        avg_kl_loss = total_kl_loss / n_samples

        # ===== Early stopping check =====
        stop = False
        if (np.any(self.gmm.weights_ <= 1.0 / (2 * self.n_clusters))):
        #if (self.n_clusters > 2) and (np.any(self.gmm.weights_ <= 1.0 / (2 * self.n_clusters))):
            #print(f"[Early Stopping] Cluster prior too small: {self.gmm.weights_}")
            stop = True

        return avg_loss, avg_recon_loss, avg_kl_loss, stop


In [None]:
import torch
import torch.optim as optim
import pandas as pd
import numpy as np
import os
import sys
import json
from pathlib import Path

from sklearn.metrics import confusion_matrix, accuracy_score, adjusted_rand_score
from scipy.optimize import linear_sum_assignment
from torch.utils.data import DataLoader, TensorDataset


# Define cluster alignment function
def cluster_accuracy(y_true, y_pred):
    contingency = confusion_matrix(y_true, y_pred)
    row_ind, col_ind = linear_sum_assignment(-contingency)
    mapping = dict(zip(col_ind, row_ind))
    y_aligned = np.array([mapping[label] for label in y_pred])
    acc = accuracy_score(y_true, y_aligned)
    return acc, y_aligned


dataset_index = '40982l'



# Config
DATA_PATH = 'data_processed.csv'
CHECKPOINT_PATH_PRE = 'pretrain_checkpoint.pth'
CHECKPOINT_PATH_FINAL = 'final_checkpoint.pth'
LABEL_PATH = 'clusters.csv'
METADATA_PATH = 'metadata.json'


with open(METADATA_PATH, 'r') as f:
  metadata = json.load(f)
num_numeric = metadata['num_numerical_features']
categories = metadata['num_classes_per_cat']
n_clusters = metadata['num_clusters']

T = 100 # this is a question !!!

pretrain_steps = 1000 # same as in example
em_epochs = 1000 # same as in example (it is 1000 altogether)
batch_size = 256 # same as in example
hidden_dims=[500, 500, 2000] # same as in example
kl_weight=0.1 # same as in example
lr = 1e-3 # same as in example



device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load data
df = pd.read_csv(DATA_PATH)
x_real = torch.tensor(df.values, dtype=torch.float32).to(device)
dataset = TensorDataset(x_real)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
# Load ground truth
y_true = pd.read_csv(LABEL_PATH).values.flatten()
if y_true.dtype.kind in {'U', 'S', 'O'}:
    unique_labels, y_true = np.unique(np.asarray(y_true).astype(str), return_inverse=True)
N, D = x_real.shape

BEST_PATH = f"best_checkpoint_{dataset_index}.pth"
best_acc = -1.0
best_meta = None
dim_hidden = 1000
latent_dim = 15
#for kl_weight in range(2, 11):
 # kl_weight = kl_weight/10
  #print(f"kl_weight: {kl_weight}")
for dim_hidden in [500, 1000]:
  for latent_dim in [5, 10, 15, 20]:
    print(f"dim_hidden {dim_hidden}, latent_dim: {latent_dim}")

    # Models
    encoder = TabularEncoder(
        input_dim=D,
        hidden_dims=hidden_dims,
        latent_dim=latent_dim
    ).to(device)

    denoiser = Denoiser(
        dim_in=D,
        latent_dim=latent_dim,
        dim_hidden=dim_hidden,
        num_numeric=num_numeric,
        categories=categories
    ).to(device)

    # CluTaD wrapper
    model = CluTaD(
        encoder=encoder,
        denoiser=denoiser,
        T=T,
        num_numeric=num_numeric,
        categories=categories,
        n_clusters=n_clusters,
        device=device
    )

    # Optimizer
    optimizer = optim.Adam(
        list(encoder.parameters()) +
        list(denoiser.parameters()) +
        list(model.mlp.parameters()),
        lr=lr, weight_decay=1e-4
    )

    # 🔹 Pretraining
    print("🔹 Starting pretraining...")
    model.pretrain(dataloader, optimizer, epochs=pretrain_steps, batch_size=batch_size, plot_freq=100)

    # Save pretraining checkpoint
    #os.makedirs(os.path.dirname(CHECKPOINT_PATH_PRE), exist_ok=True)
    #torch.save({
    #    'encoder': encoder.state_dict(),
    #    'denoiser': denoiser.state_dict(),
    #    'optimizer': optimizer.state_dict(),
    #    'T': T,
    #    'num_numeric': num_numeric,
    #    'categories': categories
    #}, CHECKPOINT_PATH_PRE)
    #print(f"✅ Pretraining checkpoint saved at {CHECKPOINT_PATH_PRE}")

    # 🔹 Fit initial GMM (E-step 0)
    print("🔹 Fitting initial GMM...")
    model.fit_gmm(dataloader)

    # Optimizer
    optimizer = optim.Adam(
        list(encoder.parameters()) +
        list(denoiser.parameters()) +
        list(model.mlp.parameters()),
        lr=lr, weight_decay=1e-4
    )

    # 🔹 EM training loop
    print("🔹 Starting EM training...")
    for epoch in range(em_epochs):
        avg_loss, avg_recon_loss, avg_cluster_loss, stop = model.train_elbo(
            dataloader, optimizer, batch_size=batch_size, kl_weight=kl_weight, plot_freq=50
        )
        if (epoch+1) % 100 == 0:
            print(f"Epoch {epoch+1}/{em_epochs}, "
                  f"Loss: {avg_loss:.4f}, Recon-Loss: {avg_recon_loss:.4f}, "
                  f"Cluster-Loss: {avg_cluster_loss:.4f}")

        if stop:
            print(f"⏹️ Stopping early at epoch {epoch+1}")
            break

        if epoch % 10 == 0:
          # E-step
            model.fit_gmm(dataloader)

        #with torch.no_grad():
          # mu, logvar = model.encoder(x_real)
          #z_np = mu.cpu().numpy()
          #y_pred = model.gmm.predict(z_np)

          # Compute metrics
          #accuracy, y_aligned = cluster_accuracy(y_true, y_pred)
          #print(y_pred)
          #print(accuracy)
    # Save final model
    #os.makedirs(os.path.dirname(CHECKPOINT_PATH_FINAL), exist_ok=True)
    #torch.save({
    #    'encoder': encoder.state_dict(),
    #    'denoiser': denoiser.state_dict(),
    #    'optimizer': optimizer.state_dict(),
    #    'gmm': model.gmm
    #}, CHECKPOINT_PATH_FINAL)
    #print(f"✅ Final checkpoint saved at {CHECKPOINT_PATH_FINAL}")


    # Encode all data and compute GMM assignments
    with torch.no_grad():
        mu, logvar = model.encoder(x_real)
        z_np = mu.cpu().numpy()
        y_pred = model.gmm.predict(z_np)

    # Compute metrics
    accuracy, y_aligned = cluster_accuracy(y_true, y_pred)
    ari = adjusted_rand_score(y_true, y_pred)

    # Print results
    print(f"✅ Final clustering performance:")
    print(f"Accuracy: {accuracy:.4f}")
    print(f"ARI: {ari:.4f}")


    # Example: after you finish one training run
    results = {
        "accuracy": float(accuracy),      # your computed accuracy
        "ari": float(ari),                # your computed ARI
        "T": int(T),                      # diffusion timesteps
        "dim_hidden": int(dim_hidden),    # hidden dimension
        "latent_dim": int(latent_dim),    # latent dimension
        "dataset_index": str(dataset_index),      # e.g., "mnist", "cifar10", etc.
        "lambda": float(kl_weight)
    }

    # Path to results file
    results_file = Path("PCA/clutad_pca_results.json")

    # If file exists, load and append; else create new
    if results_file.exists():
        with open(results_file, "r") as f:
            all_results = json.load(f)
    else:
        all_results = []

    all_results.append(results)

    # Save updated results
    with open(results_file, "w") as f:
        json.dump(all_results, f, indent=4)

    print(f"Saved results to {results_file}\n")

    # --- NEW: update 'best' and save checkpoint if improved ---
    if accuracy > best_acc:
        best_acc = float(accuracy)
        best_meta = {
            "T": int(T),
            "latent_dim": int(latent_dim),
            "dim_hidden": int(dim_hidden),
            "dataset_index": str(dataset_index),
            "accuracy": float(accuracy),
            "ari": float(ari)
        }

        # save immediately so you don't lose it if the script stops
        torch.save({
            "encoder": encoder.state_dict(),
            "denoiser": denoiser.state_dict(),
            "optimizer": optimizer.state_dict(),  # optional but handy
            "gmm": model.gmm,                     # sklearn object; torch.save pickles it
            "config": {
                "T": T,
                "num_numeric": num_numeric,
                "categories": categories,
                "n_clusters": n_clusters,
                "hidden_dims": hidden_dims,
                "dim_hidden": dim_hidden,
                "latent_dim": latent_dim,
                "kl_weight": kl_weight,
                "lr": lr,
                "batch_size": batch_size,
                "em_epochs": em_epochs,
                "pretrain_steps": pretrain_steps,
            },
            "metrics": {
                "accuracy": float(accuracy),
                "ari": float(ari)
            }
        }, BEST_PATH)

        print(f"💾 New best model saved to {BEST_PATH} (acc={accuracy:.4f}, T={T}, z={latent_dim})")
    # ----------------------------------------------------------

# --- after both loops finish ---
print("🏁 Tuning finished.")
print(f"Best acc: {best_acc:.4f}")
if best_meta is not None:
    print(f"Best config: dim_hidden={best_meta['dim_hidden']}, latent_dim={best_meta['latent_dim']}")

dim_hidden 500, latent_dim: 5
🔹 Starting pretraining...
[Pretrain] Epoch 100/1000, Loss: 0.2738
[Pretrain] Epoch 200/1000, Loss: 0.2489
[Pretrain] Epoch 300/1000, Loss: 0.2063
[Pretrain] Epoch 400/1000, Loss: 0.1634
[Pretrain] Epoch 500/1000, Loss: 0.1325
[Pretrain] Epoch 600/1000, Loss: 0.1021
[Pretrain] Epoch 700/1000, Loss: 0.0829
[Pretrain] Epoch 800/1000, Loss: 0.0876
[Pretrain] Epoch 900/1000, Loss: 0.0833
[Pretrain] Epoch 1000/1000, Loss: 0.1009
🔹 Fitting initial GMM...
🔹 Starting EM training...
⏹️ Stopping early at epoch 1
✅ Final clustering performance:
Accuracy: 0.6611
ARI: 0.4034
Saved results to clutad_PCA_results.json

💾 New best model saved to best_checkpoint_40982l.pth (acc=0.6611, T=100, z=5)
dim_hidden 500, latent_dim: 10
🔹 Starting pretraining...
[Pretrain] Epoch 100/1000, Loss: 0.2690
[Pretrain] Epoch 200/1000, Loss: 0.2702
[Pretrain] Epoch 300/1000, Loss: 0.1882
[Pretrain] Epoch 400/1000, Loss: 0.1508
[Pretrain] Epoch 500/1000, Loss: 0.1212
[Pretrain] Epoch 600/1000