# Optimize Direct Classification Model

In [1]:
!pip install calflops lightning pyts==0.12.0
!pip install --no-deps tsai==0.3.9

Collecting calflops
  Downloading calflops-0.3.2-py3-none-any.whl.metadata (28 kB)
Collecting lightning
  Downloading lightning-2.5.1-py3-none-any.whl.metadata (39 kB)
Collecting pyts==0.12.0
  Downloading pyts-0.12.0-py3-none-any.whl.metadata (10 kB)
Downloading pyts-0.12.0-py3-none-any.whl (2.5 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.5/2.5 MB[0m [31m58.5 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading calflops-0.3.2-py3-none-any.whl (29 kB)
Downloading lightning-2.5.1-py3-none-any.whl (818 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m818.9/818.9 kB[0m [31m46.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pyts, lightning, calflops
Successfully installed calflops-0.3.2 lightning-2.5.1 pyts-0.12.0
Collecting tsai==0.3.9
  Downloading tsai-0.3.9-py3-none-any.whl.metadata (16 kB)
Downloading tsai-0.3.9-py3-none-any.whl (324 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m324.3/324.

In [2]:
import os
import logging
import warnings
import optuna
from optuna.samplers import TPESampler
from optuna.exceptions import OptunaError
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, WeightedRandomSampler, Dataset
from calflops import calculate_flops
import lightning as L
from lightning.pytorch.loggers import TensorBoardLogger, Logger
from lightning.pytorch.utilities import disable_possible_user_warnings, rank_zero_only
from fastai.imports import noop
from fastai.layers import AdaptiveConcatPool1d
from tsai.models.layers import Conv, Concat, Norm, ConvBlock, GAP1d, Add
from typing import Literal
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint, LearningRateMonitor


class BaseLightningModel(L.LightningModule):
    def __init__(self, config=None):
        super().__init__()
        self.config = config
        self._val_loss = 0.0
        self._test_loss = 0.0
        self._val_batches = 0
        self._test_batches = 0

    def training_step(self, batch, batch_idx):
        inputs, target = batch
        output = self(inputs)
        loss = self.criterion(output, target)
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        inputs, target = batch
        output = self(inputs)
        self._val_loss += self.val_criterion(output, target).item()
        self._val_batches += 1

    def test_step(self, batch, batch_idx):
        inputs, target = batch
        output = self(inputs)
        self._test_loss += self.val_criterion(output, target).item()
        self._test_batches += 1

    def predict_step(self, batch, batch_idx, dataloader_idx=0):
        if isinstance(batch, list) or isinstance(batch, tuple):
            batch, y = batch
            return self(batch), y
        return self(batch)

    def configure_optimizers(self):
        if self.config.optimizer == "adam":
            optimizer = torch.optim.Adam(self.parameters(), lr=self.config.lr)
        elif self.config.optimizer == "adamw":
            optimizer = torch.optim.AdamW(self.parameters(), lr=self.config.lr)
        else:
            raise ValueError(f"Unknown optimizer: {self.config.optimizer}")
        if self.config.use_one_cycle:
            scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=self.lr, pct_start=0.25,
                                                            total_steps=self.trainer.estimated_stepping_batches)
            opt_config = {
                "optimizer": optimizer,
                "lr_scheduler": {
                    "scheduler": scheduler,
                    "interval": "step",
                    "frequency": 1,
                }
            }
            return opt_config
        return optimizer

    def configure_callbacks(self):
        callbacks = []
        early_stop = EarlyStopping(
            monitor="val_loss",
            patience=self.config.patience,
            min_delta=self.config.min_delta,
            verbose=True
        )
        callbacks.append(early_stop)
        if self.config.modelOutput is not None:
            checkpoint = ModelCheckpoint(
                dirpath=self.config.modelOutput,
                filename="best_{epoch:02d}-{val_loss:.3f}",
                monitor="val_loss",
                save_top_k=1,
                verbose=True,
                save_last=True
            )
            callbacks.append(checkpoint)
        lr_monitor = LearningRateMonitor(logging_interval="step")
        callbacks.append(lr_monitor)
        return callbacks

    def on_validation_epoch_end(self):
        """Compute and log validation metrics at the end of the validation epoch."""
        val_loss = self._val_loss / self._val_batches
        self.log("val_loss", val_loss)
        self._val_loss = 0.0
        self._val_batches = 0

    def on_test_epoch_end(self):
        """Compute and log test metrics at the end of the test epoch."""
        test_loss = self._test_loss / self._test_batches
        self.log("test_loss", test_loss)
        self._test_loss = 0.0
        self._test_batches = 0

    @staticmethod
    def get_act(activation: Literal["relu", "leakyrelu", "mish", "silu", "hardswish", "gelu", "celu", "elu"] = "relu"):
        activation = activation.lower()
        if activation == "relu":
            return torch.nn.ReLU
        elif activation == "leakyrelu":
            return torch.nn.LeakyReLU
        elif activation == "mish":
            return torch.nn.Mish
        elif activation == "silu":  # Swish
            return torch.nn.SiLU
        elif activation == "hardswish":
            return torch.nn.Hardswish
        elif activation == "gelu":
            return torch.nn.GELU
        elif activation == "celu":
            return torch.nn.CELU
        elif activation == "elu":
            return torch.nn.ELU
        else:
            raise ValueError(f"Unknown activation function: {activation}")


class BaseClassifier(BaseLightningModel):
    def __init__(self, criterion, val_criterion, optimizer, lr, patience, min_delta, checkpoint_dir, use_one_cycle):
        super().__init__()
        self.criterion = criterion if criterion is not None else nn.BCEWithLogitsLoss()
        self.val_criterion = val_criterion if val_criterion is not None else criterion
        self.optimizer_name = optimizer
        self.lr = lr
        self.patience = patience
        self.min_delta = min_delta
        self.checkpoint_dir = checkpoint_dir
        self.use_one_cycle = use_one_cycle

        # self.save_hyperparameters()

    def configure_optimizers(self):
        if self.optimizer_name == "adam":
            optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        elif self.optimizer_name == "adamw":
            optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr)
        else:
            raise ValueError(f"Unknown optimizer: {self.optimizer_name}")
        if self.use_one_cycle:
            scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=self.lr, pct_start=0.25,
                                                            total_steps=self.trainer.estimated_stepping_batches)
            opt_config = {
                "optimizer": optimizer,
                "lr_scheduler": {
                    "scheduler": scheduler,
                    "interval": "step",
                    "frequency": 1,
                }
            }
            return opt_config
        return optimizer

    def configure_callbacks(self):
        callbacks = []
        early_stop = EarlyStopping(
            monitor="val_loss",
            patience=self.patience,
            min_delta=self.min_delta,
            verbose=True
        )
        callbacks.append(early_stop)
        if self.checkpoint_dir is not None:
            checkpoint = ModelCheckpoint(
                dirpath=self.checkpoint_dir,
                filename="best_{epoch:02d}-{val_loss:.3f}",
                monitor="val_loss",
                save_top_k=1,
                verbose=True,
                save_last=True
            )
            callbacks.append(checkpoint)
        lr_monitor = LearningRateMonitor(logging_interval="step")
        callbacks.append(lr_monitor)
        return callbacks


class XceptionModulePlus(nn.Module):
    def __init__(self, ni, nf, ks=40, kss=None, bottleneck=True, coord=False, separable=True, norm='Batch',
                 bn_1st=True, act=nn.ReLU, act_kwargs=None, norm_act=False):
        super().__init__()
        act_kwargs = {} if act_kwargs is None else act_kwargs
        if kss is None:
            kss = [ks // (2 ** i) for i in range(3)]
        kss = [ksi if ksi % 2 != 0 else ksi - 1 for ksi in kss]  # ensure odd kss for padding='same'
        self.bottleneck = Conv(ni, nf, 1, coord=coord, bias=False) if bottleneck else noop
        self.convs = nn.ModuleList()
        for i in range(len(kss)):
            self.convs.append(Conv(nf if bottleneck else ni, nf, kss[i], coord=coord, separable=separable, bias=False))
        self.mp_conv = nn.Sequential(*[nn.MaxPool1d(3, stride=1, padding=1), Conv(ni, nf, 1, coord=coord, bias=False)])
        self.concat = Concat()
        _norm_act = []
        if act is not None:
            _norm_act.append(act(**act_kwargs))
        _norm_act.append(Norm(nf * 4, norm=norm, zero_norm=False))
        if bn_1st:
            _norm_act.reverse()
        self.norm_act = noop if not norm_act else _norm_act[0] if act is None else nn.Sequential(*_norm_act)

    def forward(self, x):
        input_tensor = x
        x = self.bottleneck(x)
        x = self.concat([l(x) for l in self.convs] + [self.mp_conv(input_tensor)])
        return self.norm_act(x)


class XceptionBlockPlus(nn.Module):
    def __init__(self, ni, nf, residual=True, coord=False, norm='Batch', act=nn.ReLU, act_kwargs=None, dropout=0., **kwargs):
        super().__init__()
        act_kwargs = {} if act_kwargs is None else act_kwargs
        self.residual = residual
        self.xception, self.shortcut, self.act = nn.ModuleList(), nn.ModuleList(), nn.ModuleList()
        for i in range(4):
            if self.residual and (i - 1) % 2 == 0:
                self.shortcut.append(
                    Norm(n_in, norm=norm) if n_in == n_out else
                    ConvBlock(n_in, n_out * 4 * 2, 1, coord=coord, bias=False, norm=norm, act=None, dropout=dropout)
                )
                self.act.append(act(**act_kwargs))
            n_out = nf * 2 ** i
            n_in = ni if i == 0 else n_out * 2
            self.xception.append(XceptionModulePlus(n_in, n_out, coord=coord, norm=norm,
                                                    act=act if self.residual and (i - 1) % 2 == 0 else None, **kwargs))
        self.add = Add()

    def forward(self, x):
        res = x
        for i in range(4):
            x = self.xception[i](x)
            if self.residual and (i + 1) % 2 == 0:
                res = x = self.act[i // 2](self.add(x, self.shortcut[i // 2](res)))
        return x


class XceptionTimePlus(BaseClassifier):
    def __init__(
            self,
            c_in,
            c_out,
            nf=16,
            coord=False,
            norm="Batch",
            concat_pool=False,
            adaptive_size=48,
            activation="relu",
            dropout=0.0,
            ks=40,
            bottleneck=True,
            bn_1st=True,
            norm_act=False,
            width=16,
            input_norm=True,
            window_size=None,
            norm_weights=False,
            # BaseClassifier params
            criterion=None,
            val_criterion=None,
            optimizer="adam",
            lr=1e-3,
            patience=5,
            min_delta=0.0,
            checkpoint_dir=None,
            use_one_cycle=False
    ):
        super().__init__(criterion, val_criterion, optimizer, lr, patience, min_delta, checkpoint_dir, use_one_cycle)

        # params
        act = self.get_act(activation)  # nn.ReLU

        # input standardization
        if input_norm:
            assert window_size is not None, "window_size must be provided if input_norm is True"
            self.in_norm = nn.LayerNorm([c_in, window_size], elementwise_affine=norm_weights, bias=norm_weights)
        else:
            self.in_norm = nn.Identity()

        # Backbone
        self.backbone = XceptionBlockPlus(c_in, nf, coord=coord, norm=norm, act=act, dropout=dropout, ks=ks,
                                          bottleneck=bottleneck, bn_1st=bn_1st, norm_act=norm_act)
        # Head
        gap1 = AdaptiveConcatPool1d(adaptive_size) if adaptive_size and concat_pool else nn.AdaptiveAvgPool1d(adaptive_size) if adaptive_size else noop
        mult = 2 if adaptive_size and concat_pool else 1
        conv1x1_1 = ConvBlock(nf * 32 * mult, nf * width * mult, 1, coord=coord, norm=norm)
        conv1x1_2 = ConvBlock(nf * width * mult, nf * width // 2 * mult, 1, coord=coord, norm=norm)
        conv1x1_3 = ConvBlock(nf * width // 2 * mult, c_out, 1, coord=coord, norm=norm)
        gap2 = GAP1d(1)
        lin = nn.Linear(c_out, c_out)  # Added by me to avoid ReLU preventing negative values
        self.head = nn.Sequential(gap1, conv1x1_1, conv1x1_2, conv1x1_3, gap2, lin)

    def forward(self, x):
        x = self.in_norm(x)
        x = self.backbone(x)
        x = self.head(x)
        return x


class TSDataset(Dataset):
    def __init__(
            self,
            X: np.ndarray,
            y: np.ndarray = None,
            window_size: int = 1,
            stride: int = 1,
    ):

        self.window_size = window_size
        self.stride = stride
        self.number_of_windows = ((len(X) - window_size) // stride) + 1
        self.X = torch.tensor(X, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.float32).squeeze() if y is not None else None

    def __len__(self):
        return self.number_of_windows

    def __getitem__(self, idx):
        # get window
        start_idx = idx * self.stride
        end_idx = start_idx + self.window_size

        # apply instance normalization
        window = self.X[start_idx:end_idx]  # (window_size, n_features)

        # change to channels first format
        window = window.permute(1, 0)

        # get label and return
        if self.y is not None:
            return window, self.y[end_idx - 1]
        return window


def get_weighted_sampler(dataset: TSDataset, num_samples: int):
    global_labels = dataset.y.numpy().max(axis=1).astype(int)
    unique, counts = np.unique(global_labels, return_counts=True)
    class_weights = 1.0 / counts
    sample_weights = class_weights[global_labels]
    sample_weights = sample_weights[dataset.window_size - 1::dataset.stride]
    return WeightedRandomSampler(sample_weights, num_samples, replacement=True)


def get_class_weights(labels: np.ndarray):
    pos_counts = labels.sum(axis=0)
    neg_counts = len(labels) - pos_counts
    class_weights = neg_counts / pos_counts
    print("Class weights:", class_weights)
    device = "cuda" if torch.cuda.is_available() else "cpu"
    return torch.tensor(class_weights, dtype=torch.float32, device=device)

In [3]:
# num_samples_per_epoch = 2_097_152  # 2048 steps
epochs = 10
batch_size = 1024
accumulate_grad_batches = 2  # virtual batch size of 2048
split = 264_960
file_dir = "/kaggle/working"

seed = 42
in_channels = 6
out_channels = 6
eval_steps = 1024  # val_check_interval

In [4]:
df_train = pd.read_csv("/kaggle/input/esa-mission-1-train-dataset/84_months.train.csv")

target_channels = [f"channel_{i}" for i in range(41, 47)]
label_cols = ["is_anomaly_" + col for col in target_channels]

train_array = df_train[target_channels].iloc[:-split].values.astype(np.float32)
val_array = df_train[target_channels].iloc[-split:].values.astype(np.float32)
train_labels = df_train[label_cols].iloc[:-split].values.astype(np.float32).clip(0., 1.)
val_labels = df_train[label_cols].iloc[-split:].values.astype(np.float32).clip(0., 1.)

class_weights = get_class_weights(train_labels)

train_array.shape, train_labels.shape, val_array.shape, val_labels.shape

Class weights: [75.98615178 76.00953507 75.47281678 75.46128577 76.02624612 75.9719617 ]


((7099201, 6), (7099201, 6), (264960, 6), (264960, 6))

In [5]:
# sample model configuration
window_size = 56
nf = 4
coord = True
norm = "Instance"
concat_pool = False
adaptive_size = 20
activation = "relu"
dropout = 0.3
ks = 8
bottleneck = True
bn_1st = False
norm_act = True
width = 6
input_norm = False
norm_weights = False

# set seed
L.seed_everything(seed=seed, verbose=False)

# setup datasets
train_dataset = TSDataset(train_array, train_labels, window_size=window_size)
val_dataset = TSDataset(val_array, val_labels, window_size=window_size)

# train_sampler = get_weighted_sampler(train_dataset, num_samples=num_samples_per_epoch)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=False, sampler=None)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, drop_last=False)

# initialize model
model = XceptionTimePlus(
    c_in=in_channels,
    c_out=out_channels,
    nf=nf,
    coord=coord,
    norm=norm,
    concat_pool=concat_pool,
    adaptive_size=adaptive_size,
    activation=activation,
    dropout=dropout,
    ks=ks,
    bottleneck=bottleneck,
    bn_1st=bn_1st,
    norm_act=norm_act,
    width=width,
    input_norm=input_norm,
    window_size=window_size,
    norm_weights=norm_weights,
    # fixed BaseClassifier params
    criterion=nn.BCEWithLogitsLoss(pos_weight=class_weights),
    val_criterion=nn.BCEWithLogitsLoss(pos_weight=class_weights),
    optimizer="adam",
    lr=1e-3,
    patience=40,
    min_delta=0.,
    checkpoint_dir=file_dir,
    use_one_cycle=True
).cuda()

# calculate MACs and model parameters
flops, macs, num_params = calculate_flops(
    model=model,
    input_shape=(1, in_channels, window_size),
    print_results=True,
    print_detailed=False,
    output_as_string=False,
    include_backPropagation=False,
)

# Setup trainer
# torch.set_float32_matmul_precision("high")

trainer = L.Trainer(
    max_epochs=epochs,
    max_time="00:05:00:00",
    log_every_n_steps=16,
    accumulate_grad_batches=accumulate_grad_batches,
    logger=TensorBoardLogger(save_dir=file_dir, name="lightning_logs"),
    val_check_interval=eval_steps,
    enable_model_summary=True,
    enable_progress_bar=False,
)

# train model
trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)

INFO: You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
INFO: GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO: HPU available: False, using: 0 HPUs
INFO: The following callbacks returned in `LightningModule.configure_callbacks` will override existing callbacks passed to Trainer: ModelCheckpoint



------------------------------------- Calculate Flops Results -------------------------------------
Notations:
number of parameters (Params), number of multiply-accumulate operations(MACs),
number of floating-point operations (FLOPs), floating-point operations per second (FLOPS),
fwd FLOPs (model forward propagation FLOPs), bwd FLOPs (model backward propagation FLOPs),
default model backpropagation takes 2.00 times as much computation as forward propagation.

Total Training Params:                                                  18.88 K 
fwd MACs:                                                               907.05 KMACs
fwd FLOPs:                                                              1.96 MFLOPS
fwd+bwd MACs:                                                           2.72 MMACs
fwd+bwd FLOPs:                                                          5.88 MFLOPS
---------------------------------------------------------------------------------------------------


/usr/local/lib/python3.10/dist-packages/lightning/pytorch/callbacks/model_checkpoint.py:654: Checkpoint directory /kaggle/working exists and is not empty.
INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO: Loading `train_dataloader` to estimate number of stepping batches.
/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance.
INFO: 
  | Name          | Type              | Params | Mode
-----------------------------------------------------------
0 | criterion     | BCEWithLogitsLoss | 0      | eval
1 | val_criterion | BCEWithLogitsLoss | 0      | eval
2 | in_norm       | Identity          | 0      | eval
3 | backbone      | XceptionBlockPlus | 15.3 K | eval
4 | head          | Sequential        | 3.6 K  | eval
--------------------------------