In [None]:
import os
import zipfile
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
from PIL import Image, ImageFilter
import random
import matplotlib.pyplot as plt
import torchvision

In [None]:
class TwoCropTransform:
    """
    Wrapper return two views (views) of the same image.
    Scenario 3 (Contrastive Learning / SimCLR).
    """
    def __init__(self, base_transform):
        self.base_transform = base_transform

    def __call__(self, x):
        q = self.base_transform(x)
        k = self.base_transform(x)
        return [q, k]

class GaussianBlur(object):
    """Gaussian Blur augmentation (SimCLR)."""
    def __init__(self, sigma=[.1, 2.]):
        self.sigma = sigma

    def __call__(self, x):
        sigma = random.uniform(self.sigma[0], self.sigma[1])
        x = x.filter(ImageFilter.GaussianBlur(radius=sigma))
        return x

def get_brb_transform(dataset_name, is_contrastive=False, train=True):
    """
    - train=True: Apply Augmentation (Affine or SimCLR).
    - train=False: Only Normalize (Test/Validation).
    """

    # Group 1: GRAYSCALE
    if dataset_name in ['MNIST', 'KMNIST', 'Fashion-MNIST', 'USPS', 'OPTDIGITS']:
        if train:
            # Use Random Affine for grayscale
            transform_list = [
                transforms.ToPILImage() if dataset_name == 'OPTDIGITS' else transforms.Lambda(lambda x: x),
                transforms.RandomAffine(degrees=10, translate=(0.1, 0.1), scale=(0.9, 1.1)),
                transforms.ToTensor(),
                transforms.Normalize((0.5,), (0.5,))
            ]
        else:
            transform_list = [
                transforms.ToPILImage() if dataset_name == 'OPTDIGITS' else transforms.Lambda(lambda x: x),
                transforms.ToTensor(),
                transforms.Normalize((0.5,), (0.5,))
            ]

        final_transform = transforms.Compose(transform_list)

    # Group 2: COLOR (CIFAR-10, CIFAR-100-20, GTSRB)
    elif dataset_name in ['CIFAR-10', 'CIFAR-100-20', 'GTSRB']:
        # Mean/Std of CIFAR/ImageNet
        normalize = transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010])
        input_size = 32

        if train:
            # Use SimCLR augmentations for data color.
            final_transform = transforms.Compose([
                transforms.RandomResizedCrop(size=input_size, scale=(0.2, 1.)),
                transforms.RandomHorizontalFlip(),
                transforms.RandomApply([
                    transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)
                ], p=0.8),
                transforms.RandomGrayscale(p=0.2),
                transforms.ToTensor(),
                normalize,
            ])
        else:
            # Resize to Std vÃ  Normalize
            final_transform = transforms.Compose([
                transforms.Resize((input_size, input_size)) if dataset_name == 'GTSRB' else transforms.Lambda(lambda x: x),
                transforms.ToTensor(),
                normalize,
            ])

    else:
        raise ValueError(f"Unknown dataset: {dataset_name}")

    if is_contrastive and train:
        return TwoCropTransform(final_transform)

    return final_transform

In [None]:
class OptDigitsDataset(Dataset):
    def __init__(self, data, targets, transform=None):
        self.data = torch.tensor(data, dtype=torch.float32).view(-1, 8, 8)
        self.targets = torch.tensor(targets, dtype=torch.long)
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img, label = self.data[idx], self.targets[idx]
        if self.transform:
            img = self.transform(img)
        return img, label

class GTSRBDataset(Dataset):
    def __init__(self, csv_file, root_dir, transform=None):
        self.data = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_path = os.path.join(self.root_dir, self.data.iloc[idx]['Path'])
        image = Image.open(img_path).convert('RGB')
        label = int(self.data.iloc[idx]['ClassId'])
        if self.transform:
            image = self.transform(image)
        return image, label

