In [None]:
pip install munkres



In [None]:
pip install scanpy



In [None]:
import os
import random
import h5py
import scanpy as sc
import scipy as sp
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.nn.functional import binary_cross_entropy_with_logits as bce_logits
from torch.nn.functional import mse_loss as mse
from torch.utils.data import DataLoader, Dataset
from sklearn.preprocessing import LabelEncoder
from sklearn.decomposition import TruncatedSVD
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score, normalized_mutual_info_score, adjusted_rand_score, accuracy_score
from scipy.optimize import linear_sum_assignment as hungarian
from sklearn import metrics
from munkres import Munkres

# Set CUDA device
device = torch.device("cpu")


# Utility Classes
class AverageMeter(object):
    def __init__(self):
        self.reset()

    def reset(self):
        self.val, self.avg, self.sum, self.count = 0, 0, 0, 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

# Model Classes
class AutoEncoder(torch.nn.Module):
    def __init__(
        self,
        num_genes,
        hidden_size=128,
        dropout=0,
        masked_data_weight=.75,
        mask_loss_weight=0.7,
    ):
        super().__init__()
        self.num_genes = num_genes
        self.masked_data_weight = masked_data_weight
        self.mask_loss_weight = mask_loss_weight

        self.encoder = nn.Sequential(
            nn.Dropout(p=dropout),
            nn.Linear(self.num_genes, 256),
            nn.LayerNorm(256),
            nn.Mish(inplace=True),
            nn.Linear(256, hidden_size),
            nn.LayerNorm(hidden_size),
            nn.Mish(inplace=True),
            nn.Linear(hidden_size, hidden_size)
        )

        self.mask_predictor = nn.Linear(hidden_size, num_genes)
        self.decoder = nn.Linear(
            in_features=hidden_size+num_genes, out_features=num_genes)

    def forward_mask(self, x):
        latent = self.encoder(x)
        predicted_mask = self.mask_predictor(latent)
        reconstruction = self.decoder(
            torch.cat([latent, predicted_mask], dim=1))
        return latent, predicted_mask, reconstruction

    def loss_mask(self, x, y, mask):
        latent, predicted_mask, reconstruction = self.forward_mask(x)
        w_nums = mask * self.masked_data_weight + (1 - mask) * (1 - self.masked_data_weight)
        reconstruction_loss = (1-self.mask_loss_weight) * torch.mul(
            w_nums, mse(reconstruction, y, reduction='none'))
        mask_loss = self.mask_loss_weight * \
            bce_logits(predicted_mask, mask, reduction="mean")
        reconstruction_loss = reconstruction_loss.mean()
        loss = reconstruction_loss + mask_loss
        return latent, loss

    def feature(self, x):
        latent = self.encoder(x)
        return latent

# Data Processing Classes
default_svd_params = {
    "n_components": 128,
    "random_state": 42,
    "n_oversamples": 20,
    "n_iter": 7,
}

class IterativeSVDImputator(object):
    def __init__(self, svd_params=default_svd_params, iters=2):
        self.missing_values = 0.0
        self.svd_params = svd_params
        self.iters = iters
        self.svd_decomposers = [None for _ in range(self.iters)]

    def fit(self, X):
        mask = X == self.missing_values
        transformed_X = X.copy()
        for i in range(self.iters):
            self.svd_decomposers[i] = TruncatedSVD(**self.svd_params)
            self.svd_decomposers[i].fit(transformed_X)
            new_X = self.svd_decomposers[i].inverse_transform(
                self.svd_decomposers[i].transform(transformed_X))
            transformed_X[mask] = new_X[mask]

    def transform(self, X):
        mask = X == self.missing_values
        transformed_X = X.copy()
        for i in range(self.iters):
            new_X = self.svd_decomposers[i].inverse_transform(
                self.svd_decomposers[i].transform(transformed_X))
            transformed_X[mask] = new_X[mask]
        return transformed_X

class scRNADataset(Dataset):
    def __init__(self, config, dataset_name, mode='train'):
        self.config = config
        if mode == 'train':
            self.iterator = self.prepare_training_pairs
        else:
            self.iterator = self.prepare_test_pairs
        self.paths = config["paths"]
        self.dataset_name = dataset_name
        self.data_path = os.path.join(self.paths["data"], dataset_name)
        self.data, self.labels = self._load_data()
        self.data_dim = self.data.shape[1]

    def __len__(self):
        return len(self.data)

    def prepare_training_pairs(self, idx):
        sample = self.data[idx]
        sample_tensor = torch.Tensor(sample)
        cluster = int(self.labels[idx])
        return sample, cluster

    def prepare_test_pairs(self, idx):
        sample = self.data[idx]
        cluster = int(self.labels[idx])
        return sample, cluster

    def __getitem__(self, index):
        return self.iterator(index)

    def _load_data(self):
        data, labels = self.load_data(self.data_path)
        n_classes = len(list(set(labels.reshape(-1, ).tolist())))
        self.config["feat_dim"] = data.shape[1]
        if self.config["n_classes"] != n_classes:
            self.config["n_classes"] = n_classes
            print(f"{50 * '>'} Number of classes changed "
                  f"from {self.config['n_classes']} to {n_classes} {50 * '<'}")
        self.data_max = np.max(np.abs(data))
        self.data_min = np.min(np.abs(data))
        return data, labels

    def load_data(self, path):
        data_mat = h5py.File(f"{path}.h5", "r")
        X = np.array(data_mat['X'])
        Y = np.array(data_mat['Y'])

        if Y.dtype != "int64":
            encoder_x = LabelEncoder()
            Y = encoder_x.fit_transform(Y)

        nb_genes = 1000
        X = np.ceil(X).astype(int)
        count_X = X
        print(X.shape, count_X.shape, f"keeping {nb_genes} genes")
        adata = sc.AnnData(X)

        adata = self.normalize(adata,
                               copy=True,
                               highly_genes=nb_genes,
                               size_factors=True,
                               normalize_input=True,
                               logtrans_input=True)
        sorted_genes = adata.var_names[np.argsort(adata.var["mean"])]
        adata = adata[:, sorted_genes]
        X = adata.X.astype(np.float32)

        imputator = IterativeSVDImputator(iters=2)
        imputator.fit(X)
        X = imputator.transform(X)

        return X, Y

    def normalize(self, adata, copy=True, highly_genes=None, filter_min_counts=True,
                  size_factors=True, normalize_input=True, logtrans_input=True):
        if isinstance(adata, sc.AnnData):
            if copy:
                adata = adata.copy()
        elif isinstance(adata, str):
            adata = sc.read(adata)
        else:
            raise NotImplementedError
        norm_error = 'Make sure that the dataset (adata.X) contains unnormalized count data.'
        assert 'n_count' not in adata.obs, norm_error
        if adata.X.size < 50e6:
            if sp.sparse.issparse(adata.X):
                assert (adata.X.astype(int) != adata.X).nnz == 0, norm_error
            else:
                assert np.all(adata.X.astype(int) == adata.X), norm_error

        if filter_min_counts:
            sc.pp.filter_genes(adata, min_counts=1)
            sc.pp.filter_cells(adata, min_counts=1)
        if size_factors or normalize_input or logtrans_input:
            adata.raw = adata.copy()
        else:
            adata.raw = adata
        if size_factors:
            sc.pp.normalize_total(adata)
            adata.obs['size_factors'] = adata.obs.n_counts / \
                np.median(adata.obs.n_counts)
        else:
            adata.obs['size_factors'] = 1.0
        if logtrans_input:
            sc.pp.log1p(adata)
        if highly_genes != None:
            sc.pp.highly_variable_genes(
                adata, min_mean=0.0125, max_mean=3, min_disp=0.5, n_top_genes=highly_genes, subset=True)
        if normalize_input:
            sc.pp.scale(adata)
        return adata

