In [1]:
import warnings

import os
import torch
import pickle
import numpy as np
import os.path as osp
from tqdm import tqdm
from functools import reduce
import matplotlib.pyplot as plt

from lightning.pytorch import Trainer
from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.callbacks.early_stopping import EarlyStopping

from src.datasets import DatasetBuilder
from src.models.classifiers import SimpleCNNtorch
from src.models.lightning_wrappers import ClassifierLightningWrapper 

from src.utils.generic_utils import seed_everything, get_config, load_model_weights
from src.utils.generic_utils import evaluate_classification_model

  check_for_updates()
2025-05-13 04:50:40.335583: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1747104640.354052 3761728 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1747104640.359636 3761728 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1747104640.374757 3761728 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1747104640.374775 3761728 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1747104640.374777 3761728 computation_placer.cc:177] 

In [2]:
seed_everything()
warnings.filterwarnings("ignore", category=UserWarning)

In [3]:
config_path = r"/data/leuven/365/vsc36567/CF-Robustness-Benchmark/configs/train_classifier_fmnist.yaml"
config = get_config(config_path)

In [4]:
config.data.classes = [0, 2, 4, 6]
config.data.num_classes = len(config.data.classes)
config.classifier.checkpoints_path = r"/data/leuven/365/vsc36567/CF-Robustness-Benchmark/notebooks/experiments/fmnist_classification/multiclass/checkpoints/fmnist__epoch=12_val_accuracy=0.84.pth" 
# r"/data/leuven/365/vsc36567/CF-Robustness-Benchmark/notebooks/experiments/fmnist_classification/binary/checkpoints/fmnist_0_4_epoch=06_val_accuracy=0.99.pth"

In [5]:
ds_builder = DatasetBuilder(config)
ds_builder.setup()
train_loader, val_loader, test_loader = ds_builder.get_dataloaders()

In [6]:
baseline_classifier = SimpleCNNtorch(**config.classifier.args,
                                    num_classes=config.data.num_classes,    
                                    img_size=config.data.img_size)
load_model_weights(baseline_classifier, weights_path=config.classifier.checkpoints_path)
evaluate_classification_model(baseline_classifier, test_loader, config.data.num_classes)

Accuracy for the test dataset: 81.994%


In [7]:
expt_dir = "/data/leuven/365/vsc36567/CF-Robustness-Benchmark/notebooks/experiments"
# r"D:\PycharmProjects\CF-Robustness-Benchmark\notebooks\experiments"
expt_name = f"{config.data.name}_classification"
expt_version = "binary"  if config.data.num_classes == 2 else "multiclass"
checkpints_dir = osp.join(expt_dir, expt_name, expt_version, 'checkpoints', 'mc_all')
class_names = ds_builder.class_encodings
classes4fname = ("_").join([str(i) for i in class_names.values()]) if config.data.num_classes == 2 else ""

early_stop_callback = EarlyStopping(
    monitor="val_loss",
    min_delta=0.00,
    patience=3,
    verbose=False,
    mode="min",
)
tb_logger = TensorBoardLogger(save_dir=expt_dir, name=expt_name, version=expt_version)

In [8]:
from random import randint
from torchmetrics import Accuracy


seed_list = [randint(1000, 3000) for _ in range(10)]
num_classes = config.data.num_classes
baseline_accuracy = 0.82

for seed in tqdm(seed_list):
    seed_everything(seed)

    cnn_wi = SimpleCNNtorch(**config.classifier.args,
                            num_classes=config.data.num_classes,
                            img_size=config.data.img_size)
    cnn_wrapper = ClassifierLightningWrapper(config, cnn_wi)
    # chekpoint_callback = ModelCheckpoint(
    #     monitor="val_accuracy",
    #     dirpath=checkpints_dir,
    #     filename=f"{config.data.name}_{classes4fname}_{seed}",
    #     save_top_k=1,
    #     mode="max",
    #     save_weights_only=True,
    # )
    # chekpoint_callback.FILE_EXTENSION = '.pth'
    trainer = Trainer(
        log_every_n_steps=10,
        max_epochs=10,
        enable_checkpointing=False,
        # callbacks=[early_stop_callback],
        logger=tb_logger
    )
    trainer.fit(model=cnn_wrapper, 
            train_dataloaders=train_loader, 
            val_dataloaders=val_loader)
    
    # Check predictive power of the model
    calc_metric = Accuracy(task="binary" if num_classes==2 else "multiclass", num_classes=num_classes)

    accuracy = 0
    for images, labels in test_loader:
        preds = torch.argmax(cnn_wi(images), axis=1, keepdim=False)
        accuracy += calc_metric(preds, labels)
    accuracy = (accuracy/len(test_loader)).item()
    
    print("Accuracy for the test dataset: {:.3%}".format(accuracy))  

    if baseline_accuracy - accuracy < 0.1:
        fname = f"{config.data.name}_{classes4fname}_{seed}.pth"
        torch.save(cnn_wi.state_dict(), osp.join(checkpints_dir, fname))
        
    else:
        seed_list.append(randint(1000, 3000))
        continue 

  0%|          | 0/10 [00:00<?, ?it/s]GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type             | Params | Mode 
-----------------------------------------------------------
0 | model         | SimpleCNNtorch   | 25.3 K | train
1 | loss_fn       | CrossEntropyLoss | 0      | train
2 | train_metrics | MetricCollection | 0      | train
3 | valid_metrics | MetricCollection | 0      | train
-----------------------------------------------------------
25.3 K    Trainable params
0         Non-trainable params
25.3 K    Total params
0.101     Total estimated model params size (MB)
25        Modules in train mode
0         Modules in eval mode
SLURM auto-requeueing enabled. Setting signal handlers.


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=10` reached.


Accuracy for the test dataset: 81.052%


 10%|█         | 1/10 [00:52<07:51, 52.43s/it]GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type             | Params | Mode 
-----------------------------------------------------------
0 | model         | SimpleCNNtorch   | 25.3 K | train
1 | loss_fn       | CrossEntropyLoss | 0      | train
2 | train_metrics | MetricCollection | 0      | train
3 | valid_metrics | MetricCollection | 0      | train
-----------------------------------------------------------
25.3 K    Trainable params
0         Non-trainable params
25.3 K    Total params
0.101     Total estimated model params size (MB)
25        Modules in train mode
0         Modules in eval mode
SLURM auto-requeueing enabled. Setting signal handlers.


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

In [16]:
baseline_accuracy - accuracy/len(test_loader) < 0.1

tensor(False)