# Concept Bottleneck Models: Dot Training Example

This very short notebook will showcase how to set up a Concept Embedding Model
(CEM) using our library and train it on the Dot dataset proposed in our CEM
NeurIPS 2022 paper.

Our example is composed by four different steps:
1. Loading the dataset of interest in a format that can be "digested" by our models.
2. Instantiating a CEM with the embedding size and encoder/decoder architectures we want to use.
3. Training the CEM on the Dot dataset.
4. Evaluating the CEM's task accuracy, concept AUC, and concept alignment score (CAS).

In [1]:
%load_ext autoreload
%autoreload 2

## Step 1: Load Data

As a first step, we will show you how one can generate a dataset from scratch
that is compatible with how our training pipeline is set.

In practice, you can train any CEM (or CBM variant) using our library as long as
your dataset is structured such that:
1. It is contained within a Pytorch DataLoader object.
2. Every sample contains is a tuple with three elements in it: the sample $\mathbf{x} \in \mathbb{R}^n$, the task label $y \in \{0, \cdots, L -1\}$, and a vector of $k$ binary concept annotations $\mathbf{c} \in \{0, 1\}^k$ (in that order).

Below, we show how we do this for the Dot dataset. For details on the actual
dataset, please refer to our paper.

In [2]:
import numpy as np
import torch
from pytorch_lightning import seed_everything

# We first create a simple helper function to sample random labeled instances
# from the Dot dataset:
def generate_dot_data(size):
    # sample from normal distribution
    emb_size = 2
    # Generate the latent vectors
    v1 = np.random.randn(size, emb_size) * 2
    v2 = np.ones(emb_size)
    v3 = np.random.randn(size, emb_size) * 2
    v4 = -np.ones(emb_size)
    # Generate the sample
    x = np.hstack([v1+v3, v1-v3])
    
    # Now the concept vector
    c = np.stack([
        np.dot(v1, v2).ravel() > 0,
        np.dot(v3, v4).ravel() > 0,
    ]).T
    # And finally the label
    y = ((v1*v3).sum(axis=-1) > 0).astype(np.int64)

    # We NEED to put all of these into torch Tensors (THIS IS VERY IMPORTANT)
    x = torch.FloatTensor(x)
    c = torch.FloatTensor(c)
    y = torch.Tensor(y)
    return x, y, c

In [3]:

# We then use our helper function to generate DataLoaders with the correct
# number of samples in them. We use a separate function for this to avoid
# repeating code to generate the different folds of our dataset:
def data_generator(
    dataset_size,
    batch_size,
    seed=None,
):
    # For the sake of determinism, let's always first seed everything
    # so that things can be recreated
    seed_everything(seed)
    x, y, c = generate_dot_data(dataset_size)
    data = torch.utils.data.TensorDataset(x, y, c)
    dl = torch.utils.data.DataLoader(
        data,
        batch_size=batch_size,
    )
    return dl


In [4]:
# Finally, we generate our training, testing, and validation folds with
# different random seeds

bsz=8

train_dl = data_generator(
    dataset_size=int(3000 * 0.7),
    batch_size=bsz,
    seed=42,
)
test_dl = data_generator(
    dataset_size=int(3000 * 0.2),
    batch_size=bsz,
    seed=43,
)
val_dl = data_generator(
    dataset_size=int(3000 * 0.1),
    batch_size=bsz,
    seed=44,
)

Seed set to 42
Seed set to 43
Seed set to 44


## Step 2: Create CEM Model

Now that we have our dataset in the correct `DataLoader` format, we can
proceed to construct our CEM object. For this, we will simply import
our `ConceptEmbeddingModel` object from the `cem` library. We can then instantiate
a CEM by indicating:
1. The number of concepts `n_concepts` in the dataset we will train it on (e.g., 2 for the Dot dataset).
2. The number of output tasks/labels `n_tasks` in the dataset of interest (e.g., 1 for the binary task in the Dot dataset).
3. The size `emb_size` of each concept embedding.
3. The weight `concept_loss_weight` to use for the concept prediction loss during training of the CEM (e.g., in our paper we set this value to 1 for the Dot dataset).
4. The `learning_rate` and `optimizer` to use during training (e.g., "adam" or "sgd").
5. The probability `training_intervention_prob` to perform a random intervention at training time via RandInt (we recommend setting this to 0.25).
5. The model architecture `c_extractor_arch` to use for the latent code generator (i.e., the model that generates a latent representation to learn embeddings from the input samples).
6. The model `c2y_model` to use as a label predictor **after** all concept embeddings have been generated by a CEM.