class Loader(object):
    def __init__(self, config, dataset_name, drop_last=True, kwargs={}):
        batch_size = config["batch_size"]
        self.config = config
        train_dataset, test_dataset = self.get_dataset(dataset_name)
        self.data_max = train_dataset.data_max
        self.data_min = train_dataset.data_min

        self.train_loader = DataLoader(
            train_dataset, batch_size=batch_size, shuffle=True, drop_last=drop_last, **kwargs)
        self.test_loader = DataLoader(
            test_dataset, batch_size=batch_size*5, shuffle=False, drop_last=False, **kwargs)

    def get_dataset(self, dataset_name):
        loader_map = {'default_loader': scRNADataset}
        dataset = loader_map[dataset_name] if dataset_name in loader_map.keys() else loader_map['default_loader']
        train_dataset = dataset(self.config, dataset_name=dataset_name, mode='train')
        test_dataset = dataset(self.config, dataset_name=dataset_name, mode='test')
        return train_dataset, test_dataset

# Evaluation Functions
def evaluate(label, pred):
    nmi = normalized_mutual_info_score(label, pred)
    ari = adjusted_rand_score(label, pred)
    pred_adjusted = get_y_preds(label, pred, max(len(set(label)), len(set(pred))))
    acc = accuracy_score(pred_adjusted, label)
    return nmi, ari, acc

def calculate_cost_matrix(C, n_clusters):
    cost_matrix = np.zeros((n_clusters, n_clusters))
    for j in range(n_clusters):
        s = np.sum(C[:, j])
        for i in range(n_clusters):
            t = C[i, j]
            cost_matrix[j, i] = s - t
    return cost_matrix

def get_cluster_labels_from_indices(indices):
    n_clusters = len(indices)
    cluster_labels = np.zeros(n_clusters)
    for i in range(n_clusters):
        cluster_labels[i] = indices[i][1]
    return cluster_labels

def get_y_preds(y_true, cluster_assignments, n_clusters):
    confusion_matrix = metrics.confusion_matrix(y_true, cluster_assignments, labels=None)
    cost_matrix = calculate_cost_matrix(confusion_matrix, n_clusters)
    indices = Munkres().compute(cost_matrix)
    kmeans_to_true_cluster_labels = get_cluster_labels_from_indices(indices)

    if np.min(cluster_assignments) != 0:
        cluster_assignments = cluster_assignments - np.min(cluster_assignments)
    y_pred = kmeans_to_true_cluster_labels[cluster_assignments]
    return y_pred

# Helper Functions
def apply_noise(X, p=[0.2,0.4]):
    p = torch.tensor(p)
    should_swap = torch.bernoulli(p.to(X.device) * torch.ones((X.shape)).to(X.device))
    corrupted_X = torch.where(should_swap == 1, X[torch.randperm(X.shape[0])], X)
    masked = (corrupted_X != X).float()
    return corrupted_X, masked

def make_dir(directory_path, new_folder_name):
    directory_path = os.path.join(directory_path, new_folder_name)
    if not os.path.exists(directory_path):
        os.makedirs(directory_path)
    return directory_path

def inference(net, data_loader_test):
    net.eval()
    feature_vector = []
    labels_vector = []
    with torch.no_grad():
        for step, (x, y) in enumerate(data_loader_test):

            feature_vector.extend(net.feature(x.to(device)).detach().cpu().numpy())

            labels_vector.extend(y.numpy())
    feature_vector = np.array(feature_vector)
    labels_vector = np.array(labels_vector)
    return feature_vector, labels_vector

def res_search_fixed_clus(adata, fixed_clus_count, increment=0.02):
    dis = []
    resolutions = sorted(list(np.arange(0.01, 2.5, increment)), reverse=True)
    i = 0
    res_new = []
    for res in resolutions:
        sc.tl.leiden(adata, random_state=0, resolution=res)
        count_unique_leiden = len(pd.DataFrame(adata.obs['leiden']).leiden.unique())
        dis.append(abs(count_unique_leiden-fixed_clus_count))
        res_new.append(res)
        if count_unique_leiden == fixed_clus_count:
            break
    reso = resolutions[np.argmin(dis)]
    return reso

