# Optimize Image Classification Model

In [1]:
import os
import logging
import warnings
import optuna
from optuna.samplers import TPESampler
from optuna.exceptions import OptunaError
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, WeightedRandomSampler
from calflops import calculate_flops
import lightning as L
from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch.utilities import disable_possible_user_warnings
from base import TSDataset, InternalLogger
from resnet import XResNet

logging.getLogger("lightning.pytorch").setLevel(logging.ERROR)
disable_possible_user_warnings()


class TrialError(OptunaError):
    pass


def get_weighted_sampler(dataset: TSDataset, num_samples: int):
    global_labels = dataset.y.numpy().max(axis=1).astype(int)
    unique, counts = np.unique(global_labels, return_counts=True)
    class_weights = 1.0 / counts
    sample_weights = class_weights[global_labels]
    sample_weights = sample_weights[dataset.window_size - 1::dataset.stride]
    return WeightedRandomSampler(sample_weights, num_samples, replacement=True)


def get_class_weights(labels: np.ndarray):
    pos_counts = labels.sum(axis=0)
    neg_counts = len(labels) - pos_counts
    class_weights = neg_counts / pos_counts
    print("Class weights:", class_weights)
    device = "cuda" if torch.cuda.is_available() else "cpu"
    return torch.tensor(class_weights, dtype=torch.float32, device=device)

In [5]:
test_trials = [
    {
        "window_size": 224,
        "block": "resblock",
        "expansion": 1,
        "layers": 3,
        "p": 0.0,
        "stem_szs_0": 32,
        "stem_szs_1": 32,
        "widen": 1.0,
        "sa": False,
        "act_cls": "relu",
        "ks": 3,
        "stride": 2,
    },
    {
        "window_size": 192,
        "block": "resblock",
        "expansion": 1,
        "layers": 3,
        "p": 0.0,
        "stem_szs_0": 32,
        "stem_szs_1": 32,
        "widen": 0.5,
        "sa": False,
        "act_cls": "relu",
        "ks": 3,
        "stride": 2,
    },
    {
        "window_size": 224,
        "block": "resblock",
        "expansion": 2,
        "layers": 2,
        "p": 0.0,
        "stem_szs_0": 32,
        "stem_szs_1": 32,
        "widen": 0.5,
        "sa": False,
        "act_cls": "hardswish",
        "ks": 3,
        "stride": 2,
    },
    {
        "window_size": 128,
        "block": "seresnextblock",
        "expansion": 2,
        "layers": 2,
        "p": 0.0,
        "stem_szs_0": 32,
        "stem_szs_1": 32,
        "widen": 0.5,
        "sa": False,
        "act_cls": "relu",
        "ks": 3,
        "stride": 2,
    },
    {
        "window_size": 192,
        "block": "seresnextblock",
        "expansion": 4,
        "layers": 2,
        "p": 0.0,
        "stem_szs_0": 32,
        "stem_szs_1": 32,
        "widen": 0.75,
        "sa": False,
        "act_cls": "hardswish",
        "ks": 3,
        "stride": 2,
    },
    {
        "window_size": 128,
        "block": "seresnextblock",
        "expansion": 4,
        "layers": 1,
        "p": 0.0,
        "stem_szs_0": 16,
        "stem_szs_1": 32,
        "widen": 0.5,
        "sa": False,
        "act_cls": "hardswish",
        "ks": 3,
        "stride": 2,
    },
    {
        "window_size": 128,
        "block": "resblock",
        "expansion": 1,
        "layers": 1,
        "p": 0.0,
        "stem_szs_0": 32,
        "stem_szs_1": 32,
        "widen": 1.0,
        "sa": True,
        "act_cls": "relu",
        "ks": 5,
        "stride": 3,
    },
    {
        "window_size": 128,
        "block": "seresnextblock",
        "expansion": 2,
        "layers": 2,
        "p": 0.0,
        "stem_szs_0": 16,
        "stem_szs_1": 32,
        "widen": 0.5,
        "sa": False,
        "act_cls": "hardswish",
        "ks": 5,
        "stride": 3,
    },
    {
        "window_size": 224,
        "block": "resblock",
        "expansion": 2,
        "layers": 3,
        "p": 0.0,
        "stem_szs_0": 32,
        "stem_szs_1": 32,
        "widen": 0.25,
        "sa": False,
        "act_cls": "relu",
        "ks": 5,
        "stride": 3,
    },
    {
        "window_size": 224,
        "block": "resblock",
        "expansion": 2,
        "layers": 4,
        "p": 0.0,
        "stem_szs_0": 32,
        "stem_szs_1": 32,
        "widen": 0.25,
        "sa": False,
        "act_cls": "relu",
        "ks": 3,
        "stride": 2,
    },
    {
        "window_size": 224,
        "block": "resblock",
        "expansion": 2,
        "layers": 5,
        "p": 0.0,
        "stem_szs_0": 16,
        "stem_szs_1": 16,
        "widen": 1/8,
        "sa": False,
        "act_cls": "hardswish",
        "ks": 3,
        "stride": 2,
    },
]