class CIFAR100Coarse(Dataset):
    """
    Wrapper mapping CIFAR-100 (100 classes) -> CIFAR-100-20 (20 superclasses).
    """
    def __init__(self, root, train, transform, download):
        self.dataset = datasets.CIFAR100(root=root, train=train, transform=transform, download=download)

        # Hardcoded mapping from Fine (0-99) to Coarse (0-19)
        self.coarse_map = np.array([
            4, 1, 14, 8, 0, 6, 7, 7, 18, 3, 3, 14, 9, 18, 7, 11, 3, 9, 7, 11,
            6, 11, 5, 10, 7, 6, 13, 15, 3, 15, 0, 11, 1, 10, 12, 14, 16, 9, 11, 5,
            5, 19, 8, 8, 15, 13, 14, 17, 18, 10, 16, 4, 17, 4, 2, 0, 17, 4, 18, 17,
            10, 3, 2, 12, 12, 16, 12, 1, 9, 19, 2, 10, 0, 1, 16, 12, 9, 13, 15, 13,
            16, 19, 2, 4, 6, 19, 5, 5, 8, 19, 18, 1, 2, 15, 6, 0, 17, 8, 14, 13
        ])

        # Pre-calculate coarse labels
        self.targets = [self.coarse_map[y] for y in self.dataset.targets]

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        # self.dataset[idx] applied transform
        img, _ = self.dataset[idx]
        return img, self.targets[idx]

In [None]:
def get_data_loaders(dataset_name,
                     batch_size=32,
                     download_dir='./datasets',
                     transform=None,
                     is_contrastive=False):
    """
    load data

    Args:
        dataset_name (str): name (MNIST, CIFAR-10, CIFAR-100-20, ...).
        batch_size (int): size batch.
        download_dir (str): path data.
        transform (transforms): (Optional) Define your own transform if you don't want to use the default one.
        is_contrastive (bool): If True, enable TwoCropTransform mode for SimCLR (Scenario 3).

    Returns:
        train_loader, test_loader
    """
    os.makedirs(download_dir, exist_ok=True)

    # Transform Strategy
    if transform is None:
        # Train: Apply Augmentation (Affine or Color Jitter) + TwoCrop (if contrastive)
        train_transform = get_brb_transform(dataset_name, is_contrastive=is_contrastive, train=True)
        test_transform = get_brb_transform(dataset_name, is_contrastive=False, train=False)
    else:
        train_transform = transform
        test_transform = transform

    # Init Dataset
    if dataset_name == 'MNIST':
        train_dataset = datasets.MNIST(download_dir, True, download=True, transform=train_transform)
        test_dataset  = datasets.MNIST(download_dir, False, download=True, transform=test_transform)

    elif dataset_name == 'KMNIST':
        train_dataset = datasets.KMNIST(download_dir, True, download=True, transform=train_transform)
        test_dataset  = datasets.KMNIST(download_dir, False, download=True, transform=test_transform)

    elif dataset_name == 'Fashion-MNIST':
        train_dataset = datasets.FashionMNIST(download_dir, True, download=True, transform=train_transform)
        test_dataset  = datasets.FashionMNIST(download_dir, False, download=True, transform=test_transform)

    elif dataset_name == 'USPS':
        train_dataset = datasets.USPS(download_dir, True, download=True, transform=train_transform)
        test_dataset  = datasets.USPS(download_dir, False, download=True, transform=test_transform)

    elif dataset_name == 'OPTDIGITS':
        digits = load_digits()
        X_train, X_test, y_train, y_test = train_test_split(digits.data, digits.target, test_size=0.2, random_state=42)
        train_dataset = OptDigitsDataset(X_train, y_train, train_transform)
        test_dataset  = OptDigitsDataset(X_test, y_test, test_transform)

    elif dataset_name == 'CIFAR-10':
        train_dataset = datasets.CIFAR10(download_dir, True, download=True, transform=train_transform)
        test_dataset  = datasets.CIFAR10(download_dir, False, download=True, transform=test_transform)

    elif dataset_name == 'CIFAR-100-20':
        train_dataset = CIFAR100Coarse(download_dir, True, train_transform, True)
        test_dataset  = CIFAR100Coarse(download_dir, False, test_transform, True)

    elif dataset_name == 'GTSRB':
        dataset_path = os.path.join(download_dir, 'gtsrb')
        if not os.path.exists(os.path.join(dataset_path, 'Train.csv')):
            print("Downloading GTSRB from Kaggle...")
            try:
                os.system(f'kaggle datasets download -d meowmeowmeowmeowmeow/gtsrb-german-traffic-sign -p {download_dir}')
                with zipfile.ZipFile(os.path.join(download_dir, 'gtsrb-german-traffic-sign.zip'), 'r') as zip_ref:
                    zip_ref.extractall(dataset_path)
            except Exception as e:
                print(f"Error downloading GTSRB: {e}. Please manually download or setup Kaggle API.")

        train_dataset = GTSRBDataset(os.path.join(dataset_path, 'Train.csv'), dataset_path, train_transform)
        test_dataset = GTSRBDataset(os.path.join(dataset_path, 'Test.csv'), dataset_path, test_transform)

    else:
        raise ValueError(f"Dataset '{dataset_name}' is not supported in BRB project.")

    # Create DataLoader
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=2)
    test_loader  = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

    return train_loader, test_loader