# Main Training Function
def train(args):
    data_load = Loader(args, dataset_name=args["dataset"], drop_last=True)
    data_loader = data_load.train_loader
    data_loader_test = data_load.test_loader
    x_shape = args["data_dim"]

    results = []
    init_lr = args["learning_rate"]
    max_epochs = args["epochs"]
    mask_probas = [0.2]*x_shape


    model = AutoEncoder(
    num_genes=x_shape,
    hidden_size=128,
    masked_data_weight=0.75,
    mask_loss_weight=0.7
).to(device)

    model_checkpoint = 'model_checkpoint.pth'
    optimizer = torch.optim.Adam(model.parameters(), lr=init_lr)

    for epoch in range(max_epochs):
        model.train()
        meter = AverageMeter()
        for i, (x, y) in enumerate(data_loader):
            x = x.to(device)
            x_corrputed, mask = apply_noise(x, mask_probas)
            optimizer.zero_grad()
            x_corrputed_latent, loss_ae = model.loss_mask(x_corrputed, x, mask)
            loss_ae.backward()
            optimizer.step()
            meter.update(loss_ae.detach().cpu().numpy())

        if epoch == 80:
            latent, true_label = inference(model, data_loader_test)
            if latent.shape[0] < 10000:
                clustering_model = KMeans(n_clusters=args["n_classes"])
                clustering_model.fit(latent)
                pred_label = clustering_model.labels_
            else:
                adata = sc.AnnData(latent)
                sc.pp.neighbors(adata, n_neighbors=10, use_rep="X")
                reso = res_search_fixed_clus(adata, args["n_classes"])
                sc.tl.leiden(adata, resolution=reso)
                pred = adata.obs['leiden'].to_list()
                pred_label = [int(x) for x in pred]

            nmi, ari, acc = evaluate(true_label, pred_label)
            ss = silhouette_score(latent, pred_label)

            res = {}
            res["nmi"] = nmi
            res["ari"] = ari
            res["acc"] = acc
            res["sil"] = ss
            res["dataset"] = args["dataset"]
            res["epoch"] = epoch
            results.append(res)

            print("\tEvalute: [nmi: %f] [ari: %f] [acc: %f]" % (nmi, ari, acc))

            np.save(args["save_path"]+"/embedding_"+str(epoch)+".npy", latent)
            pd.DataFrame({"True": true_label, "Pred": pred_label}).to_csv(args["save_path"]+"/types_"+str(epoch)+".txt")

    torch.save({"optimizer": optimizer.state_dict(), "model": model.state_dict()}, model_checkpoint)
    return results

# Main Execution
if __name__ == "__main__":
    # Set random seeds for reproducibility
    seed = 42
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)

    # Configuration
    args = {
        "num_workers": 4,
        "paths": {
            "data": ".",  # Assuming Pollen.h5 is in the current directory
            "results": "./results"
        },
        'batch_size': 256,
        "data_dim": 1000,
        'n_classes': 7,  # Adjust based on your dataset
        'epochs': 100,
        "dataset": "Pollen",  # Using your input file name without .h5 extension
        "learning_rate": 1e-3,
        "latent_dim": 32,
        "save_path": "./results/Pollen"  # Output directory
    }

    # Create output directory
    os.makedirs(args["save_path"], exist_ok=True)
    os.makedirs(args["paths"]["results"], exist_ok=True)

    print("Starting training with configuration:")
    print(args)

    # Run training
    results = train(args)

    # Save final results
    results_df = pd.DataFrame(results)
    results_df.to_csv(os.path.join(args["paths"]["results"], "final_results.csv"))
    print("Training completed and results saved.")

Starting training with configuration:
{'num_workers': 4, 'paths': {'data': '.', 'results': './results'}, 'batch_size': 256, 'data_dim': 1000, 'n_classes': 7, 'epochs': 100, 'dataset': 'Pollen', 'learning_rate': 0.001, 'latent_dim': 32, 'save_path': './results/Pollen'}
(301, 21721) (301, 21721) keeping 1000 genes
>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> Number of classes changed from 11 to 11 <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
(301, 21721) (301, 21721) keeping 1000 genes
	Evalute: [nmi: 0.920688] [ari: 0.933688] [acc: 0.913621]
Training completed and results saved.


In [None]:
import os
import random
import h5py
import scanpy as sc
import scipy as sp
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.nn.functional import binary_cross_entropy_with_logits as bce_logits
from torch.nn.functional import mse_loss as mse
from torch.utils.data import DataLoader, Dataset
from sklearn.preprocessing import LabelEncoder
from sklearn.decomposition import TruncatedSVD
from sklearn.cluster import KMeans, DBSCAN, SpectralClustering
from sklearn.mixture import GaussianMixture
from sklearn.neighbors import NearestNeighbors
from sklearn.metrics import silhouette_score, normalized_mutual_info_score, adjusted_rand_score, accuracy_score
from scipy.optimize import linear_sum_assignment as hungarian
from sklearn import metrics
from munkres import Munkres
import matplotlib.pyplot as plt

# Set CUDA device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Utility Classes and Functions
class AverageMeter(object):
    def __init__(self):
        self.reset()
    def reset(self):
        self.val, self.avg, self.sum, self.count = 0, 0, 0, 0
    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

def evaluate(label, pred):
    nmi = normalized_mutual_info_score(label, pred)
    ari = adjusted_rand_score(label, pred)
    pred_adjusted = get_y_preds(label, pred, max(len(set(label)), len(set(pred))))
    acc = accuracy_score(pred_adjusted, label)
    return nmi, ari, acc

def get_y_preds(y_true, cluster_assignments, n_clusters):
    confusion_matrix = metrics.confusion_matrix(y_true, cluster_assignments, labels=None)
    cost_matrix = calculate_cost_matrix(confusion_matrix, n_clusters)
    indices = Munkres().compute(cost_matrix)
    kmeans_to_true_cluster_labels = get_cluster_labels_from_indices(indices)
    if np.min(cluster_assignments) != 0:
        cluster_assignments = cluster_assignments - np.min(cluster_assignments)
    y_pred = kmeans_to_true_cluster_labels[cluster_assignments]
    return y_pred

def calculate_cost_matrix(C, n_clusters):
    cost_matrix = np.zeros((n_clusters, n_clusters))
    for j in range(n_clusters):
        s = np.sum(C[:, j])
        for i in range(n_clusters):
            t = C[i, j]
            cost_matrix[j, i] = s - t
    return cost_matrix

def get_cluster_labels_from_indices(indices):
    n_clusters = len(indices)
    cluster_labels = np.zeros(n_clusters)
    for i in range(n_clusters):
        cluster_labels[i] = indices[i][1]
    return cluster_labels

def cluster_with_methods(latent, true_label, n_classes):
    results = {}

    # 1: K-Means
    kmeans = KMeans(n_clusters=n_classes, random_state=42)
    pred_kmeans = kmeans.fit_predict(latent)
    nmi, ari, acc = evaluate(true_label, pred_kmeans)
    results['KMeans'] = {'nmi': nmi, 'ari': ari, 'acc': acc, 'labels': pred_kmeans}


    # 2: DBSCAN
    neigh = NearestNeighbors(n_neighbors=2)
    nbrs = neigh.fit(latent)
    distances, indices = nbrs.kneighbors(latent)
    distances = np.sort(distances, axis=0)
    distances = distances[:,1]
    eps = distances[int(0.95*len(distances))]

    dbscan = DBSCAN(eps=eps, min_samples=5)
    pred_dbscan = dbscan.fit_predict(latent)

    if len(set(pred_dbscan)) <= 1:
        pred_dbscan = kmeans.fit_predict(latent)

    nmi, ari, acc = evaluate(true_label, pred_dbscan)
    results['DBSCAN'] = {'nmi': nmi, 'ari': ari, 'acc': acc, 'labels': pred_dbscan}

    # 3: Gaussian Mixture Model
    gmm = GaussianMixture(n_components=n_classes, random_state=42)
    pred_gmm = gmm.fit_predict(latent)
    nmi, ari, acc = evaluate(true_label, pred_gmm)
    results['GMM'] = {'nmi': nmi, 'ari': ari, 'acc': acc, 'labels': pred_gmm}

    # 4: Spectral Clustering
    spectral = SpectralClustering(n_clusters=n_classes, random_state=42, affinity='nearest_neighbors')
    pred_spectral = spectral.fit_predict(latent)
    nmi, ari, acc = evaluate(true_label, pred_spectral)
    results['Spectral'] = {'nmi': nmi, 'ari': ari, 'acc': acc, 'labels': pred_spectral}

    return results