In [None]:
prepared_trials = [{'act_cls': 'hardswish',
  'block': 'separableblock',
  'expansion': 2,
  'ks': 5,
  'layers': 5,
  'p': 0.1,
  'sa': True,
  'stem_szs_0': 8,
  'stem_szs_1': 40,
  'stride': 3,
  'widen': 0.25,
  'window_size': 152},
 {'act_cls': 'leakyrelu',
  'block': 'seresnextblock',
  'expansion': 4,
  'ks': 7,
  'layers': 2,
  'p': 0.1,
  'sa': False,
  'stem_szs_0': 48,
  'stem_szs_1': 32,
  'stride': 3,
  'widen': 0.75,
  'window_size': 88},
 {'act_cls': 'hardswish',
  'block': 'seresnextblock',
  'expansion': 2,
  'ks': 3,
  'layers': 3,
  'p': 0.0,
  'sa': True,
  'stem_szs_0': 24,
  'stem_szs_1': 48,
  'stride': 2,
  'widen': 0.75,
  'window_size': 104},
 {'act_cls': 'relu',
  'block': 'seblock',
  'expansion': 2,
  'ks': 7,
  'layers': 2,
  'p': 0.0,
  'sa': True,
  'stem_szs_0': 8,
  'stem_szs_1': 24,
  'stride': 3,
  'widen': 0.75,
  'window_size': 72},
 {'act_cls': 'leakyrelu',
  'block': 'resblock',
  'expansion': 1,
  'ks': 3,
  'layers': 1,
  'p': 0.15,
  'sa': False,
  'stem_szs_0': 56,
  'stem_szs_1': 8,
  'stride': 2,
  'widen': 0.25,
  'window_size': 168},
 {'act_cls': 'hardswish',
  'block': 'resblock',
  'expansion': 2,
  'ks': 3,
  'layers': 1,
  'p': 0.1,
  'sa': True,
  'stem_szs_0': 16,
  'stem_szs_1': 48,
  'stride': 3,
  'widen': 0.25,
  'window_size': 64},
 {'act_cls': 'leakyrelu',
  'block': 'resblock',
  'expansion': 4,
  'ks': 5,
  'layers': 3,
  'p': 0.0,
  'sa': False,
  'stem_szs_0': 8,
  'stem_szs_1': 32,
  'stride': 2,
  'widen': 0.25,
  'window_size': 72},
 {'act_cls': 'relu',
  'block': 'seresnextblock',
  'expansion': 4,
  'ks': 3,
  'layers': 4,
  'p': 0.2,
  'sa': False,
  'stem_szs_0': 48,
  'stem_szs_1': 32,
  'stride': 2,
  'widen': 0.5,
  'window_size': 64},
 {'act_cls': 'leakyrelu',
  'block': 'resblock',
  'expansion': 1,
  'ks': 3,
  'layers': 2,
  'p': 0.2,
  'sa': True,
  'stem_szs_0': 24,
  'stem_szs_1': 16,
  'stride': 1,
  'widen': 0.125,
  'window_size': 104},
 {'act_cls': 'leakyrelu',
  'block': 'seresnextblock',
  'expansion': 2,
  'ks': 5,
  'layers': 5,
  'p': 0.2,
  'sa': False,
  'stem_szs_0': 64,
  'stem_szs_1': 24,
  'stride': 2,
  'widen': 0.5,
  'window_size': 96},
 {'act_cls': 'relu',
  'block': 'resblock',
  'expansion': 4,
  'ks': 3,
  'layers': 5,
  'p': 0.15,
  'sa': False,
  'stem_szs_0': 8,
  'stem_szs_1': 8,
  'stride': 3,
  'widen': 0.75,
  'window_size': 112},
 {'act_cls': 'hardswish',
  'block': 'resblock',
  'expansion': 2,
  'ks': 5,
  'layers': 3,
  'p': 0.2,
  'sa': True,
  'stem_szs_0': 48,
  'stem_szs_1': 24,
  'stride': 3,
  'widen': 0.125,
  'window_size': 64},
 {'act_cls': 'leakyrelu',
  'block': 'seblock',
  'expansion': 1,
  'ks': 7,
  'layers': 5,
  'p': 0.15,
  'sa': True,
  'stem_szs_0': 40,
  'stem_szs_1': 16,
  'stride': 3,
  'widen': 0.25,
  'window_size': 112},
 {'act_cls': 'relu',
  'block': 'seresnextblock',
  'expansion': 1,
  'ks': 5,
  'layers': 3,
  'p': 0.0,
  'sa': True,
  'stem_szs_0': 64,
  'stem_szs_1': 56,
  'stride': 2,
  'widen': 0.5,
  'window_size': 72},
 {'act_cls': 'hardswish',
  'block': 'resblock',
  'expansion': 2,
  'ks': 5,
  'layers': 1,
  'p': 0.15,
  'sa': True,
  'stem_szs_0': 16,
  'stem_szs_1': 40,
  'stride': 3,
  'widen': 0.5,
  'window_size': 120},
 {'act_cls': 'hardswish',
  'block': 'separableblock',
  'expansion': 2,
  'ks': 5,
  'layers': 4,
  'p': 0.0,
  'sa': False,
  'stem_szs_0': 8,
  'stem_szs_1': 16,
  'stride': 2,
  'widen': 0.75,
  'window_size': 64},
 {'act_cls': 'leakyrelu',
  'block': 'seblock',
  'expansion': 4,
  'ks': 5,
  'layers': 3,
  'p': 0.2,
  'sa': False,
  'stem_szs_0': 32,
  'stem_szs_1': 24,
  'stride': 2,
  'widen': 0.25,
  'window_size': 64},
 {'act_cls': 'leakyrelu',
  'block': 'seresnextblock',
  'expansion': 4,
  'ks': 3,
  'layers': 3,
  'p': 0.15,
  'sa': True,
  'stem_szs_0': 56,
  'stem_szs_1': 32,
  'stride': 2,
  'widen': 0.25,
  'window_size': 184},
 {'act_cls': 'leakyrelu',
  'block': 'seresnextblock',
  'expansion': 1,
  'ks': 3,
  'layers': 1,
  'p': 0.05,
  'sa': False,
  'stem_szs_0': 32,
  'stem_szs_1': 24,
  'stride': 2,
  'widen': 0.75,
  'window_size': 88},
 {'act_cls': 'leakyrelu',
  'block': 'separableblock',
  'expansion': 2,
  'ks': 7,
  'layers': 2,
  'p': 0.1,
  'sa': False,
  'stem_szs_0': 40,
  'stem_szs_1': 24,
  'stride': 3,
  'widen': 1.0,
  'window_size': 88},
 {'act_cls': 'leakyrelu',
  'block': 'seresnextblock',
  'expansion': 4,
  'ks': 3,
  'layers': 2,
  'p': 0.15,
  'sa': False,
  'stem_szs_0': 8,
  'stem_szs_1': 32,
  'stride': 3,
  'widen': 0.25,
  'window_size': 192},
 {'act_cls': 'leakyrelu',
  'block': 'resblock',
  'expansion': 4,
  'ks': 7,
  'layers': 2,
  'p': 0.1,
  'sa': True,
  'stem_szs_0': 32,
  'stem_szs_1': 8,
  'stride': 3,
  'widen': 0.25,
  'window_size': 88},
 {'act_cls': 'relu',
  'block': 'separableblock',
  'expansion': 4,
  'ks': 3,
  'layers': 3,
  'p': 0.0,
  'sa': False,
  'stem_szs_0': 64,
  'stem_szs_1': 16,
  'stride': 2,
  'widen': 1.0,
  'window_size': 80},
 {'act_cls': 'hardswish',
  'block': 'seresnextblock',
  'expansion': 4,
  'ks': 5,
  'layers': 1,
  'p': 0.0,
  'sa': True,
  'stem_szs_0': 32,
  'stem_szs_1': 8,
  'stride': 2,
  'widen': 1.0,
  'window_size': 128},
 {'act_cls': 'leakyrelu',
  'block': 'resblock',
  'expansion': 4,
  'ks': 7,
  'layers': 1,
  'p': 0.2,
  'sa': False,
  'stem_szs_0': 48,
  'stem_szs_1': 16,
  'stride': 2,
  'widen': 0.5,
  'window_size': 96},
 {'act_cls': 'leakyrelu',
  'block': 'seresnextblock',
  'expansion': 2,
  'ks': 3,
  'layers': 1,
  'p': 0.05,
  'sa': True,
  'stem_szs_0': 24,
  'stem_szs_1': 48,
  'stride': 3,
  'widen': 0.5,
  'window_size': 200},
 {'act_cls': 'leakyrelu',
  'block': 'seresnextblock',
  'expansion': 1,
  'ks': 3,
  'layers': 2,
  'p': 0.1,
  'sa': True,
  'stem_szs_0': 64,
  'stem_szs_1': 48,
  'stride': 2,
  'widen': 0.5,
  'window_size': 128},
 {'act_cls': 'leakyrelu',
  'block': 'resblock',
  'expansion': 1,
  'ks': 7,
  'layers': 5,
  'p': 0.1,
  'sa': True,
  'stem_szs_0': 32,
  'stem_szs_1': 16,
  'stride': 3,
  'widen': 0.125,
  'window_size': 96},
 {'act_cls': 'relu',
  'block': 'separableblock',
  'expansion': 2,
  'ks': 3,
  'layers': 1,
  'p': 0.1,
  'sa': True,
  'stem_szs_0': 24,
  'stem_szs_1': 32,
  'stride': 3,
  'widen': 0.75,
  'window_size': 136},
 {'act_cls': 'leakyrelu',
  'block': 'separableblock',
  'expansion': 4,
  'ks': 3,
  'layers': 2,
  'p': 0.1,
  'sa': False,
  'stem_szs_0': 32,
  'stem_szs_1': 24,
  'stride': 3,
  'widen': 0.25,
  'window_size': 160}]