The only non-trivial parameters to set for this instantiation are the model
architectures for the latent code generator (passed via the `c_extractor_arch`
argument) and for the label predictor (passed via) the `c2y_model` argument.


The first of these arguments, namely the latent code generator `c_extractor_arch`,
must be provided as a simple Python function that takes as an input a named
argument `output_dim` and generates a model that maps inputs from your task
of interest to a latent code with shape `output_dim`. For our Dot example,
we will do this via a simple MLP (although in practice you can do use an
arbitrarily complex model):

In [5]:
def latent_code_generator_model(output_dim):
    if output_dim is None:
        output_dim = 128
    return torch.nn.Sequential(*[
        # 4 because Dot has inputs with 4 features in them
        torch.nn.Linear(4, 128),
        torch.nn.LeakyReLU(),
        torch.nn.Linear(128, 128),
        torch.nn.LeakyReLU(),
        torch.nn.Linear(128, output_dim),
    ])

The second of these arguments, namely the label predictor `c2y_model`, must
be any valid Pytorch model that takes as an input as many activations as the
CEM's bottleneck (i,e., `n_concepts` * `emb_size`) and generates `n_tasks`
outputs, one for each output label in our dataset's downstream task. If not
provided, or if set to `None`, then by default we will simply attach a linear
mapping after the CEM's bottleneck to obtain the output label prediction.
In practice, this is how a CEM is usually constructed.

In [10]:
# We simply import our CEM class (the same can be done with CBMs to easily train
# any of their variants)
from torchvision.models import *
from protocbm.models.protocem import ProtoCEM 

# And generate the actual model
cem_model = ProtoCEM(
  n_concepts=2, # Number of training-time concepts. Dot has 2
  n_tasks=2, # Number of output labels. Dot is binary so it has 1.
  proto_train_dl=train_dl,
  emb_size=128,  # We will use an embedding size of 128
  concept_loss_weight=1,  # The weight assigned to the concept prediction loss relative to the task predictive loss.
  learning_rate=1e-3,  # The learning rate to use during training.
  optimizer="adam",  # The optimizer to use during training.
  training_intervention_prob=0.25, # RandInt probability. We recommend setting this to 0.25.
  pre_concept_model=latent_code_generator_model(128),  # The model to use to predict the concept vectors.
  dknn_max_neighbours=50,
  dknn_k=10,
)
print(cem_model)

ProtoCEM(
  (x2c_model): Identity()
  (c2y_model): DKNN(
    (soft_sort): NeuralSort()
  )
  (sig): Sigmoid()
  (bottleneck_nonlin): Sigmoid()
  (loss_concept): BCELoss()
  (loss_task): CrossEntropyLoss()
  (dknn_loss_function): DKNNLoss()
  (proto_model): DKNN(
    (soft_sort): NeuralSort()
  )
  (pre_concept_model): Sequential(
    (0): Linear(in_features=4, out_features=128, bias=True)
    (1): LeakyReLU(negative_slope=0.01)
    (2): Linear(in_features=128, out_features=128, bias=True)
    (3): LeakyReLU(negative_slope=0.01)
    (4): Linear(in_features=128, out_features=128, bias=True)
  )
  (concept_context_generators): ModuleList(
    (0-1): 2 x Sequential(
      (0): Linear(in_features=128, out_features=256, bias=True)
      (1): LeakyReLU(negative_slope=0.01)
    )
  )
  (concept_prob_generators): ModuleList(
    (0): Linear(in_features=256, out_features=1, bias=True)
  )
)


## Step 3: Train the CEM

Now that we have both the dataset and the model defined, we can train our CEM
using Pytorch Lightning's wrappers for ease. This should be very simple via
Pytorch Lightning's `Trainer` once the data has been generated:



In [11]:
import pytorch_lightning as pl

trainer = pl.Trainer(
    accelerator="gpu",  # Change to "cpu" if you are not running on a GPU!
    devices="auto", 
    max_epochs=500,  # The number of epochs we will train our model for
    check_val_every_n_epoch=5,  # And how often we will check for validation metrics
    logger=False,  # No logs to be dumped for this trainer
)