**Models**

In [None]:
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import resnet18

def init_weights(m):
    """
    Kaiming Uniform Initialization.
    Paper BRB yÃªu cáº§u sá»­ dá»¥ng Kaiming init khi reset weights[cite: 206].
    """
    if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
        nn.init.kaiming_uniform_(m.weight, mode='fan_in', nonlinearity='relu')
        if m.bias is not None:
            m.bias.data.fill_(0.0)


# FEED-FORWARD AUTOENCODER (Group 1: ['MNIST', 'KMNIST', 'Fashion-MNIST', 'USPS', 'OPTDIGITS', etc.)
class Autoencoder(nn.Module):
    def __init__(self, input_dim=784, latent_dim=10, dims=[500, 500, 2000]):
        """
        Kiáº¿n trÃºc chuáº©n cá»§a DEC/IDEC Ä‘Æ°á»£c dÃ¹ng trong BRB.
        Máº·c Ä‘á»‹nh cho MNIST: 784 -> 500 -> 500 -> 2000 -> 10
        """
        super(Autoencoder, self).__init__()

        # Encoder
        self.encoder_layers = nn.ModuleList()
        curr_dim = input_dim

        # hidden layer(500, 500, 2000)
        for dim in dims:
            self.encoder_layers.append(nn.Linear(curr_dim, dim))
            self.encoder_layers.append(nn.ReLU())
            curr_dim = dim

        # layer Latent (Embedding)
        self.embedding_layer = nn.Linear(curr_dim, latent_dim)

        # Decoder
        self.decoder_layers = nn.ModuleList()
        # Reverse : 2000 -> 500 -> 500
        reversed_dims = list(reversed(dims))

        # from latent
        self.decoder_layers.append(nn.Linear(latent_dim, reversed_dims[0]))
        self.decoder_layers.append(nn.ReLU())
        curr_dim = reversed_dims[0]

        # hidden between
        for i in range(1, len(reversed_dims)):
            self.decoder_layers.append(nn.Linear(curr_dim, reversed_dims[i]))
            self.decoder_layers.append(nn.ReLU())
            curr_dim = reversed_dims[i]

        # Reconstruction
        self.reconst_layer = nn.Linear(curr_dim, input_dim)

        # init weight
        self.apply(init_weights)

    def forward(self, x):
        if x.dim() > 2:
            x = x.view(x.size(0), -1)

        # Encode
        h = x
        for layer in self.encoder_layers:
            h = layer(h)
        z = self.embedding_layer(h)

        # Decode
        r = z
        for layer in self.decoder_layers:
            r = layer(r)
        x_recon = self.reconst_layer(r)

        # use Tanh
        x_recon = torch.tanh(x_recon)

        return x_recon, z

# RESNET-18 BACKBONE (Group 2: CIFAR-10, GTSRB)

class ContrastiveResNet18(nn.Module):
    def __init__(self, latent_dim=128):
        """
        ResNet-18 backbone modified for CIFAR small images + Projection Head.
        """
        super(ContrastiveResNet18, self).__init__()

        # Load standard ResNet18
        backbone = resnet18(pretrained=False)

        # MODIFICATION CHO CIFAR/GTSRB (32x32 images)

        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False),
            backbone.bn1,
            backbone.relu,
            nn.Identity(),
            backbone.layer1,
            backbone.layer2,
            backbone.layer3,
            backbone.layer4,
            nn.AdaptiveAvgPool2d((1, 1))
        )

        self.feature_dim = 512
        self.projector = nn.Sequential(
            nn.Linear(self.feature_dim, 512),
            nn.ReLU(),
            nn.Linear(512, latent_dim)
        )

        self.apply(init_weights)

    def forward(self, x):
        # Feature Extraction (Representation h)
        h = self.features(x)
        h = h.view(h.size(0), -1) # Flatten (batch, 512)

        # Projection (z)
        z = self.projector(h)

        # Return both h (for clustering) vÃ  z (for training loss SimCLR)
        return h, z

