<a href="https://colab.research.google.com/github/jsvir/idc/blob/main/idc_example.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Interpretable Deep Clustering (for Tabular Data)

You can use **any** data, also images, but for now the model supports only samples similar to table rows: each sample should be a d-dimensional vector.
The main goal of our method is to discover clusters assignments for the dataset samples and provide **local** (sample-level) and **global** (cluster-level) interpretations. The interpretations are the feature ids that are have the most important information for clustering and are potentially not representing the data noise.

## Model Description:

<img src="https://github.com/jsvir/idc/tree/main/img/img.png" width="500">

We train a Gating Neural Network together with autoencoder with reconstruction objective while our goal to reconstruct the sample x from the gated version of it (x * z). Then we train the clustering head to discover the clustering of the samples. The last stage is to train the auxiliary classifier that trains the global gates matrix for cluster-level features (interpretations). In addition, we add more sub-steps for training that serve as augmentations to the main stages. We add random binary noise to the input samples, we add noise to the latent embeddings (after encoder) and we start train the autoencoder without gating network.

Next, we will go step-by-step with PBMC-2 dataset which is a binary-class subset of the original PBMC [1] dataset. We select two categories that have the most number of samples in the original set. In addition, we remove all zero columns from the data resulting in 17,126 featurees × 20,742 samples size  example to show how the training is done. If you find something unclear, please, let us know.



[1] Zheng, G. X., Terry, J. M., Belgrader, P., Ryvkin, P., Bent, Z. W., Wilson, R., Ziraldo, S. B., Wheeler, T. D., McDermott, G. P., Zhu, J., et al. Massively parallel digital transcriptional profiling of single cells. Nature communications, 8(1):14049, 2017.

### Step 0: config file definitions