# train_dl and val_dl are datasets previously built...
trainer.fit(cem_model, train_dl, val_dl)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

   | Name                       | Type             | Params
-----------------------------------------------------------------
0  | x2c_model                  | Identity         | 0     
1  | c2y_model                  | DKNN             | 0     
2  | sig                        | Sigmoid          | 0     
3  | bottleneck_nonlin          | Sigmoid          | 0     
4  | loss_concept               | BCELoss          | 0     
5  | loss_task                  | CrossEntropyLoss | 0     
6  | dknn_loss_function         | DKNNLoss         | 0     
7  | proto_model                | DKNN             | 0     
8  | pre_concept_model          | Sequential       | 33.7 K
9  | concept_context_generators | ModuleList       | 66.0 K
10 | concept_prob_generators    | ModuleList       | 257   
-------------------

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/home/leenux/work/part3_project/pyenv/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 429.28it/s]
/home/leenux/work/part3_project/pyenv/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 568.80it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 592.04it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 577.04it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 592.92it/s]


Validation: |          | 0/? [00:00<?, ?it/s]

Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 580.51it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 569.03it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 457.26it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 621.66it/s]


==EMPTY ACCURACIES==
(8, 50)
(50,)
(8,)
Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 529.48it/s]


Validation: |          | 0/? [00:00<?, ?it/s]

Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 633.77it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 668.70it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 658.70it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 554.37it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 659.48it/s]


Validation: |          | 0/? [00:00<?, ?it/s]

Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 487.77it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 646.56it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 670.37it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 674.49it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 656.58it/s]


Validation: |          | 0/? [00:00<?, ?it/s]

Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 610.60it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 705.62it/s]


==EMPTY ACCURACIES==
(4, 50)
(50,)
(4,)
Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 684.74it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 696.56it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 653.30it/s]


Validation: |          | 0/? [00:00<?, ?it/s]

Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 672.52it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 693.91it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 707.88it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 683.09it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 696.04it/s]


Validation: |          | 0/? [00:00<?, ?it/s]

Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 704.19it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 684.62it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 674.62it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 562.08it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 635.81it/s]


Validation: |          | 0/? [00:00<?, ?it/s]

Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 482.54it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 548.26it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 707.18it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 678.61it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 696.63it/s]


Validation: |          | 0/? [00:00<?, ?it/s]

Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 576.01it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 585.33it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 524.85it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 674.42it/s]


==EMPTY ACCURACIES==
(8, 50)
(50,)
(8,)
Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 622.75it/s]


Validation: |          | 0/? [00:00<?, ?it/s]

Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 621.23it/s]


==EMPTY ACCURACIES==
(8, 50)
(50,)
(8,)
Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 674.60it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 687.24it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 595.04it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 455.57it/s]


Validation: |          | 0/? [00:00<?, ?it/s]

Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 611.85it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 576.72it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 615.55it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 559.35it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 610.91it/s]


==EMPTY ACCURACIES==
(8, 50)
(50,)
(8,)


Validation: |          | 0/? [00:00<?, ?it/s]

Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 685.01it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 692.64it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 563.82it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 626.21it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 518.19it/s]


Validation: |          | 0/? [00:00<?, ?it/s]

Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 662.50it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 653.32it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 695.04it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 684.45it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 621.61it/s]


Validation: |          | 0/? [00:00<?, ?it/s]

Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 604.45it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 670.17it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 641.31it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 604.86it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 663.03it/s]


Validation: |          | 0/? [00:00<?, ?it/s]

Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 651.39it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 585.90it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 690.03it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 699.40it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 637.90it/s]


Validation: |          | 0/? [00:00<?, ?it/s]

Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 636.71it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 681.09it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 689.09it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 657.64it/s]


==EMPTY ACCURACIES==
(8, 50)
(50,)
(8,)
Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 673.67it/s]


Validation: |          | 0/? [00:00<?, ?it/s]

Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 692.43it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 653.31it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 665.93it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 662.82it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 688.65it/s]


Validation: |          | 0/? [00:00<?, ?it/s]

Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 674.77it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 660.15it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 607.64it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 699.71it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 694.20it/s]


Validation: |          | 0/? [00:00<?, ?it/s]

Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 519.59it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 645.62it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 546.90it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 694.85it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 667.30it/s]


Validation: |          | 0/? [00:00<?, ?it/s]

Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 606.85it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 650.90it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 549.47it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 683.11it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 654.06it/s]