**metrics**

In [None]:
from sklearn.metrics import normalized_mutual_info_score, adjusted_rand_score
from scipy.optimize import linear_sum_assignment

def cluster_acc(y_true, y_pred):
    """
    Args:
        y_true (np.array): Ground Truth
        y_pred (np.array): Predicted Clusters

    Returns:
        float: Accuracy [0, 1].
    """
    y_true = y_true.astype(np.int64)
    y_pred = y_pred.astype(np.int64)

    # Size of  y_pred and y_true must same
    assert y_pred.size == y_true.size

    D = max(y_pred.max(), y_true.max()) + 1
    w = np.zeros((D, D), dtype=np.int64)

    # Confusion Matrix
    # w[i, j] = number of samples belonging to cluster i but with actual label j
    for i in range(y_pred.size):
        w[y_pred[i], y_true[i]] += 1

    row_ind, col_ind = linear_sum_assignment(w.max() - w)
    accuracy = sum([w[i, j] for i, j in zip(row_ind, col_ind)]) * 1.0 / y_pred.size

    return accuracy

def evaluate_clustering(y_true, y_pred):
    """
    Calculate all the key metrics for Deep Clustering.

    Args:
        y_true (np.array): Ground Truth
        y_pred (np.array): Predicted Clusters

    Returns:
        dict: Dictionary contain {'ACC': float, 'NMI': float, 'ARI': float}
    """
    if hasattr(y_true, 'cpu'): y_true = y_true.cpu().numpy()
    if hasattr(y_pred, 'cpu'): y_pred = y_pred.cpu().numpy()
    if isinstance(y_true, list): y_true = np.array(y_true)
    if isinstance(y_pred, list): y_pred = np.array(y_pred)

    # Accuracy (ACC)
    acc = cluster_acc(y_true, y_pred)

    # Normalized Mutual Information (NMI)
    nmi = normalized_mutual_info_score(y_true, y_pred)

    # Adjusted Rand Index (ARI)
    ari = adjusted_rand_score(y_true, y_pred)

    return {
        'ACC': np.round(acc, 4),
        'NMI': np.round(nmi, 4),
        'ARI': np.round(ari, 4)
    }

**dec**

In [None]:
class ClusteringLayer(nn.Module):
    """
    The Clustering Layer stores the Centroids and calculates the Q distribution.
    """
    def __init__(self, n_clusters, n_z, alpha=1.0):
        super(ClusteringLayer, self).__init__()
        self.n_clusters = n_clusters
        self.n_z = n_z
        self.alpha = alpha # Alpha=1.0 -  Student's t-distribution

        # Centroids is the training parameter.(nn.Parameter)
        # Shape: (number of cluster, size latent)
        self.centroids = nn.Parameter(torch.Tensor(n_clusters, n_z))

        # init centroids
        nn.init.xavier_normal_(self.centroids.data)

    def forward(self, z):
        """
        Compute Soft Assignment (Q) :  latent z and centroids.
        Input: z (batch_size, n_z)
        Output: q (batch_size, n_clusters)
        """
        diff = z.unsqueeze(1) - self.centroids.unsqueeze(0)
        squared_dist = torch.sum(diff**2, dim=2) # (batch, n_clusters)

        q = 1.0 + (squared_dist / self.alpha)
        q = torch.pow(q, -(self.alpha + 1.0) / 2.0)

        # Normalize
        q = torch.div(q, torch.sum(q, dim=1, keepdim=True))

        return q

def target_distribution(q):
    weight = q**2 / q.sum(0)
    p = (weight.t() / weight.sum(1)).t()
    return p.detach()

def kl_divergence_loss(q, p):
    """
    KL Divergence Loss: KL(P || Q) = sum(p * log(p/q))
    """
    return F.kl_div(torch.log(q), p, reduction='batchmean')

**idec**

In [None]:
class IDEC(nn.Module):
    """
    Improved Deep Embedded Clustering (IDEC).
    combine Autoencoder + Clustering Layer.
    """
    def __init__(self, autoencoder, n_clusters, alpha=1.0):
        super(IDEC, self).__init__()
        self.autoencoder = autoencoder
        self.n_clusters = n_clusters
        self.alpha = alpha

        self.latent_dim = autoencoder.embedding_layer.out_features

        # init Clustering Layer (similar DEC)
        self.clustering_layer = ClusteringLayer(self.n_clusters, self.latent_dim, self.alpha)

    def forward(self, x):
        """
        Forward pass.
        Input: x
        Output:
            - x_recon: Reconstructed image (for calculating L_rec)
            - q: Soft distribution (for calculating L_clus)
            - z: Latent vector
        """
        # Autoencoder
        x_recon, z = self.autoencoder(x)

        # Clustering Layer
        q = self.clustering_layer(z)

        return x_recon, q, z