# Model and Data Classes
class AutoEncoder(torch.nn.Module):
    def __init__(
        self,
        num_genes,
        hidden_size=128,
        dropout=0,
        masked_data_weight=.75,
        mask_loss_weight=0.7,
    ):
        super().__init__()
        self.num_genes = num_genes
        self.masked_data_weight = masked_data_weight
        self.mask_loss_weight = mask_loss_weight

        self.encoder = nn.Sequential(
            nn.Dropout(p=dropout),
            nn.Linear(self.num_genes, 256),
            nn.LayerNorm(256),
            nn.Mish(inplace=True),
            nn.Linear(256, hidden_size),
            nn.LayerNorm(hidden_size),
            nn.Mish(inplace=True),
            nn.Linear(hidden_size, hidden_size)
        )

        self.mask_predictor = nn.Linear(hidden_size, num_genes)
        self.decoder = nn.Linear(
            in_features=hidden_size+num_genes, out_features=num_genes)

    def forward_mask(self, x):
        latent = self.encoder(x)
        predicted_mask = self.mask_predictor(latent)
        reconstruction = self.decoder(
            torch.cat([latent, predicted_mask], dim=1))
        return latent, predicted_mask, reconstruction

    def loss_mask(self, x, y, mask):
        latent, predicted_mask, reconstruction = self.forward_mask(x)
        w_nums = mask * self.masked_data_weight + (1 - mask) * (1 - self.masked_data_weight)
        reconstruction_loss = (1-self.mask_loss_weight) * torch.mul(
            w_nums, mse(reconstruction, y, reduction='none'))
        mask_loss = self.mask_loss_weight * \
            bce_logits(predicted_mask, mask, reduction="mean")
        reconstruction_loss = reconstruction_loss.mean()
        loss = reconstruction_loss + mask_loss
        return latent, loss

    def feature(self, x):
        latent = self.encoder(x)
        return latent

# Data Processing Classes
default_svd_params = {
    "n_components": 128,
    "random_state": 42,
    "n_oversamples": 20,
    "n_iter": 7,
}

class scRNADataset(Dataset):
    def __init__(self, config, dataset_name, mode='train'):
        self.config = config
        if mode == 'train':
            self.iterator = self.prepare_training_pairs
        else:
            self.iterator = self.prepare_test_pairs
        self.paths = config["paths"]
        self.dataset_name = dataset_name
        self.data_path = os.path.join(self.paths["data"], dataset_name)
        self.data, self.labels = self._load_data()
        self.data_dim = self.data.shape[1]

    def __len__(self):
        return len(self.data)

    def prepare_training_pairs(self, idx):
        sample = self.data[idx]
        sample_tensor = torch.Tensor(sample)
        cluster = int(self.labels[idx])
        return sample, cluster

    def prepare_test_pairs(self, idx):
        sample = self.data[idx]
        cluster = int(self.labels[idx])
        return sample, cluster

    def __getitem__(self, index):
        return self.iterator(index)

    def _load_data(self):
        data, labels = self.load_data(self.data_path)
        n_classes = len(list(set(labels.reshape(-1, ).tolist())))
        self.config["feat_dim"] = data.shape[1]
        if self.config["n_classes"] != n_classes:
            self.config["n_classes"] = n_classes
            print(f"{50 * '>'} Number of classes changed "
                  f"from {self.config['n_classes']} to {n_classes} {50 * '<'}")
        self.data_max = np.max(np.abs(data))
        self.data_min = np.min(np.abs(data))
        return data, labels

    def load_data(self, path):
        data_mat = h5py.File(f"{path}.h5", "r")
        X = np.array(data_mat['X'])
        Y = np.array(data_mat['Y'])

        if Y.dtype != "int64":
            encoder_x = LabelEncoder()
            Y = encoder_x.fit_transform(Y)

        nb_genes = 1000
        X = np.ceil(X).astype(int)
        count_X = X
        print(X.shape, count_X.shape, f"keeping {nb_genes} genes")
        adata = sc.AnnData(X)

        adata = self.normalize(adata,
                               copy=True,
                               highly_genes=nb_genes,
                               size_factors=True,
                               normalize_input=True,
                               logtrans_input=True)
        sorted_genes = adata.var_names[np.argsort(adata.var["mean"])]
        adata = adata[:, sorted_genes]
        X = adata.X.astype(np.float32)

        imputator = IterativeSVDImputator(iters=2)
        imputator.fit(X)
        X = imputator.transform(X)

        return X, Y

    def normalize(self, adata, copy=True, highly_genes=None, filter_min_counts=True,
                  size_factors=True, normalize_input=True, logtrans_input=True):
        if isinstance(adata, sc.AnnData):
            if copy:
                adata = adata.copy()
        elif isinstance(adata, str):
            adata = sc.read(adata)
        else:
            raise NotImplementedError
        norm_error = 'Make sure that the dataset (adata.X) contains unnormalized count data.'
        assert 'n_count' not in adata.obs, norm_error
        if adata.X.size < 50e6:
            if sp.sparse.issparse(adata.X):
                assert (adata.X.astype(int) != adata.X).nnz == 0, norm_error
            else:
                assert np.all(adata.X.astype(int) == adata.X), norm_error

        if filter_min_counts:
            sc.pp.filter_genes(adata, min_counts=1)
            sc.pp.filter_cells(adata, min_counts=1)
        if size_factors or normalize_input or logtrans_input:
            adata.raw = adata.copy()
        else:
            adata.raw = adata
        if size_factors:
            sc.pp.normalize_total(adata)
            adata.obs['size_factors'] = adata.obs.n_counts / \
                np.median(adata.obs.n_counts)
        else:
            adata.obs['size_factors'] = 1.0
        if logtrans_input:
            sc.pp.log1p(adata)
        if highly_genes != None:
            sc.pp.highly_variable_genes(
                adata, min_mean=0.0125, max_mean=3, min_disp=0.5, n_top_genes=highly_genes, subset=True)
        if normalize_input:
            sc.pp.scale(adata)
        return adata