Validation: |          | 0/? [00:00<?, ?it/s]

Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 704.07it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 640.73it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 702.35it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 702.84it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 624.19it/s]


Validation: |          | 0/? [00:00<?, ?it/s]

Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 624.82it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 661.40it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 699.62it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 534.10it/s]


==EMPTY ACCURACIES==
(8, 50)
(50,)
(8,)
Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 628.95it/s]


Validation: |          | 0/? [00:00<?, ?it/s]

Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 688.67it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 690.07it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 552.04it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 691.97it/s]


==EMPTY ACCURACIES==
(8, 50)
(50,)
(8,)
==EMPTY ACCURACIES==
(8, 50)
(50,)
(8,)
Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 681.51it/s]


Validation: |          | 0/? [00:00<?, ?it/s]

==EMPTY ACCURACIES==
(4, 50)
(50,)
(4,)
Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 596.95it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 552.76it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 468.37it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 593.09it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 657.59it/s]


Validation: |          | 0/? [00:00<?, ?it/s]

Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 596.84it/s]


==EMPTY ACCURACIES==
(8, 50)
(50,)
(8,)
Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 644.95it/s]


==EMPTY ACCURACIES==
(4, 50)
(50,)
(4,)
Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 645.52it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 663.51it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 652.10it/s]


==EMPTY ACCURACIES==
(8, 50)
(50,)
(8,)


Validation: |          | 0/? [00:00<?, ?it/s]

Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 475.23it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 608.62it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 496.05it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 672.21it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 681.25it/s]


Validation: |          | 0/? [00:00<?, ?it/s]

Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 617.18it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 662.40it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 629.07it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 645.66it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 655.98it/s]


Validation: |          | 0/? [00:00<?, ?it/s]

Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 618.75it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 595.10it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 698.09it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 670.03it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 669.42it/s]


Validation: |          | 0/? [00:00<?, ?it/s]

Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 677.91it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 668.33it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 617.91it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 689.73it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 704.42it/s]


Validation: |          | 0/? [00:00<?, ?it/s]

Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 554.09it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 540.82it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 678.47it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 640.85it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 538.48it/s]


Validation: |          | 0/? [00:00<?, ?it/s]

Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 627.77it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 541.51it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 688.26it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 685.01it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 603.01it/s]


Validation: |          | 0/? [00:00<?, ?it/s]

Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 502.96it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 688.24it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 635.30it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 644.12it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 569.26it/s]


Validation: |          | 0/? [00:00<?, ?it/s]

Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 533.14it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 638.74it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 605.53it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 672.48it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 681.89it/s]


==EMPTY ACCURACIES==
(4, 50)
(50,)
(4,)


Validation: |          | 0/? [00:00<?, ?it/s]

Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 677.19it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 712.07it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 551.15it/s]


==EMPTY ACCURACIES==
(4, 50)
(50,)
(4,)
Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 672.36it/s]


==EMPTY ACCURACIES==
(4, 50)
(50,)
(4,)
Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 675.11it/s]


Validation: |          | 0/? [00:00<?, ?it/s]

Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 568.08it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 578.79it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 686.93it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 669.08it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 550.24it/s]


Validation: |          | 0/? [00:00<?, ?it/s]

Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 621.56it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 513.51it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 699.62it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 713.28it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 690.02it/s]


Validation: |          | 0/? [00:00<?, ?it/s]

Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 523.50it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 565.27it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 672.98it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 647.47it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 719.91it/s]


Validation: |          | 0/? [00:00<?, ?it/s]

Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 571.78it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 559.37it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 692.88it/s]


==EMPTY ACCURACIES==
(8, 50)
(50,)
(8,)
Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 650.75it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 681.16it/s]


==EMPTY ACCURACIES==
(8, 50)
(50,)
(8,)


Validation: |          | 0/? [00:00<?, ?it/s]

Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 555.69it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 699.36it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 701.45it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 671.37it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 679.37it/s]


Validation: |          | 0/? [00:00<?, ?it/s]

Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 559.85it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 548.05it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 711.16it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 705.69it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 712.45it/s]


Validation: |          | 0/? [00:00<?, ?it/s]

Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 600.82it/s]


==EMPTY ACCURACIES==
(4, 50)
(50,)
(4,)
Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 704.91it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 620.14it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 631.61it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 575.59it/s]


Validation: |          | 0/? [00:00<?, ?it/s]

Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 607.85it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 706.14it/s]


==EMPTY ACCURACIES==
(8, 50)
(50,)
(8,)
Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 542.91it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 563.31it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 641.88it/s]