def idec_loss_function(x, x_recon, q, p, gamma=0.1):
    """
    IDEC: L = L_rec + gamma * L_clus

    Args:
        x: origin img
        x_recon: Reconstructed image from Decoder.
        q: Soft assignment from Clustering Layer.
        p: Target distribution.
        gamma: hyperparameter.

    Returns:
        total_loss, reconstruction_loss, clustering_loss
    """
    # Reconstruction Loss (MSE)
    if x.dim() > 2:
        x = x.view(x.size(0), -1)

    loss_rec = F.mse_loss(x_recon, x)

    # Clustering Loss (KL Divergence)
    loss_clus = kl_divergence_loss(q, p)

    # Total Loss
    total_loss = loss_rec + gamma * loss_clus

    return total_loss, loss_rec, loss_clus

**dcn**

In [None]:
from sklearn.cluster import KMeans

class DCN(nn.Module):
    """
    Deep Clustering Network (DCN).
    Alternating Optimization:
    1. K-Means to update Centroids & Assignments (Labels).
    2. SGD to update Autoencoder weights.
    """
    def __init__(self, autoencoder, n_clusters):
        super(DCN, self).__init__()
        self.autoencoder = autoencoder
        self.n_clusters = n_clusters
        self.latent_dim = autoencoder.embedding_layer.out_features

        self.register_buffer('centroids', torch.zeros(n_clusters, self.latent_dim))

    def forward(self, x):
        """
        Forward pass of DCN : cháº¡y Autoencoder.
        """
        x_recon, z = self.autoencoder(x)
        return x_recon, z

    def update_centroids(self, z_full_data):
        """
        Step 1: Freeze Embedding -> Update Centroids (use K-Means).
        All dataset.

        Args:
            z_full_data (np.array): Latent vectors of train.

        Returns:
            assignments (np.array): hard label.
        """
        # Hard Clustering
        kmeans = KMeans(n_clusters=self.n_clusters, n_init=20, random_state=42)
        assignments = kmeans.fit_predict(z_full_data)

        # upadte new centroids
        new_centroids = torch.tensor(kmeans.cluster_centers_, dtype=torch.float32)

        # assign buffer
        self.centroids.copy_(new_centroids)

        return assignments

def dcn_loss_function(x, x_recon, z, assignments, centroids, beta=1.0):
    """
    L = L_rec + beta * L_clus
    """
    # Reconstruction Loss
    if x.dim() > 2:
        x = x.view(x.size(0), -1)
    loss_rec = F.mse_loss(x_recon, x)

    # Clustering Loss (Hard Assignment Loss)
    batch_centroids = centroids[assignments]

    loss_clus = F.mse_loss(z, batch_centroids)

    # summary
    total_loss = loss_rec + beta * loss_clus

    return total_loss, loss_rec, loss_clus

**simclr**

In [None]:
class NTXentLoss(nn.Module):
    """
    Normalized Temperature-scaled Cross Entropy Loss cho SimCLR.
    Source: Chen et al., 2020.
    """
    def __init__(self, batch_size, temperature=0.5, device='cuda'):
        super(NTXentLoss, self).__init__()
        self.batch_size = batch_size
        self.temperature = temperature
        self.device = device

        self.mask = self.mask_correlated_samples(batch_size)

        # Criterion is CrossEntropy
        self.criterion = nn.CrossEntropyLoss(reduction="sum")
        self.similarity_f = nn.CosineSimilarity(dim=2)

    def mask_correlated_samples(self, batch_size):
        N = 2 * batch_size
        mask = torch.ones((N, N), dtype=bool)
        mask = mask.fill_diagonal_(0)

        for i in range(batch_size):
            mask[i, batch_size + i] = 0
            mask[batch_size + i, i] = 0
        return mask

    def forward(self, z_i, z_j):
        """
        Input:
            z_i: Batch vector projection of view 1 (N, dim)
            z_j: Batch vector projection of view 2 (N, dim)
        """
        N = 2 * self.batch_size

        # (2N, dim)
        z = torch.cat((z_i, z_j), dim=0)

        # Cosine Similarity Matrix (2N, 2N)
        sim = self.similarity_f(z.unsqueeze(1), z.unsqueeze(0)) / self.temperature

        sim_i_j = torch.diag(sim, self.batch_size)
        sim_j_i = torch.diag(sim, -self.batch_size)

        # Positive logits: (2N, 1)
        positive_samples = torch.cat((sim_i_j, sim_j_i), dim=0).reshape(N, 1)

        mask = self.mask.to(self.device)
        negative_samples = sim[mask].reshape(N, -1)

        # Cross Entropy
        labels = torch.zeros(N).to(self.device).long()
        logits = torch.cat((positive_samples, negative_samples), dim=1)

        loss = self.criterion(logits, labels)
        return loss / N