In [2]:
macs_threshold = 200  # Million MACs
params_threshold = 500_000  # Num Params
n_startup_trials = len(prepared_trials)
num_samples_per_epoch = 2_097_152  # 2048 steps
batch_size = 1024
accumulate_grad_batches = 1024 // batch_size  # virtual batch size of 1024
num_evals = 8  # number of evaluations per epoch
split = 264_960
file_dir = "output"
study_name = "model_optimization_image_v2"

os.makedirs(file_dir, exist_ok=True)

seed = 0
in_channels = 6
out_channels = 6
eval_steps = num_samples_per_epoch // batch_size // num_evals

In [None]:
df_train = pd.read_csv("84_months.train.csv")
target_channels = [f"channel_{i}" for i in range(41, 47)]
label_cols = ["is_anomaly_" + col for col in target_channels]

train_array = df_train[target_channels].iloc[:-split].values.astype(np.float32)
val_array = df_train[target_channels].iloc[-split:].values.astype(np.float32)
train_labels = df_train[label_cols].iloc[:-split].values.astype(np.float32).clip(0., 1.)
val_labels = df_train[label_cols].iloc[-split:].values.astype(np.float32).clip(0., 1.)

class_weights = get_class_weights(train_labels)

train_array.shape, train_labels.shape, val_array.shape, val_labels.shape