class Loader(object):
    def __init__(self, config, dataset_name, drop_last=True, kwargs={}):
        batch_size = config["batch_size"]
        self.config = config
        train_dataset, test_dataset = self.get_dataset(dataset_name)
        self.data_max = train_dataset.data_max
        self.data_min = train_dataset.data_min

        self.train_loader = DataLoader(
            train_dataset, batch_size=batch_size, shuffle=True, drop_last=drop_last, **kwargs)
        self.test_loader = DataLoader(
            test_dataset, batch_size=batch_size*5, shuffle=False, drop_last=False, **kwargs)

    def get_dataset(self, dataset_name):
        loader_map = {'default_loader': scRNADataset}
        dataset = loader_map[dataset_name] if dataset_name in loader_map.keys() else loader_map['default_loader']
        train_dataset = dataset(self.config, dataset_name=dataset_name, mode='train')
        test_dataset = dataset(self.config, dataset_name=dataset_name, mode='test')
        return train_dataset, test_dataset

# train 
def train(args):
    data_load = Loader(args, dataset_name=args["dataset"], drop_last=True)
    data_loader = data_load.train_loader
    data_loader_test = data_load.test_loader
    x_shape = args["data_dim"]

    results = []
    init_lr = args["learning_rate"]
    max_epochs = args["epochs"]
    mask_probas = [0.6]*x_shape

    model = AutoEncoder(
        num_genes=x_shape,
        hidden_size=128,
        masked_data_weight=0.75,
        mask_loss_weight=0.7
    ).to(device)

    model_checkpoint = 'model_checkpoint.pth'
    optimizer = torch.optim.Adam(model.parameters(), lr=init_lr)

    for epoch in range(max_epochs):
        model.train()
        meter = AverageMeter()
        for i, (x, y) in enumerate(data_loader):
            x = x.to(device)
            x_corrputed, mask = apply_noise(x, mask_probas)
            optimizer.zero_grad()
            x_corrputed_latent, loss_ae = model.loss_mask(x_corrputed, x, mask)
            loss_ae.backward()
            optimizer.step()
            meter.update(loss_ae.detach().cpu().numpy())

        if epoch % 10 == 0 or epoch == max_epochs - 1:
            print(f"Epoch {epoch}/{max_epochs}, Loss: {meter.avg:.4f}")

        if epoch == 80 or epoch == max_epochs - 1:
            latent, true_label = inference(model, data_loader_test)

            clustering_results = cluster_with_methods(latent, true_label, args["n_classes"])

            print(f"\nEvaluation Results at Epoch {epoch}:")
            print("="*60)
            print(f"{'Method':<15} | {'NMI':<8} | {'ARI':<8} | {'Accuracy':<8}")
            print("-"*60)

            for method, metrics in clustering_results.items():
                res = {
                    "method": method,
                    "nmi": metrics['nmi'],
                    "ari": metrics['ari'],
                    "acc": metrics['acc'],
                    "dataset": args["dataset"],
                    "epoch": epoch
                }
                results.append(res)
                print(f"{method:<15} | {metrics['nmi']:.4f} | {metrics['ari']:.4f} | {metrics['acc']:.4f}")

            # Visualization (optional)
            plt.figure(figsize=(15, 10))
            for i, (method, metrics) in enumerate(clustering_results.items(), 1):
                plt.subplot(2, 3, i)
                plt.scatter(latent[:, 0], latent[:, 1], c=metrics['labels'], cmap='tab20', s=5)
                plt.title(f"{method}\nNMI: {metrics['nmi']:.2f}, ARI: {metrics['ari']:.2f}")
                plt.colorbar()
            plt.tight_layout()
            plt.savefig(os.path.join(args["save_path"], f"clustering_results_epoch_{epoch}.png"))
            plt.close()

            np.save(os.path.join(args["save_path"], f"embedding_{epoch}.npy"), latent)
            pd.DataFrame({"True": true_label}).to_csv(os.path.join(args["save_path"], f"true_labels_{epoch}.csv"))

    torch.save({"optimizer": optimizer.state_dict(), "model": model.state_dict()}, model_checkpoint)

    # Save final results
    results_df = pd.DataFrame(results)
    results_df.to_csv(os.path.join(args["save_path"], "final_results.csv"), index=False)

    return results

# 
if __name__ == "__main__":
    # Set random seeds
    seed = 42
    random.seed(seed)
    torch.manual_seed(seed)
    np.random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

    # Configuration
    args = {
        "num_workers": 4,
        "paths": {
            "data": ".",
            "results": "./results"
        },
        'batch_size': 256,
        "data_dim": 1000,
        'n_classes': 5,
        'epochs': 100,
        "dataset": "Pollen",
        "learning_rate": 1e-3,
        "latent_dim": 32,
        "save_path": "./results/Pollen"
    }

    os.makedirs(args["save_path"], exist_ok=True)
    os.makedirs(args["paths"]["results"], exist_ok=True)

    print("Starting training with configuration:")
    print(args)

    results = train(args)

    print("\nFinal Results Summary:")
    print("="*60)
    print(f"{'Method':<15} | {'NMI':<8} | {'ARI':<8} | {'Accuracy':<8}")
    print("-"*60)

    final_results = pd.read_csv(os.path.join(args["save_path"], "final_results.csv"))
    for method in ['KMeans', 'DBSCAN', 'GMM', 'Spectral']:
        method_results = final_results[final_results['method'] == method]
        best_idx = method_results['nmi'].idxmax()
        best = method_results.loc[best_idx]
        print(f"{method:<15} | {best['nmi']:.4f} | {best['ari']:.4f} | {best['acc']:.4f}")

Starting training with configuration:
{'num_workers': 4, 'paths': {'data': '.', 'results': './results'}, 'batch_size': 256, 'data_dim': 1000, 'n_classes': 5, 'epochs': 100, 'dataset': 'Pollen', 'learning_rate': 0.001, 'latent_dim': 32, 'save_path': './results/Pollen'}
(301, 21721) (301, 21721) keeping 1000 genes
>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> Number of classes changed from 11 to 11 <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
(301, 21721) (301, 21721) keeping 1000 genes
Epoch 0/100, Loss: 0.6513
Epoch 10/100, Loss: 0.5041
Epoch 20/100, Loss: 0.4340
Epoch 30/100, Loss: 0.4090
Epoch 40/100, Loss: 0.4101
Epoch 50/100, Loss: 0.4020
Epoch 60/100, Loss: 0.4102
Epoch 70/100, Loss: 0.3964
Epoch 80/100, Loss: 0.4008