In [None]:
import torch.optim as optim
from tqdm import tqdm

class SelfLabeling:
    """
    Self-Labeling:
    1. Generate Pseudo-Labels from currently model (use K-Means).
    2. Fine-tune
    """
    def __init__(self, model, train_loader, n_clusters, device='cuda'):
        self.model = model
        self.train_loader = train_loader
        self.n_clusters = n_clusters
        self.device = device

    def get_pseudo_labels(self):
        """
        get feature and run K-Means.
        """
        self.model.eval()
        features_list = []

        with torch.no_grad():
            for batch in tqdm(self.train_loader):
                imgs = batch[0]
                if isinstance(imgs, list): # if TwoCropTransform
                    imgs = imgs[0]

                imgs = imgs.to(self.device)

                # Forward cross backbone (ResNet) get representation h
                h, _ = self.model(imgs)
                features_list.append(h.cpu().numpy())

        full_features = np.concatenate(features_list, axis=0)

        # Clustering K-Means
        kmeans = KMeans(n_clusters=self.n_clusters, n_init=20, random_state=42)
        pseudo_labels = kmeans.fit_predict(full_features)

        return torch.LongTensor(pseudo_labels).to(self.device)

    def finetune(self, epochs=20, lr=0.01):
        """
        train classification model base on Pseudo-Labels.
        """
        # get fake label
        pseudo_labels = self.get_pseudo_labels()

        # Input dim = model.feature_dim
        # Output dim = n_clusters
        classifier_head = nn.Linear(self.model.feature_dim, self.n_clusters).to(self.device)

        # Optimizer for classifier
        optimizer = optim.SGD(
            list(self.model.parameters()) + list(classifier_head.parameters()),
            lr=lr, momentum=0.9, weight_decay=1e-4
        )
        criterion = nn.CrossEntropyLoss()


        self.model.train()
        classifier_head.train()

        for epoch in range(epochs):
            total_loss = 0

            current_idx = 0
            for batch in self.train_loader:
                imgs = batch[0]
                if isinstance(imgs, list): imgs = imgs[0]
                imgs = imgs.to(self.device)

                batch_size = imgs.size(0)
                targets = pseudo_labels[current_idx : current_idx + batch_size]
                current_idx += batch_size

                # Forward
                h, _ = self.model(imgs) # get feature
                logits = classifier_head(h) # Classify

                loss = criterion(logits, targets)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                total_loss += loss.item()

            print(f"Epoch {epoch+1}/{epochs} - Self-Label Loss: {total_loss / len(self.train_loader):.4f}")

        return self.model, classifier_head

**Benchmark**

In [None]:
from sklearn.metrics import confusion_matrix
from sklearn.manifold import TSNE
import seaborn as sns