Validation: |          | 0/? [00:00<?, ?it/s]

Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 610.20it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 642.67it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 675.14it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 618.75it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 676.56it/s]


Validation: |          | 0/? [00:00<?, ?it/s]

Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 502.43it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 667.82it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 706.38it/s]


==EMPTY ACCURACIES==
(8, 50)
(50,)
(8,)
Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 608.24it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 704.08it/s]


==EMPTY ACCURACIES==
(8, 50)
(50,)
(8,)


Validation: |          | 0/? [00:00<?, ?it/s]

Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 640.81it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 638.97it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 657.72it/s]


==EMPTY ACCURACIES==
(8, 50)
(50,)
(8,)
Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 700.86it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 701.43it/s]


==EMPTY ACCURACIES==
(4, 50)
(50,)
(4,)


Validation: |          | 0/? [00:00<?, ?it/s]

Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 615.68it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 703.55it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 685.03it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 668.22it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 647.13it/s]


Validation: |          | 0/? [00:00<?, ?it/s]

Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 718.14it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 704.55it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 714.92it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 685.18it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 677.57it/s]


Validation: |          | 0/? [00:00<?, ?it/s]

Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 498.23it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 730.45it/s]


==EMPTY ACCURACIES==
(8, 50)
(50,)
(8,)
Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 656.61it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 558.92it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 502.80it/s]


==EMPTY ACCURACIES==
(8, 50)
(50,)
(8,)


Validation: |          | 0/? [00:00<?, ?it/s]

Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 570.39it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 704.12it/s]


==EMPTY ACCURACIES==
(4, 50)
(50,)
(4,)
Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 683.10it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 562.38it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 695.59it/s]


Validation: |          | 0/? [00:00<?, ?it/s]

==EMPTY ACCURACIES==
(8, 50)
(50,)
(8,)
Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 700.03it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 706.04it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 718.10it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 653.62it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 500.98it/s]


Validation: |          | 0/? [00:00<?, ?it/s]

Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 663.35it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 554.27it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 611.16it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 720.48it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 700.44it/s]


Validation: |          | 0/? [00:00<?, ?it/s]

Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 600.86it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 710.20it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 725.84it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 686.71it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 578.98it/s]


==EMPTY ACCURACIES==
(8, 50)
(50,)
(8,)


Validation: |          | 0/? [00:00<?, ?it/s]

==EMPTY ACCURACIES==
(4, 50)
(50,)
(4,)
Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 721.50it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 692.21it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 681.90it/s]


==EMPTY ACCURACIES==
(8, 50)
(50,)
(8,)
Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 707.37it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 719.77it/s]


Validation: |          | 0/? [00:00<?, ?it/s]

Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 719.44it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 717.23it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 547.62it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 694.09it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 606.01it/s]


Validation: |          | 0/? [00:00<?, ?it/s]

==EMPTY ACCURACIES==
(4, 50)
(50,)
(4,)
Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 634.06it/s]


==EMPTY ACCURACIES==
(8, 50)
(50,)
(8,)
Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 699.82it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 663.21it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 696.64it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 677.77it/s]


Validation: |          | 0/? [00:00<?, ?it/s]

Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 619.71it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 689.30it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 679.21it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 700.80it/s]


==EMPTY ACCURACIES==
(8, 50)
(50,)
(8,)
Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 732.05it/s]


Validation: |          | 0/? [00:00<?, ?it/s]

Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 661.29it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 699.21it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 658.24it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 674.65it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 698.96it/s]


Validation: |          | 0/? [00:00<?, ?it/s]

Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 700.09it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 717.61it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 705.89it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 660.93it/s]


==EMPTY ACCURACIES==
(8, 50)
(50,)
(8,)
Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 716.88it/s]


==EMPTY ACCURACIES==
(8, 50)
(50,)
(8,)


Validation: |          | 0/? [00:00<?, ?it/s]

Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 649.12it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 722.27it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 711.36it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 724.47it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 710.25it/s]


Validation: |          | 0/? [00:00<?, ?it/s]

Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 691.40it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 610.61it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 660.29it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 645.28it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 558.25it/s]


==EMPTY ACCURACIES==
(8, 50)
(50,)
(8,)


Validation: |          | 0/? [00:00<?, ?it/s]

Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 619.20it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 627.13it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 713.51it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 695.12it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 538.57it/s]


Validation: |          | 0/? [00:00<?, ?it/s]

Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 530.83it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 583.58it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 724.62it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 655.33it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 696.46it/s]