Evaluation Results at Epoch 80:
Method          | NMI      | ARI      | Accuracy
------------------------------------------------------------
KMeans          | 0.9105 | 0.9092 | 0.8837
DBSCAN          | 0.7823 | 0.5198 | 0.0897
GMM             | 0.9105 | 0.9092 | 0.8837
Spectral        | 0.8255 | 0.6466 | 0.7475
Epoch 90/100, Loss: 0.3965
Epoch 99/100, Loss: 0.3909





Evaluation Results at Epoch 99:
Method          | NMI      | ARI      | Accuracy
------------------------------------------------------------
KMeans          | 0.9256 | 0.9301 | 0.9203
DBSCAN          | 0.7805 | 0.5147 | 0.0797
GMM             | 0.9256 | 0.9301 | 0.9203
Spectral        | 0.8209 | 0.6351 | 0.6811

Final Results Summary:
Method          | NMI      | ARI      | Accuracy
------------------------------------------------------------
KMeans          | 0.9256 | 0.9301 | 0.9203
DBSCAN          | 0.7823 | 0.5198 | 0.0897
GMM             | 0.9256 | 0.9301 | 0.9203
Spectral        | 0.8255 | 0.6466 | 0.7475


In [None]:
import os
import random
import h5py
import scanpy as sc
import scipy as sp
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.nn.functional import binary_cross_entropy_with_logits as bce_logits
from torch.nn.functional import mse_loss as mse
from torch.utils.data import DataLoader, Dataset
from sklearn.preprocessing import LabelEncoder
from sklearn.decomposition import TruncatedSVD
from sklearn.cluster import AgglomerativeClustering
from sklearn.metrics import silhouette_score, normalized_mutual_info_score, adjusted_rand_score, accuracy_score
from scipy.optimize import linear_sum_assignment as hungarian
from sklearn import metrics
from munkres import Munkres
import matplotlib.pyplot as plt

# Set CUDA device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Utility Classes and Functions
class AverageMeter(object):
    def __init__(self):
        self.reset()
    def reset(self):
        self.val, self.avg, self.sum, self.count = 0, 0, 0, 0
    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

def evaluate(label, pred):
    nmi = normalized_mutual_info_score(label, pred)
    ari = adjusted_rand_score(label, pred)
    pred_adjusted = get_y_preds(label, pred, max(len(set(label)), len(set(pred))))
    acc = accuracy_score(pred_adjusted, label)
    return nmi, ari, acc

def get_y_preds(y_true, cluster_assignments, n_clusters):
    confusion_matrix = metrics.confusion_matrix(y_true, cluster_assignments, labels=None)
    cost_matrix = calculate_cost_matrix(confusion_matrix, n_clusters)
    indices = Munkres().compute(cost_matrix)
    kmeans_to_true_cluster_labels = get_cluster_labels_from_indices(indices)
    if np.min(cluster_assignments) != 0:
        cluster_assignments = cluster_assignments - np.min(cluster_assignments)
    y_pred = kmeans_to_true_cluster_labels[cluster_assignments]
    return y_pred

def calculate_cost_matrix(C, n_clusters):
    cost_matrix = np.zeros((n_clusters, n_clusters))
    for j in range(n_clusters):
        s = np.sum(C[:, j])
        for i in range(n_clusters):
            t = C[i, j]
            cost_matrix[j, i] = s - t
    return cost_matrix

def get_cluster_labels_from_indices(indices):
    n_clusters = len(indices)
    cluster_labels = np.zeros(n_clusters)
    for i in range(n_clusters):
        cluster_labels[i] = indices[i][1]
    return cluster_labels

def cluster_with_methods(latent, true_label, n_classes):
    results = {}

    # Agglomerative Clustering
    agg = AgglomerativeClustering(n_clusters=n_classes)
    pred_agg = agg.fit_predict(latent)
    nmi, ari, acc = evaluate(true_label, pred_agg)
    results['Agglomerative'] = {'nmi': nmi, 'ari': ari, 'acc': acc, 'labels': pred_agg}

    return results

class AttentionLayer(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.query = nn.Linear(input_dim, input_dim)
        self.key = nn.Linear(input_dim, input_dim)
        self.value = nn.Linear(input_dim, input_dim)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        Q = self.query(x)
        K = self.key(x)
        V = self.value(x)

        attention_scores = torch.matmul(Q, K.transpose(-2, -1)) / (x.size(-1) ** 0.5)
        attention_weights = self.softmax(attention_scores)
        output = torch.matmul(attention_weights, V)
        return output

class DeepAttentionEncoder(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        self.attention1 = AttentionLayer(input_dim)
        self.norm1 = nn.LayerNorm(input_dim)
        self.ffn1 = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, input_dim)
        )
        self.norm2 = nn.LayerNorm(input_dim)

    def forward(self, x):
        attn_out = self.attention1(x)
        x = self.norm1(x + attn_out)
        ffn_out = self.ffn1(x)
        x = self.norm2(x + ffn_out)
        return x

class AutoEncoder(torch.nn.Module):
    def __init__(
        self,
        num_genes,
        hidden_size=128,
        dropout=0.6,
        masked_data_weight=.75,
        mask_loss_weight=0.7,
    ):
        super().__init__()
        self.num_genes = num_genes
        self.masked_data_weight = masked_data_weight
        self.mask_loss_weight = mask_loss_weight

        # Enhanced Encoder with Deep Attention
        self.encoder = nn.Sequential(
            nn.Dropout(p=dropout),
            nn.Linear(self.num_genes, 256),
            nn.LayerNorm(256),
            nn.GELU(),
            DeepAttentionEncoder(256, 512),
            nn.Linear(256, hidden_size),
            nn.LayerNorm(hidden_size),
            nn.GELU(),
            DeepAttentionEncoder(hidden_size, hidden_size * 2),
            nn.Linear(hidden_size, hidden_size)
        )

        # Enhanced Mask Predictor with Attention
        self.mask_predictor = nn.Sequential(
            DeepAttentionEncoder(hidden_size, hidden_size*4),
            nn.Linear(hidden_size, num_genes)
        )

        # Enhanced Decoder with Attention
        self.decoder = nn.Sequential(
            DeepAttentionEncoder(hidden_size+num_genes, (hidden_size+num_genes)*2),
            nn.Linear(hidden_size+num_genes, num_genes)
        )

    def forward_mask(self, x):
        latent = self.encoder(x)
        predicted_mask = self.mask_predictor(latent)
        reconstruction = self.decoder(
            torch.cat([latent, predicted_mask], dim=1))
        return latent, predicted_mask, reconstruction

    def loss_mask(self, x, y, mask):
        latent, predicted_mask, reconstruction = self.forward_mask(x)
        w_nums = mask * self.masked_data_weight + (1 - mask) * (1 - self.masked_data_weight)
        reconstruction_loss = (1-self.mask_loss_weight) * torch.mul(
            w_nums, mse(reconstruction, y, reduction='none'))
        mask_loss = self.mask_loss_weight * \
            bce_logits(predicted_mask, mask, reduction="mean")
        reconstruction_loss = reconstruction_loss.mean()
        loss = reconstruction_loss + mask_loss
        return latent, loss

    def feature(self, x):
        latent = self.encoder(x)
        return latent