class Visualizer:
    @staticmethod
    def align_labels(y_true, y_pred):
        D = max(y_pred.max(), y_true.max()) + 1
        w = np.zeros((D, D), dtype=np.int64)
        for i in range(y_pred.size):
            w[y_pred[i], y_true[i]] += 1
        row_ind, col_ind = linear_sum_assignment(w.max() - w)
        remapping = {i: j for i, j in zip(row_ind, col_ind)}
        return np.array([remapping.get(x, x) for x in y_pred])

    @staticmethod
    def plot_all(results, y_true, y_pred, features, dataset_name, method_name):
        sns.set_theme(style="whitegrid")
        fig = plt.figure(figsize=(20, 12))
        gs = fig.add_gridspec(2, 2)

        # common title
        fig.suptitle(f"{method_name} on {dataset_name}", fontsize=18, fontweight='bold')

        # Chart 1: Metrics
        ax1 = fig.add_subplot(gs[0, 0])
        sns.barplot(x=list(results.keys()), y=list(results.values()), palette='viridis', ax=ax1)
        for i, v in enumerate(results.values()):
            ax1.text(i, v + 0.01, f"{v:.4f}", ha='center', fontweight='bold')
        ax1.set_title("Clustering Metrics", fontsize=14)
        ax1.set_ylim(0, 1.1)

        # Chart 2: Confusion Matrix
        ax2 = fig.add_subplot(gs[0, 1])
        y_pred_aligned = Visualizer.align_labels(y_true, y_pred)
        cm = confusion_matrix(y_true, y_pred_aligned)
        cm_norm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        sns.heatmap(cm_norm, annot=True, fmt=".2f", cmap='Blues', ax=ax2)
        ax2.set_title("Confusion Matrix (Aligned)", fontsize=14)

        # Chart 3: t-SNE
        ax3 = fig.add_subplot(gs[1, :])
        if len(features) > 2000:
            idx = np.random.choice(len(features), 2000, replace=False)
            feat_sample, y_sample = features[idx], y_true[idx]
        else:
            feat_sample, y_sample = features, y_true

        print("   >> Computing t-SNE...")
        tsne = TSNE(n_components=2, init='pca', learning_rate='auto', random_state=42)
        z_emb = tsne.fit_transform(feat_sample)
        scatter = ax3.scatter(z_emb[:, 0], z_emb[:, 1], c=y_sample, cmap='tab10', s=15, alpha=0.7)
        ax3.legend(*scatter.legend_elements(), title="Classes", loc="upper right", ncol=2)
        ax3.set_title("t-SNE Latent Space Visualization", fontsize=14)

        plt.tight_layout()
        plt.show()

# PIPELINE GROUP 1: IDEC (Grayscale)
def run_group1_idec(dataset_name, n_clusters, epochs_ae, epochs_idec, batch_size, device):
    print(f"\n RUNNING GROUP 1 (IDEC): {dataset_name}")

    # Load Data (Force Clean Transform for AE)

    train_loader, test_loader = get_data_loaders(dataset_name, batch_size, is_contrastive=False)

    sample_batch, _ = next(iter(train_loader))
    input_dim = sample_batch.view(sample_batch.size(0), -1).shape[1]
    print(f"   Input Dim: {input_dim}")

    # Model & Pre-train
    ae = Autoencoder(input_dim=input_dim, latent_dim=10).to(device)
    optimizer_ae = optim.Adam(ae.parameters(), lr=1e-3)

    ae.train()
    for _ in tqdm(range(epochs_ae), desc="   Step 1: Pre-train AE", leave=False):
        for x, _ in train_loader:
            x = x.to(device)
            if x.dim() > 2: x = x.view(x.size(0), -1)
            x_recon, _ = ae(x)
            loss = F.mse_loss(x_recon, x)
            optimizer_ae.zero_grad()
            loss.backward()
            optimizer_ae.step()

    # Init Centroids
    print("   Step 2: Init K-Means...")
    ae.eval()
    features = []
    with torch.no_grad():
        for x, _ in train_loader:
            x = x.to(device)
            if x.dim() > 2: x = x.view(x.size(0), -1)
            _, z = ae(x)
            features.append(z.cpu().numpy())
    features = np.concatenate(features)
    kmeans = KMeans(n_clusters=n_clusters, n_init=20, random_state=42).fit(features)

    idec = IDEC(ae, n_clusters=n_clusters).to(device)
    idec.clustering_layer.centroids.data = torch.tensor(kmeans.cluster_centers_).to(device)

    # Train IDEC
    optimizer_idec = optim.Adam(idec.parameters(), lr=1e-4)
    idec.train()
    for _ in tqdm(range(epochs_idec), desc="   Step 3: Fine-tune IDEC", leave=False):
        for x, _ in train_loader:
            x = x.to(device)
            x_recon, q, z = idec(x)
            p = target_distribution(q)
            loss, _, _ = idec_loss_function(x, x_recon, q, p, gamma=0.1)
            optimizer_idec.zero_grad()
            loss.backward()
            optimizer_idec.step()

    # Eval & Visualize
    idec.eval()
    final_z, final_pred, final_true = [], [], []
    with torch.no_grad():
        for x, y in test_loader:
            x = x.to(device)
            _, q, z = idec(x)
            final_z.append(z.cpu().numpy())
            final_pred.append(torch.argmax(q, dim=1).cpu().numpy())
            final_true.append(y.numpy())

    # Concat
    final_z = np.concatenate(final_z)
    final_pred = np.concatenate(final_pred)
    final_true = np.concatenate(final_true)

    results = evaluate_clustering(final_true, final_pred)
    print(f"   Result: {results}")
    Visualizer.plot_all(results, final_true, final_pred, final_z, dataset_name, "IDEC")