==EMPTY ACCURACIES==
(4, 50)
(50,)
(4,)


Validation: |          | 0/? [00:00<?, ?it/s]

Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 565.66it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 705.64it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 685.20it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 724.07it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 710.54it/s]


Validation: |          | 0/? [00:00<?, ?it/s]

Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 661.48it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 695.50it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 718.42it/s]


Preparing prototypes


100%|██████████| 263/263 [00:01<00:00, 239.94it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 718.59it/s]


Validation: |          | 0/? [00:00<?, ?it/s]

Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 639.46it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 728.31it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 697.18it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 605.21it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 645.34it/s]


==EMPTY ACCURACIES==
(8, 50)
(50,)
(8,)


Validation: |          | 0/? [00:00<?, ?it/s]

Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 713.84it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 694.39it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 675.72it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 646.33it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 523.59it/s]


Validation: |          | 0/? [00:00<?, ?it/s]

Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 639.85it/s]


==EMPTY ACCURACIES==
(8, 50)
(50,)
(8,)
Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 686.61it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 589.92it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 706.75it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 655.73it/s]


Validation: |          | 0/? [00:00<?, ?it/s]

Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 571.12it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 557.25it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 608.64it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 549.69it/s]


==EMPTY ACCURACIES==
(4, 50)
(50,)
(4,)
Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 681.56it/s]


Validation: |          | 0/? [00:00<?, ?it/s]

Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 545.10it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 707.44it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 534.95it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 498.35it/s]


==EMPTY ACCURACIES==
(8, 50)
(50,)
(8,)
Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 600.95it/s]


Validation: |          | 0/? [00:00<?, ?it/s]

Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 727.41it/s]


==EMPTY ACCURACIES==
(8, 50)
(50,)
(8,)
Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 702.62it/s]


==EMPTY ACCURACIES==
(8, 50)
(50,)
(8,)
==EMPTY ACCURACIES==
(8, 50)
(50,)
(8,)
==EMPTY ACCURACIES==
(4, 50)
(50,)
(4,)
Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 550.20it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 669.23it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 640.06it/s]


Validation: |          | 0/? [00:00<?, ?it/s]

Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 665.16it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 690.43it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 706.75it/s]


==EMPTY ACCURACIES==
(8, 50)
(50,)
(8,)
Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 561.63it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 707.01it/s]


Validation: |          | 0/? [00:00<?, ?it/s]

Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 622.29it/s]


==EMPTY ACCURACIES==
(8, 50)
(50,)
(8,)
Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 653.13it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 713.58it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 627.45it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 546.11it/s]


Validation: |          | 0/? [00:00<?, ?it/s]

Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 543.85it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 674.38it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 692.09it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 677.59it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 691.42it/s]


==EMPTY ACCURACIES==
(8, 50)
(50,)
(8,)


Validation: |          | 0/? [00:00<?, ?it/s]

Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 524.04it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 692.74it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 678.28it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 577.87it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 658.12it/s]


Validation: |          | 0/? [00:00<?, ?it/s]

Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 594.52it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 617.89it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 633.70it/s]


==EMPTY ACCURACIES==
(8, 50)
(50,)
(8,)
Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 673.88it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 692.39it/s]


Validation: |          | 0/? [00:00<?, ?it/s]

==EMPTY ACCURACIES==
(4, 50)
(50,)
(4,)
Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 674.20it/s]


==EMPTY ACCURACIES==
(8, 50)
(50,)
(8,)
Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 688.23it/s]


==EMPTY ACCURACIES==
(8, 50)
(50,)
(8,)
Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 708.18it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 556.66it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 563.42it/s]


Validation: |          | 0/? [00:00<?, ?it/s]

Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 549.20it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 693.48it/s]


==EMPTY ACCURACIES==
(8, 50)
(50,)
(8,)
Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 543.19it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 672.02it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 689.68it/s]


Validation: |          | 0/? [00:00<?, ?it/s]

Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 680.26it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 608.32it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 515.52it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 686.90it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 501.50it/s]


Validation: |          | 0/? [00:00<?, ?it/s]

Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 483.85it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 685.56it/s]


==EMPTY ACCURACIES==
(8, 50)
(50,)
(8,)
Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 646.77it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 545.63it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 701.23it/s]


Validation: |          | 0/? [00:00<?, ?it/s]

Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 690.98it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 696.01it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 692.99it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 566.42it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 493.32it/s]