# Data Processing Classes
default_svd_params = {
    "n_components": 128,
    "random_state": 42,
    "n_oversamples": 20,
    "n_iter": 7,
}

class scRNADataset(Dataset):
    def __init__(self, config, dataset_name, mode='train'):
        self.config = config
        if mode == 'train':
            self.iterator = self.prepare_training_pairs
        else:
            self.iterator = self.prepare_test_pairs
        self.paths = config["paths"]
        self.dataset_name = dataset_name
        self.data_path = os.path.join(self.paths["data"], dataset_name)
        self.data, self.labels = self._load_data()
        self.data_dim = self.data.shape[1]

    def __len__(self):
        return len(self.data)

    def prepare_training_pairs(self, idx):
        sample = self.data[idx]
        sample_tensor = torch.Tensor(sample)
        cluster = int(self.labels[idx])
        return sample, cluster

    def prepare_test_pairs(self, idx):
        sample = self.data[idx]
        cluster = int(self.labels[idx])
        return sample, cluster

    def __getitem__(self, index):
        return self.iterator(index)

    def _load_data(self):
        data, labels = self.load_data(self.data_path)
        n_classes = len(list(set(labels.reshape(-1, ).tolist())))
        self.config["feat_dim"] = data.shape[1]
        if self.config["n_classes"] != n_classes:
            self.config["n_classes"] = n_classes
            print(f"{50 * '>'} Number of classes changed "
                  f"from {self.config['n_classes']} to {n_classes} {50 * '<'}")
        self.data_max = np.max(np.abs(data))
        self.data_min = np.min(np.abs(data))
        return data, labels

    def load_data(self, path):
        data_mat = h5py.File(f"{path}.h5", "r")
        X = np.array(data_mat['X'])
        Y = np.array(data_mat['Y'])

        if Y.dtype != "int64":
            encoder_x = LabelEncoder()
            Y = encoder_x.fit_transform(Y)

        nb_genes = 1000
        X = np.ceil(X).astype(int)
        count_X = X
        print(X.shape, count_X.shape, f"keeping {nb_genes} genes")
        adata = sc.AnnData(X)

        adata = self.normalize(adata,
                               copy=True,
                               highly_genes=nb_genes,
                               size_factors=True,
                               normalize_input=True,
                               logtrans_input=True)
        sorted_genes = adata.var_names[np.argsort(adata.var["mean"])]
        adata = adata[:, sorted_genes]
        X = adata.X.astype(np.float32)

        imputator = IterativeSVDImputator(iters=2)
        imputator.fit(X)
        X = imputator.transform(X)

        return X, Y

    def normalize(self, adata, copy=True, highly_genes=None, filter_min_counts=True,
                  size_factors=True, normalize_input=True, logtrans_input=True):
        if isinstance(adata, sc.AnnData):
            if copy:
                adata = adata.copy()
        elif isinstance(adata, str):
            adata = sc.read(adata)
        else:
            raise NotImplementedError
        norm_error = 'Make sure that the dataset (adata.X) contains unnormalized count data.'
        assert 'n_count' not in adata.obs, norm_error
        if adata.X.size < 50e6:
            if sp.sparse.issparse(adata.X):
                assert (adata.X.astype(int) != adata.X).nnz == 0, norm_error
            else:
                assert np.all(adata.X.astype(int) == adata.X), norm_error

        if filter_min_counts:
            sc.pp.filter_genes(adata, min_counts=1)
            sc.pp.filter_cells(adata, min_counts=1)
        if size_factors or normalize_input or logtrans_input:
            adata.raw = adata.copy()
        else:
            adata.raw = adata
        if size_factors:
            sc.pp.normalize_total(adata)
            adata.obs['size_factors'] = adata.obs.n_counts / \
                np.median(adata.obs.n_counts)
        else:
            adata.obs['size_factors'] = 1.0
        if logtrans_input:
            sc.pp.log1p(adata)
        if highly_genes != None:
            sc.pp.highly_variable_genes(
                adata, min_mean=0.0125, max_mean=3, min_disp=0.7, n_top_genes=highly_genes, subset=True)
        if normalize_input:
            sc.pp.scale(adata)
        return adata

class Loader(object):
    def __init__(self, config, dataset_name, drop_last=True, kwargs={}):
        batch_size = config["batch_size"]
        self.config = config
        train_dataset, test_dataset = self.get_dataset(dataset_name)
        self.data_max = train_dataset.data_max
        self.data_min = train_dataset.data_min

        self.train_loader = DataLoader(
            train_dataset, batch_size=batch_size, shuffle=True, drop_last=drop_last, **kwargs)
        self.test_loader = DataLoader(
            test_dataset, batch_size=batch_size*5, shuffle=False, drop_last=False, **kwargs)

    def get_dataset(self, dataset_name):
        loader_map = {'default_loader': scRNADataset}
        dataset = loader_map[dataset_name] if dataset_name in loader_map.keys() else loader_map['default_loader']
        train_dataset = dataset(self.config, dataset_name=dataset_name, mode='train')
        test_dataset = dataset(self.config, dataset_name=dataset_name, mode='test')
        return train_dataset, test_dataset

def apply_noise(x, mask_probas):
    mask = torch.bernoulli(torch.ones_like(x, device=x.device) * torch.tensor(mask_probas, device=x.device))
    x_corrputed = x * mask
    return x_corrputed, mask

def inference(model, data_loader):
    model.eval()
    latent_list = []
    label_list = []
    with torch.no_grad():
        for x, y in data_loader:
            x = x.float().to(device)
            latent = model.feature(x)
            latent_list.append(latent.cpu().numpy())
            label_list.append(y.numpy())
    latent_all = np.concatenate(latent_list, axis=0)
    label_all = np.concatenate(label_list, axis=0)
    return latent_all, label_all