# PIPELINE GROUP 2: SimCLR (Color/Complex)
def run_group2_simclr(dataset_name, n_clusters, epochs, batch_size, device):
    print(f"\nðŸ”¸ RUNNING GROUP 2 (SimCLR): {dataset_name}")

    # Load Data
    # Loader train: is_contrastive=True (2 views)
    train_loader, _ = get_data_loaders(dataset_name, batch_size, is_contrastive=True)
    # Loader test: is_contrastive=False
    _, test_loader = get_data_loaders(dataset_name, batch_size, is_contrastive=False)

    # Model & Loss
    model = ContrastiveResNet18(latent_dim=128).to(device)
    optimizer = optim.Adam(model.parameters(), lr=3e-4, weight_decay=1e-4)
    criterion = NTXentLoss(batch_size=batch_size, temperature=0.5, device=device)

    # Train SimCLR loop
    model.train()
    for epoch in tqdm(range(epochs), desc="   Step 1: Training SimCLR", leave=False):
        total_loss = 0
        for batch in train_loader:
            # batch[0] is list [view1, view2] because TwoCropTransform
            (x_i, x_j), _ = batch
            x_i, x_j = x_i.to(device), x_j.to(device)

            _, z_i = model(x_i)
            _, z_j = model(x_j)

            loss = criterion(z_i, z_j)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

    # Evaluation (K-Means on learned features h)
    print("   Step 2: Evaluating Features...")
    model.eval()
    features_list = []
    labels_list = []

    with torch.no_grad():
        for x, y in test_loader:
            x = x.to(device)
            h, _ = model(x)
            features_list.append(h.cpu().numpy())
            labels_list.append(y.numpy())

    final_h = np.concatenate(features_list)
    final_true = np.concatenate(labels_list)

    # K-Means
    kmeans = KMeans(n_clusters=n_clusters, n_init=20, random_state=42)
    final_pred = kmeans.fit_predict(final_h)

    results = evaluate_clustering(final_true, final_pred)
    print(f"    Result: {results}")
    Visualizer.plot_all(results, final_true, final_pred, final_h, dataset_name, "SimCLR+KMeans")

# MASTER EXECUTION BLOCK
if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    print(f"System Check: Using {str(device).upper()}")


    all_datasets = [
        # --- GROUP 1: IDEC ---
        {'name': 'MNIST',          'group': 1, 'k': 10, 'ae': 30, 'idec': 30, 'bs': 256},
        {'name': 'Fashion-MNIST',  'group': 1, 'k': 10, 'ae': 50, 'idec': 40, 'bs': 256},
        {'name': 'USPS',           'group': 1, 'k': 10, 'ae': 50, 'idec': 30, 'bs': 256},
        {'name': 'OPTDIGITS',      'group': 1, 'k': 10, 'ae': 50, 'idec': 30, 'bs': 64},

        # --- GROUP 2: SimCLR ---
        {'name': 'CIFAR-10',       'group': 2, 'k': 10, 'epochs': 50, 'bs': 128},
        {'name': 'CIFAR-100-20',   'group': 2, 'k': 20, 'epochs': 50, 'bs': 128},
        {'name': 'GTSRB',          'group': 2, 'k': 43, 'epochs': 50, 'bs': 64}
    ]

    for config in all_datasets:
        try:
            if config['group'] == 1:
                run_group1_idec(
                    dataset_name=config['name'],
                    n_clusters=config['k'],
                    epochs_ae=config['ae'],
                    epochs_idec=config['idec'],
                    batch_size=config['bs'],
                    device=device
                )
            elif config['group'] == 2:
                run_group2_simclr(
                    dataset_name=config['name'],
                    n_clusters=config['k'],
                    epochs=config['epochs'],
                    batch_size=config['bs'],
                    device=device
                )
        except Exception as e:
            print(f" Error running {config['name']}: {e}")
            import traceback
            traceback.print_exc()