| Key                                  | Required / Optional | Example Value                   | Description                                                                                                                                                                                                                                                 |
|--------------------------------------|---------------------|---------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| filepath_samples                              | Required            | C:/data/fs/pbmc/pbmc_x.npz      | *Filepath to the npz file with samples table. The table should be of shape [num of samples, num of features]*                                                                                                                                               |
| num_clusters                             | Required            | 2                               | *Expected number of clusters*                                                                                                                                                                                                                               |
| batch_size                           | Required            | 256                             | *Training and evaluation batch size*                                                                                                                                                                                                                        |
| epochs                               | Required            | 200                             | *How many epochs to train the model (total epochs)*                                                                                                                                                                                                         |
| seeds                                | Required            | 1                               | *How many random intializations for model re-training*                                                                                                                                                                                                      |
| ae_non_gated_epochs                  | Required            | 50                              | *Number of epochs for autoencoder pre-training without gating network.*                                                                                                                                                                                     |
| ae_pretrain_epochs                   | Required            | 100                             | *Number of epochs for autoencoder pre-training with gating network.*                                                                                                                                                                                        |
| start_global_gates_training_on_epoch | Required            | 150                             | *After this number of epochs we start to train aux. classifier with global gates.*                                                                                                                                                                          |
| mask_percentage                      | Required            | 0.9                             | *The random subset of features that will be masked by zero gates. The tuning of this parameter should be based on reconstruction loss convergence. For better convergence try smaller values. Far sparse mask try larger*                                   |
| latent_noise_std                     | Required            | 0.01                            | *The std value for random normal noise with mean=1 that multiplies latent embeddings (H) outputed by the encoder.*                                                                                                                                          |
| gtcr_loss                            | Optional            | true                            | *Use it to encourge features uniquness at sample-level (the model will try to find the unique set of features for each sample.*                                                                                                                             |
| gtcr_projection_dim                  | Optional            | 1024                            | *For large number of features (>10K) it will apply a random projection to the smaller dimension which affects only the GTCR loss*                                                                                                                           |
| gtcr_eps                             | Optional            | 1                               | *Code Reduction Rate precision parameter *                                                                                                                                                                                                                  |
| eps                                  | Required            | 0.1                             | *Clustering head loss is trained with code reduction rate -based objective with precision parameter. Notice, that here the loss operates on latent embedding and helps to cluster them while gtcr operates on gates only and try to seperate between them.* |
| use_gating                           | Required            | true                            | *If trained with Gating Network.*                                                                                                                                                                                                                           |
| gates_hidden_dim                     | Required            | 784                             | *The hodden layer dimension in Gating Network.*                                                                                                                                                                                                             |
| encdec                               | Required            | [512,512,2048,128,2048,512,512] | *Autoencoder architecture. Each value represents the hidden layer dimension*                                                                                                                                                                                |
| clustering_head                      | Required            | [128, 2048]                     | *Clustering head dimension. The input dimension and the hidden dimension.*                                                                                                                                                                                  |
| tau                                  | Required            | 100                             | *Tempretaure for GumbleSoftmax. We used a fixed value but you can try also to change it dring the training*                                                                                                                                                 |
| aux_classifier                       | Required            | 2048                            | *TThe dimension of the hidden layer in the aux classifier*                                                                                                                                                                                                  |
| local_gates_lambda                   | Required            | 1                               | *The weight of the sparsity loss term in the total clustering loss computation.*                                                                                                                                                                            |
| global_gates_lambda                  | Required            | 1                               | *The weight of the sparsity loss term in the total aix classifier loss computation.*                                                                                                                                                                        |
| gtcr_lambda                          | Required            | 0.01                            | *The weight of the uniqness loss term in the total clustering loss computation.*                                                                                                                                                                            |
| lr.pretrain                          | Required            | 1e-3                            | *The learning rate for the autoencoder and gating networks.*                                                                                                                                                                                                |
| lr.clustering                        | Required            | 1e-3                            | *The learning rate for the clustering head.*                                                                                                                                                                                                                |
| lr.aux_classifier                    | Required            | 1e-1                            | *The learning rate for the aux classifier and global gates matrix.*                                                                                                                                                                                         |
| sched.pretrain_min_lr                | Required            | 1e-4                            | *The min learning rate for the autoencoder and gating networks.*                                                                                                                                                                                            |
| sched.clustering_min_lr              | Required            | 1e-4                            | *The min learning rate for the clustering head.*                                                                                                                                                                                                            |
| save_seed_checkpoints                | Required            | false                           | *Change it to true if you would like to save the checkpoint.*                                                                                                                                                                                               |

And finally there are some additional pytorch-lightning configs you should provide but it could remain the same valeus as below:

trainer:
  devices: 1
  accelerator: gpu
  max_epochs: *epochs
  deterministic: true
  logger: true
  log_every_n_steps: 10
  check_val_every_n_epoch: 10
  enable_checkpointing: false
  

We clone the repo and print the yaml config file we will use for PBMC:

In [None]:
!git clone https://github.com/jsvir/idc.git && cd idc && pip install -r requirements.txt

Cloning into 'idc'...
remote: Enumerating objects: 15, done.[K
remote: Counting objects: 100% (15/15), done.[K
remote: Compressing objects: 100% (11/11), done.[K
remote: Total 15 (delta 2), reused 15 (delta 2), pack-reused 0[K
Receiving objects: 100% (15/15), 16.64 MiB | 12.59 MiB/s, done.
Resolving deltas: 100% (2/2), done.
Collecting torch==2.0.1 (from -r requirements.txt (line 1))
  Downloading torch-2.0.1-cp310-cp310-manylinux1_x86_64.whl.metadata (24 kB)
Collecting pytorch-lightning==2.0.0 (from -r requirements.txt (line 2))
  Downloading pytorch_lightning-2.0.0-py3-none-any.whl.metadata (24 kB)
Collecting scikit-learn==1.1.2 (from -r requirements.txt (line 3))
  Downloading scikit_learn-1.1.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (10 kB)
Collecting scipy==1.9.3 (from -r requirements.txt (line 4))
  Downloading scipy-1.9.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (58 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [None]:
with open("cfg/cfg_run.yaml") as f:
    for line in f.readlines():
        print(line.strip())

### Step 1: data preparation

We provide the PBMC-2 dataset in the data directory, unzip it:


In [None]:
!unzip data/pbmc_x.zip


### Step 2: run clustering training

We modify the function that plots the clustering to show the plot in the notebook:

In [None]:
import umap
import matplotlib.pyplot as plt

def plot_clustering(val_embs_list, cluster_mtx, current_epoch, silhouette, dbi):
    reducer = umap.UMAP(n_neighbors=10, min_dist=0.1, n_components=2, random_state=0)
    embedding = reducer.fit_transform(torch.cat(val_embs_list, dim=0).cpu().numpy())
    plt.figure(figsize=(5, 3))
    plt.scatter(embedding[:, 0], embedding[:, 1], c=cluster_mtx.numpy(), s=20, edgecolor='k')
    plt.title(f'Clustering (UMAP). Epoch: {current_epoch}. Silhouette: {silhouette}. DBI: {dbi}')
    plt.show()

Now we train the model

In [None]:
!tensorboard --logdir logs\example --port 6006 &

In [None]:
from tensorboard import notebook
notebook.list() # View open TensorBoard instances
notebook.display(port=6006, height=1000)

In [None]:
import torch
from omegaconf import OmegaConf
import numpy as np
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import LearningRateMonitor
from run import BaseModule


cfg = OmegaConf.load("cfg/cfg_run.yaml")
torch.use_deterministic_algorithms(True)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
seed=777
cfg.seed = seed
seed_everything(seed)
np.random.seed(seed)
model = BaseModule(cfg)
model.plot_clustering = plot_clustering
logger = TensorBoardLogger("logs", name="example", log_graph=False)
trainer = Trainer(**cfg.trainer, callbacks=[LearningRateMonitor(logging_interval='step')])
trainer.logger = logger
trainer.fit(model)