def train(args):
    data_load = Loader(args, dataset_name=args["dataset"], drop_last=True)
    data_loader = data_load.train_loader
    data_loader_test = data_load.test_loader
    x_shape = args["data_dim"]

    results = []
    init_lr = args["learning_rate"]
    max_epochs = args["epochs"]
    mask_probas = torch.tensor([0.7]*x_shape, device=device)

    model = AutoEncoder(
        num_genes=x_shape,
        hidden_size=128,
        masked_data_weight=0.75,
        mask_loss_weight=0.7
    ).to(device)

    model_checkpoint = 'model_checkpoint.pth'
    optimizer = torch.optim.Adam(model.parameters(), lr=init_lr)

    for epoch in range(max_epochs):
        model.train()
        meter = AverageMeter()
        for i, (x, y) in enumerate(data_loader):
            x = x.float().to(device)
            x_corrputed, mask = apply_noise(x, mask_probas)
            optimizer.zero_grad()
            x_corrputed_latent, loss_ae = model.loss_mask(x_corrputed, x, mask)
            loss_ae.backward()
            optimizer.step()
            meter.update(loss_ae.detach().cpu().numpy())

        if epoch % 10 == 0 or epoch == max_epochs - 1:
            print(f"Epoch {epoch}/{max_epochs}, Loss: {meter.avg:.4f}")

        if epoch == 80 or epoch == max_epochs - 1:
            latent, true_label = inference(model, data_loader_test)

            clustering_results = cluster_with_methods(latent, true_label, args["n_classes"])

            print(f"\nEvaluation Results at Epoch {epoch}:")
            print("="*60)
            print(f"{'Method':<15} | {'NMI':<8} | {'ARI':<8} | {'Accuracy':<8}")
            print("-"*60)

            for method, metrics in clustering_results.items():
                res = {
                    "method": method,
                    "nmi": metrics['nmi'],
                    "ari": metrics['ari'],
                    "acc": metrics['acc'],
                    "dataset": args["dataset"],
                    "epoch": epoch
                }
                results.append(res)
                print(f"{method:<15} | {metrics['nmi']:.4f} | {metrics['ari']:.4f} | {metrics['acc']:.4f}")

            # Visualization
            plt.figure(figsize=(8, 6))
            plt.scatter(latent[:, 0], latent[:, 1], c=clustering_results['Agglomerative']['labels'], cmap='tab20', s=5)
            plt.title(f"Agglomerative Clustering\nNMI: {metrics['nmi']:.2f}, ARI: {metrics['ari']:.2f}")
            plt.colorbar()
            plt.tight_layout()
            plt.savefig(os.path.join(args["save_path"], f"clustering_results_epoch_{epoch}.png"))
            plt.close()

            np.save(os.path.join(args["save_path"], f"embedding_{epoch}.npy"), latent)
            pd.DataFrame({"True": true_label}).to_csv(os.path.join(args["save_path"], f"true_labels_{epoch}.csv"))

    torch.save({"optimizer": optimizer.state_dict(), "model": model.state_dict()}, model_checkpoint)

    # Save final results
    results_df = pd.DataFrame(results)
    results_df.to_csv(os.path.join(args["save_path"], "final_results.csv"), index=False)

    return results

#  
if __name__ == "__main__":
    # Set random seeds
    seed = 42
    random.seed(seed)
    torch.manual_seed(seed)
    np.random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

    # Configuration
    args = {
        "num_workers": 10,
        "paths": {
            "data": ".",
            "results": "./results"
        },
        'batch_size': 256,
        "data_dim": 1000,
        'n_classes': 7,
        'epochs': 150,
        "dataset": "Pollen",
        "learning_rate": 0.0008,
        "latent_dim": 64,
        "save_path": "./results/Pollen"
    }

    os.makedirs(args["save_path"], exist_ok=True)
    os.makedirs(args["paths"]["results"], exist_ok=True)

    print("Starting training with configuration:")
    print(args)

    results = train(args)

    print("\nFinal Results Summary:")
    print("="*60)
    print(f"{'Method':<15} | {'NMI':<8} | {'ARI':<8} | {'Accuracy':<8}")
    print("-"*60)

    final_results = pd.read_csv(os.path.join(args["save_path"], "final_results.csv"))
    method_results = final_results[final_results['method'] == 'Agglomerative']
    best_idx = method_results['nmi'].idxmax()
    best = method_results.loc[best_idx]
    print(f"{'Agglomerative':<15} | {best['nmi']:.4f} | {best['ari']:.4f} | {best['acc']:.4f}")

Starting training with configuration:
{'num_workers': 10, 'paths': {'data': '.', 'results': './results'}, 'batch_size': 256, 'data_dim': 1000, 'n_classes': 7, 'epochs': 150, 'dataset': 'Pollen', 'learning_rate': 0.0008, 'latent_dim': 64, 'save_path': './results/Pollen'}
(301, 21721) (301, 21721) keeping 1000 genes


  return fn(*args_all, **kw)


>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> Number of classes changed from 11 to 11 <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
(301, 21721) (301, 21721) keeping 1000 genes


  return fn(*args_all, **kw)
  mask = torch.bernoulli(torch.ones_like(x, device=x.device) * torch.tensor(mask_probas, device=x.device))


Epoch 0/150, Loss: 0.7565
Epoch 10/150, Loss: 0.6171
Epoch 20/150, Loss: 0.5949
Epoch 30/150, Loss: 0.5776
Epoch 40/150, Loss: 0.5795
Epoch 50/150, Loss: 0.5682
Epoch 60/150, Loss: 0.5670
Epoch 70/150, Loss: 0.5602
Epoch 80/150, Loss: 0.5559

Evaluation Results at Epoch 80:
Method          | NMI      | ARI      | Accuracy
------------------------------------------------------------
Agglomerative   | 0.9351 | 0.9391 | 0.9269


  mask = torch.bernoulli(torch.ones_like(x, device=x.device) * torch.tensor(mask_probas, device=x.device))


Epoch 90/150, Loss: 0.5558
Epoch 100/150, Loss: 0.5425
Epoch 110/150, Loss: 0.5378
Epoch 120/150, Loss: 0.5322
Epoch 130/150, Loss: 0.5257
Epoch 140/150, Loss: 0.5184
Epoch 149/150, Loss: 0.5114

Evaluation Results at Epoch 149:
Method          | NMI      | ARI      | Accuracy
------------------------------------------------------------
Agglomerative   | 0.9011 | 0.8780 | 0.8904

Final Results Summary:
Method          | NMI      | ARI      | Accuracy
------------------------------------------------------------
Agglomerative   | 0.9351 | 0.9391 | 0.9269
