<a href="https://colab.research.google.com/github/jsvir/idc/blob/main/idc_evaluate.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Interpretable Deep Clustering (for Tabular Data)

You can use **any** data, also images, but for now the model supports only samples similar to table rows: each sample should be a d-dimensional vector.
The main goal of our method is to discover clusters assignments for the dataset samples and provide **local** (sample-level) and **global** (cluster-level) interpretations. The interpretations are the feature ids that are have the most important information for clustering and are potentially not representing the data noise.

## Model Description:

<img src="https://github.com/jsvir/idc/tree/main/img/img.png" width="500">

We train a Gating Neural Network together with autoencoder with reconstruction objective while our goal to reconstruct the sample x from the gated version of it (x * z). Then we train the clustering head to discover the clustering of the samples. The last stage is to train the auxiliary classifier that trains the global gates matrix for cluster-level features (interpretations). In addition, we add more sub-steps for training that serve as augmentations to the main stages. We add random binary noise to the input samples, we add noise to the latent embeddings (after encoder) and we start train the autoencoder without gating network.

Next, we will go step-by-step with MNIST example to show how the training is done. If you find something unclear, please, let us know.

### Step 0: config file definitions

| Key                                  | Required / Optional | Example Value                  | Description                                                                                                                                                                                                                                                 |
|--------------------------------------|---------------------|--------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| dataset                              | Required            | MNIST10K                       | *The dataset class name in dataset.py. The dataset class should implement **setup** function for preprocessing.*                                                                                                                                            |
| data_dir                             | Optional            | C:/data/fs/mnist               | *Dataset directory path. Depends if your **setup** function needs it*                                                                                                                                                                                       |
| scaler                               | Optional            | MinMaxScaler                   | *Depends if it is required by your dataset **setup** function*                                                                                                                                                                                              |
| batch_size                           | Required            | 100                            | *Training and evaluation batch size*                                                                                                                                                                                                                        |
| epochs                               | Required            | 100                            | *How many epochs to train the model (total epochs)*                                                                                                                                                                                                         |
| seeds                                | Required            | 1                              | *How many random intializations for model re-training*                                                                                                                                                                                                      |
| ae_non_gated_epochs                  | Required            | 10                             | *Number of epochs for autoencoder pre-training without gating network.*                                                                                                                                                                                     |
| ae_pretrain_epochs                   | Required            | 20                             | *Number of epochs for autoencoder pre-training with gating network.*                                                                                                                                                                                        |
| start_global_gates_training_on_epoch | Required            | 50                             | *After this number of epochs we start to train aux. classifier with global gates.*                                                                                                                                                                          |
| mask_percentage                      | Required            | 0.9                            | *The random subset of features that will be masked by zero gates. The tuning of this parameter should be based on reconstruction loss convergence. For better convergence try smaller values. Far sparse mask try larger*                                   |
| latent_noise_std                     | Required            | 0.01                           | *The std value for random normal noise with mean=1 that multiplies latent embeddings (H) outputed by the encoder.*                                                                                                                                          |
| gtcr_loss                            | Optional            | true                           | *Use it to encourge features uniquness at sample-level (the model will try to find the unique set of features for each sample.*                                                                                                                             |
| gtcr_projection_dim                  | Optional            | null                           | *For large number of features (>10K) it will apply a random projection to the smaller dimension which affects only the GTCR loss*                                                                                                                           |
| gtcr_eps                             | Optional            | 1                              | *Code Reduction Rate precision parameter *                                                                                                                                                                                                                  |
| eps                                  | Required            | 0.1                            | *Clustering head loss is trained with code reduction rate -based objective with precision parameter. Notice, that here the loss operates on latent embedding and helps to cluster them while gtcr operates on gates only and try to seperate between them.* |
| use_gating                           | Required            | true                           | *If trained with Gating Network.*                                                                                                                                                                                                                           |
| gates_hidden_dim                     | Required            | 784                            | *The hodden layer dimension in Gating Network.*                                                                                                                                                                                                             |
| encdec                               | Required            | [512,512,2048,10,2048,512,512] | *Autoencoder architecture. Each value represents the hidden layer dimension*                                                                                                                                                                                |
| clustering_head                      | Required            | [10, 2048]                     | *Clustering head dimension. The input dimension and the hidden dimension.*                                                                                                                                                                                  |
| tau                                  | Required            | 100                            | *Tempretaure for GumbleSoftmax. We used a fixed value but you can try also to change it dring the training*                                                                                                                                                 |
| aux_classifier                       | Required            | 2048                           | *TThe dimension of the hidden layer in the aux classifier*                                                                                                                                                                                                  |
| local_gates_lambda                   | Required            | 1                              | *The weight of the sparsity loss term in the total clustering loss computation.*                                                                                                                                                                            |
| global_gates_lambda                  | Required            | 1                              | *The weight of the sparsity loss term in the total aix classifier loss computation.*                                                                                                                                                                        |
| gtcr_lambda                          | Required            | 0.01                           | *The weight of the uniqness loss term in the total clustering loss computation.*                                                                                                                                                                            |
| lr.pretrain                          | Required            | 1e-3                           | *The learning rate for the autoencoder and gating networks.*                                                                                                                                                                                                |
| lr.clustering                        | Required            | 1e-2                           | *The learning rate for the clustering head.*                                                                                                                                                                                                                |
| lr.aux_classifier                    | Required            | 1e-2                           | *The learning rate for the aux classifier and global gates matrix.*                                                                                                                                                                                         |
| sched.pretrain_min_lr                | Required            | 1e-6                           | *The min learning rate for the autoencoder and gating networks.*                                                                                                                                                                                            |
| sched.clustering_min_lr              | Required            | 1e-6                           | *The min learning rate for the clustering head.*                                                                                                                                                                                                            |
| save_seed_checkpoints                | Required            | false                          | *Change it to true if you would like to save the checkpoint.*                                                                                                                                                                                               |
| validate                             | Optional               | true                           | *If you have a labeled data and wish to check the method on it (like MNIST) then use true. Otherwise false*                                                                                                                                                 |

And finally there are some additional pytorch-lightning configs you should provide but it could remain the same valeus as below:

trainer:
  devices: 1
  accelerator: gpu
  max_epochs: *epochs
  deterministic: true
  logger: true
  log_every_n_steps: 10
  check_val_every_n_epoch: 10
  enable_checkpointing: false
  

We clone the repo and print the yaml config file we will use for MNIST.

In [1]:
!git clone https://github.com/jsvir/idc.git

Cloning into 'idc'...
remote: Enumerating objects: 66, done.[K
remote: Counting objects: 100% (66/66), done.[K
remote: Compressing objects: 100% (45/45), done.[K
remote: Total 66 (delta 30), reused 46 (delta 19), pack-reused 0[K
Receiving objects: 100% (66/66), 18.08 MiB | 12.82 MiB/s, done.
Resolving deltas: 100% (30/30), done.


In [11]:
# !rm -r idc

In [3]:
!cd idc && pip install -r requirements.txt

Collecting pytorch-lightning==2.0.0 (from -r requirements.txt (line 2))
  Downloading pytorch_lightning-2.0.0-py3-none-any.whl.metadata (24 kB)
Collecting scikit-learn==1.1.2 (from -r requirements.txt (line 3))
  Downloading scikit_learn-1.1.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (10 kB)
Collecting scipy==1.10.0-rc1 (from -r requirements.txt (line 4))
  Downloading scipy-1.10.0rc1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (58 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m58.9/58.9 kB[0m [31m3.0 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting omegaconf==2.2.3 (from -r requirements.txt (line 5))
  Downloading omegaconf-2.2.3-py3-none-any.whl.metadata (3.9 kB)
Collecting matplotlib==3.6.3 (from -r requirements.txt (line 6))
  Downloading matplotlib-3.6.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.8 kB)
Collecting matplotlib-inline==0.1.6 (from -r requirements.txt (line 7))
  Downloa

In [4]:
with open("idc/cfg/cfg_mnist.yaml") as f:
    for line in f.readlines():
        print(line.strip())

dataset: MNIST10K
data_dir: idc/data
scaler: MinMaxScaler
batch_size: 100
seeds: 1
epochs: &epochs 700

ae_non_gated_epochs: 10
ae_pretrain_epochs: 300
start_global_gates_training_on_epoch: 400

mask_percentage: 0.9
latent_noise_std: 0.01

trainer:
devices: 1
accelerator: gpu
max_epochs: *epochs
deterministic: true
logger: true
log_every_n_steps: 10
check_val_every_n_epoch: 10
enable_checkpointing: false
num_sanity_val_steps: 0


# GTCR loss
gtcr_loss: true
gtcr_projection_dim: null # for large number of features use it
gtcr_eps: 1


# Compression loss
eps: 0.1

# Gating Net
use_gating: true
gates_hidden_dim: 784

# EncoderDecoder
encdec:
- 512
- 512
- 2048
- &bn_layer 10
- 2048
- 512
- 512

clustering_head:
- *bn_layer
- 2048

tau: 100

aux_classifier:
- 2048

local_gates_lambda: 1
global_gates_lambda: 0.0001
gtcr_lambda: 0.01

lr:
pretrain: 1e-3
clustering: 1e-2
aux_classifier: 1e-2

sched:
pretrain_min_lr: 1e-6
clustering_min_lr: 1e-6



save_seed_checkpoints: false
validate: true


### Step 1: add your dataset class

Assuming you have a dataset numpy files that are ready for training, this is the minimal code you need (*The X values should pass z-score prerpocessing. For some cases like MNIST dataset MinMax(0,1) could be also applied*):

In [6]:
import sys
sys.path.append("idc")
from dataset import ClusteringDataset
from sklearn import preprocessing
from torchvision.datasets import MNIST


class MNIST10K(ClusteringDataset):
    def __init__(self, data, targets):
        super().__init__(data, targets)

    @classmethod
    def setup(cls, cfg):
        scaler = getattr(preprocessing, cfg.scaler)()
        X = MNIST(cfg.data_dir, train=True, download=True).data.reshape(-1, 784).cpu().numpy()
        Y = MNIST(cfg.data_dir, train=True, download=True).targets.cpu().numpy()
        X = scaler.fit_transform(X)
        X = X[:10000]
        Y = Y[:10000]
        return cls(X, Y)

Now we manually copy it to the dataset.py file.


### Step 1: run clustering training

In [None]:
import torch
from omegaconf import OmegaConf
import numpy as np
from pytorch_lightning import Trainer, seed_everything
import os
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import LearningRateMonitor
from train_evaluate import BaseModule


cfg = OmegaConf.load("idc/cfg/cfg_mnist.yaml")
torch.use_deterministic_algorithms(True)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
if not cfg.validate:
    cfg.trainer.check_val_every_n_epoch = cfg.trainer.max_epochs + 1 # the validation will be never done
with open(f"results_example.txt", mode='a') as f:
    header = '\t'.join(['seed', 'acc', 'ari', 'nmi', 'local_gates', 'global_gates',
                        'topk_max_silhouette_score', 'topk_min_dbi_score'])
    f.write(f"{header}\n")


for seed in range(cfg.seeds):
    cfg.seed = seed
    seed_everything(seed)
    np.random.seed(seed)
    if not os.path.exists(cfg.dataset):
        os.makedirs(cfg.dataset)
    model = BaseModule(cfg)
    logger = TensorBoardLogger(cfg.dataset, name="example", log_graph=False)
    trainer = Trainer(**cfg.trainer, callbacks=[LearningRateMonitor(logging_interval='step')])
    trainer.logger = logger
    trainer.fit(model)
    topk_max_siluetter_score = np.mean(sorted(model.max_silhouette_score, reverse=True)[:10])
    topk_min_dbi_score = np.mean(sorted(model.max_silhouette_score)[:10])
    results_str = '\t'.join(
        [f'{seed}',
         f'{model.best_acc}',
         f'{model.best_ari}',
         f'{model.best_nmi}',
         f'{model.best_local_feats}',
         f'{model.best_global_feats}',
         f'{topk_max_siluetter_score}',
         f'{topk_min_dbi_score}',
         ])
    with open(f"results_example.txt", mode='a') as f:
        f.write(f"{results_str}\n")
        f.flush()


INFO:lightning_fabric.utilities.seed:Global seed set to 0


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to idc/data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 15900186.11it/s]


Extracting idc/data/MNIST/raw/train-images-idx3-ubyte.gz to idc/data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to idc/data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 478977.06it/s]


Extracting idc/data/MNIST/raw/train-labels-idx1-ubyte.gz to idc/data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to idc/data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:01<00:00, 1242948.13it/s]


Extracting idc/data/MNIST/raw/t10k-images-idx3-ubyte.gz to idc/data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to idc/data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 3071675.07it/s]


Extracting idc/data/MNIST/raw/t10k-labels-idx1-ubyte.gz to idc/data/MNIST/raw

X.shape:  (10000, 784)
X.min=0.0, X.max=1.0
Y.shape:  (10000,)
0: 1001
1: 1127
2: 991
3: 1032
4: 980
5: 863
6: 1014
7: 1070
8: 944
9: 978
Y.min=0, Y.max=9
Dataset length: 10000


INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name                | Type                          | Params
----------------------------------------------------------------------
0 | gating_net          | GatingNet                     | 1.2 M 
1 | encdec              | EncoderDecoder                | 3.5 M 
2 | clustering_head     | Sequential                    | 47.1 K
3 | aux_classifier_head | Sequential                    | 1.6 M 
4 | mcrr                | MaximalCodingRateReduction    | 0     
5 | gtcr_loss           | TotalCodingRateWithProjection | 0     
-------------------

Cosine annealing LR scheduling is applied during 40000 steps


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

New best accuracy: 0.1127


Validation: 0it [00:00, ?it/s]

New best accuracy: 0.7015


Validation: 0it [00:00, ?it/s]

New best accuracy: 0.728


Validation: 0it [00:00, ?it/s]

New best accuracy: 0.7341


Validation: 0it [00:00, ?it/s]

New best accuracy: 0.7699


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

New best accuracy: 0.8033


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

New best accuracy: 0.8068


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

New best accuracy: 0.8115


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

New best accuracy: 0.8165


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

New best accuracy: 0.8214


Validation: 0it [00:00, ?it/s]

New best accuracy: 0.8303


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

New best accuracy: 0.8311


Validation: 0it [00:00, ?it/s]

New best accuracy: 0.8322


Validation: 0it [00:00, ?it/s]

New best accuracy: 0.8329


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

New best accuracy: 0.8334


Validation: 0it [00:00, ?it/s]

New best accuracy: 0.834


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

New best accuracy: 0.8351


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

### Step 2: check you results output file

We measure the next metrics:

1. ACC
2. ARI
3. NMI
4. BD-Index
5. Silhouette Score

In [10]:
with open("results_example.txt") as f:
  for line in f.readlines():
    print(line.strip())

seed	acc	ari	nmi	local_gates	global_gates	topk_max_silhouette_score	topk_min_dbi_score
0	0.8351	0.7393671948823729	0.7698711335813655	14.817920427322388	393.8912658691406	0.1359824538230896	0.1293429434299469