Validation: |          | 0/? [00:00<?, ?it/s]

Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 635.71it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 675.51it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 608.43it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 620.51it/s]


==EMPTY ACCURACIES==
(8, 50)
(50,)
(8,)
Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 700.58it/s]


Validation: |          | 0/? [00:00<?, ?it/s]

Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 718.87it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 707.89it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 713.58it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 670.65it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 659.39it/s]


Validation: |          | 0/? [00:00<?, ?it/s]

Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 583.69it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 675.40it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 662.49it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 504.90it/s]


==EMPTY ACCURACIES==
(4, 50)
(50,)
(4,)
Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 617.50it/s]


Validation: |          | 0/? [00:00<?, ?it/s]

Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 722.49it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 654.84it/s]


==EMPTY ACCURACIES==
(4, 50)
(50,)
(4,)
Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 635.87it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 562.45it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 639.51it/s]


Validation: |          | 0/? [00:00<?, ?it/s]

Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 543.58it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 668.09it/s]


==EMPTY ACCURACIES==
(4, 50)
(50,)
(4,)
Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 641.31it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 630.74it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 620.11it/s]


Validation: |          | 0/? [00:00<?, ?it/s]

Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 703.76it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 662.13it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 638.43it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 698.52it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 657.57it/s]


Validation: |          | 0/? [00:00<?, ?it/s]

Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 625.37it/s]


==EMPTY ACCURACIES==
(8, 50)
(50,)
(8,)
Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 685.12it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 646.93it/s]


==EMPTY ACCURACIES==
(8, 50)
(50,)
(8,)
Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 686.98it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 715.62it/s]


==EMPTY ACCURACIES==
(8, 50)
(50,)
(8,)


Validation: |          | 0/? [00:00<?, ?it/s]

Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 524.94it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 655.19it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 698.85it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 708.56it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 716.62it/s]


Validation: |          | 0/? [00:00<?, ?it/s]

Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 674.93it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 633.33it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 599.40it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 680.18it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 652.23it/s]


Validation: |          | 0/? [00:00<?, ?it/s]

Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 521.11it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 561.07it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 704.01it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 685.91it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 677.17it/s]


Validation: |          | 0/? [00:00<?, ?it/s]

Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 617.36it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 692.39it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 668.10it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 575.96it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 639.92it/s]


Validation: |          | 0/? [00:00<?, ?it/s]

==EMPTY ACCURACIES==
(8, 50)
(50,)
(8,)
Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 543.88it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 699.85it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 683.06it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 663.28it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 716.31it/s]


Validation: |          | 0/? [00:00<?, ?it/s]

Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 711.31it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 579.28it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 703.84it/s]


==EMPTY ACCURACIES==
(8, 50)
(50,)
(8,)
Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 720.69it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 701.31it/s]


Validation: |          | 0/? [00:00<?, ?it/s]

Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 493.71it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 646.65it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 694.43it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 704.12it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 539.56it/s]


Validation: |          | 0/? [00:00<?, ?it/s]

Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 653.82it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 571.39it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 688.00it/s]


==EMPTY ACCURACIES==
(8, 50)
(50,)
(8,)
Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 645.22it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 642.39it/s]


Validation: |          | 0/? [00:00<?, ?it/s]

Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 600.34it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 628.86it/s]


==EMPTY ACCURACIES==
(8, 50)
(50,)
(8,)
Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 648.76it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 726.91it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 674.33it/s]


Validation: |          | 0/? [00:00<?, ?it/s]

==EMPTY ACCURACIES==
(8, 50)
(50,)
(8,)
Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 669.89it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 660.71it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 564.98it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 681.09it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 716.72it/s]


Validation: |          | 0/? [00:00<?, ?it/s]

Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 568.42it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 702.16it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 549.36it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 722.32it/s]


==EMPTY ACCURACIES==
(4, 50)
(50,)
(4,)
Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 711.98it/s]


Validation: |          | 0/? [00:00<?, ?it/s]

Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 622.76it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 527.11it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 629.55it/s]


==EMPTY ACCURACIES==
(8, 50)
(50,)
(8,)
==EMPTY ACCURACIES==
(8, 50)
(50,)
(8,)
Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 708.62it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 698.86it/s]


Validation: |          | 0/? [00:00<?, ?it/s]

Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 607.53it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 657.61it/s]


==EMPTY ACCURACIES==
(8, 50)
(50,)
(8,)
Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 719.62it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 664.56it/s]