In [4]:
def objective(trial):
    # sample model configuration
    window_size = trial.suggest_int("window_size", 64, 224, step=8)
    block = trial.suggest_categorical("block", ["resblock", "seblock", "seresnextblock", "separableblock"])
    expansion = trial.suggest_categorical("expansion", [1, 2, 4])
    layers = [1] * trial.suggest_int("layers", 1, 5, step=1)
    p = trial.suggest_float("p", 0., 0.2, step=0.05)
    stem_szs_0 = trial.suggest_int("stem_szs_0", 8, 64, step=8)
    stem_szs_1 = trial.suggest_int("stem_szs_1", 8, 64, step=8)
    widen = trial.suggest_categorical("widen", [1/8, 0.25, 0.5, 0.75, 1.0])
    sa = trial.suggest_categorical("sa", [True, False])
    act_cls = trial.suggest_categorical("act_cls", ["relu", "leakyrelu", "hardswish"])
    ks = trial.suggest_int("ks", 3, 7, step=2)
    stride = trial.suggest_int("stride", 1, 3, step=1)

    # set seed
    L.seed_everything(seed=seed, verbose=False)

    # setup datasets
    train_dataset = TSDataset(train_array, train_labels, window_size=window_size)
    val_dataset = TSDataset(val_array, val_labels, window_size=window_size)

    train_sampler = get_weighted_sampler(train_dataset, num_samples=num_samples_per_epoch)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False, drop_last=False, sampler=train_sampler)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, drop_last=False)

    try:
        # initialize model
        model = XResNet(
            block=block,
            expansion=expansion,
            layers=layers,
            p=p,
            c_in=in_channels,
            n_out=out_channels,
            stem_szs=(stem_szs_0, stem_szs_1, 64),
            widen=widen,
            sa=sa,
            act_cls=act_cls,
            ndim=2,
            ks=ks,
            stride=stride,
            # fixed BaseClassifier params
            criterion=nn.BCEWithLogitsLoss(),
            val_criterion=nn.BCEWithLogitsLoss(pos_weight=class_weights),
            optimizer="adam",
            lr=1e-4,
            patience=100,
            min_delta=0.,
            checkpoint_dir=None,
            use_one_cycle=True
        ).cuda()

        # calculate MACs and model parameters
        flops, macs, num_params = calculate_flops(
            model=model,
            input_shape=(1, in_channels, window_size),
            print_results=False,
            output_as_string=False,
            include_backPropagation=False,
        )
        # store all metrics
        trial.set_user_attr("flops", flops)
        trial.set_user_attr("macs", macs)
        trial.set_user_attr("num_params", num_params)

        # scale macs
        macs /= 1_000_000

    except:
        raise TrialError(f"Trial {trial.number} model initialization failed.")

    # prune
    if trial.number >= n_startup_trials:  # Don't prune during startup
        if (macs > macs_threshold) or (num_params > params_threshold):
            raise optuna.TrialPruned()

    # Setup trainer
    metric_logger = InternalLogger()
    torch.set_float32_matmul_precision("high")

    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        trainer = L.Trainer(
            accelerator="gpu",
            max_epochs=1,
            log_every_n_steps=8,
            accumulate_grad_batches=accumulate_grad_batches,
            logger=[metric_logger, TensorBoardLogger(save_dir=file_dir, name="lightning_logs")],
            val_check_interval=eval_steps,
            enable_model_summary=False,
            enable_progress_bar=False,
            enable_checkpointing=False,
        )

    try:
        # train model
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)

        # get min validation loss
        metric_df = metric_logger.history
        loss = metric_df["val_loss"].min()
        trial.set_user_attr("val_loss", loss)

    except:
        raise TrialError(f"Trial {trial.number} training failed.")

    if (loss is None) or (np.isnan(loss)):
        raise TrialError("No validation loss")
    if (macs is None) or (np.isnan(macs)):
        raise TrialError("No MACs")

    return macs, loss

In [None]:
sampler = TPESampler(n_startup_trials=n_startup_trials, multivariate=True, seed=seed)
study = optuna.create_study(
    storage=f"sqlite:///{file_dir}/{study_name}.db",
    study_name=study_name,
    sampler=sampler,
    directions=["minimize", "minimize"],
    load_if_exists=True,
)

In [None]:
# baselines
for trial_dict in prepared_trials:
    study.enqueue_trial(trial_dict)

study.optimize(
    objective,
    n_trials=len(prepared_trials),
    timeout=int(60 * 60 * 8),  # 8 hours
    catch=(TrialError,),
    gc_after_trial=True
)