Preparing prototypes


100%|██████████| 263/263 [00:00<00:00, 706.02it/s]


Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=500` reached.


For more details on all the things you may add/configure to the Trainer for more
control, please refer to the [official documentation](https://lightning.ai/docs/pytorch/stable/common/trainer.html).

## Part 4: Evaluate Model

Once the CEM has been trained, you can evaluate it with test data to generate
the learnt embeddings, the predicted concepts, and the predicted task labels!

A CEM or CBM model can be called with any input sample of shape `(batch_size, ...)`
using Pytorch's functional API:
```python
(c_pred, c_embs, y_pred) = cem_model(x)
```
Where:
1. `c_pred` is a $(\text{batch\_size}, k)$-dimensional vector where the i-th dimension indicates the probability that the i-th concept is on.
2. `c_embs` is a $(\text{batch\_size}, k \cdot \text{emb\_size})$-dimensional tensor representing the CEM's bottleneck. This corresponds to all concept embeddings concatenated in the same order as given in the training annotations (so reshaping it to $(\text{batch\_size}, k, \text{emb\_size})$ will allow you to access each concept's embedding directly).
3. `y_pred` is a $(\text{batch\_size}, L)$-dimensional vector where the i-th dimension is proportional to the probability that the i-th label is predicted for the current sample (the model outputs logits by default). If the downstream task is binary, then the CEM will output a $(\text{batch\_size})$-dimensional vector where each entry is the logit of the probability of the downstream class being $1$.

This allows us to compute some metrics of interest. Below, we will use
PytorchLightning's API to be able to run inference in batches in a GPU to
obtain all test activations.

Before doing this, we will turn our test dataset into numpy arrays as they
will be easily easier to work with if we want to compute custom metrics:

In [None]:
# Before anything, however, let's get the underlying numpy arrays of our
# test dataset as they will be easier to work with
x_test, y_test, c_test = [], [], []
for (x, y, c) in test_dl:
    x_test.append(x)
    y_test.append(y)
    c_test.append(c)
x_test = np.concatenate(x_test, axis=0)
y_test = np.concatenate(y_test, axis=0)
c_test = np.concatenate(c_test, axis=0)


Now we are ready to generate the concept, label, and embedding predictions for
the test set using our trained CEM:

In [None]:
# We will use a Trainer object to run inference in batches over our test
# dataset
trainer = pl.Trainer(
    accelerator="gpu",
    devices="auto",
    logger=False, # No logs to be dumped for this trainer
)
batch_results = trainer.predict(cem_model, test_dl)

# Then we combine all results into numpy arrays by joining over the batch
# dimension
c_pred = np.concatenate(
    list(map(lambda x: x[0].detach().cpu().numpy(), batch_results)),
    axis=0,
)
c_embs = np.concatenate(
    list(map(lambda x: x[1].detach().cpu().numpy(), batch_results)),
    axis=0,
)
# Reshape them so that we have embeddings (batch_size, k, emb_size)
c_embs = np.reshape(c_embs, (c_test.shape[0], c_test.shape[1], -1))

y_pred = np.concatenate(
    list(map(lambda x: x[2].detach().cpu().numpy(), batch_results)),
    axis=0,
)

And compute all the metrics of interest:

In [None]:
##########
## Compute test task accuracy
##########

from scipy.special import expit
from sklearn.metrics import accuracy_score

# Which allows us to compute the task accuracy (we explicitly perform a
# sigmoidal operation as CEMs always return logits)
task_accuracy = accuracy_score(y_test, expit(y_pred) >=0.5)
print(f"Our CEM's test task accuracy is {task_accuracy*100:.2f}%")

In [None]:
##########
## Compute test concept AUC
##########

from scipy.special import expit
from sklearn.metrics import roc_auc_score

# Which allows us to compute the task accuracy (we explicitly perform a
# sigmoidal operation as CEMs always return logits)
concept_auc = roc_auc_score(c_test, c_pred)
print(f"Our CEM's test concept AUC is {concept_auc*100:.2f}%")

In [None]:
##########
## Compute test concept alignment score
##########

from cem.metrics.cas import concept_alignment_score

cas, _ = concept_alignment_score(
    c_vec=c_embs,
    c_test=c_test,
    y_test=y_test,
    step=5,
    progress_bar=False,
)
print(f"Our CEM's concept alignment score (CAS) is {cas*100:.2f}%")