# TabM

This is a standalone usage example for the TabM project.
The easiest way to run it is [Pixi](https://pixi.sh/latest/#installation):

```shell
git clone https://github.com/yandex-research/tabm
cd tabm

# With GPU:
pixi run -e cuda jupyter-lab example.ipynb

# Without GPU:
pixi run jupyter-lab example.ipynb
```

For the full overview of the project, and for non-Pixi environment setups, see README in the repository:
https://github.com/yandex-research/tabm

In [1]:
# ruff: noqa: E402
import math
import random
import warnings
from typing import Literal, NamedTuple

import numpy as np
import rtdl_num_embeddings  # https://github.com/yandex-research/rtdl-num-embeddings
import scipy.special
import sklearn.datasets
import sklearn.metrics
import sklearn.model_selection
import sklearn.preprocessing
import torch
import torch.nn.functional as F
import torch.optim
from torch import Tensor
from tqdm.std import tqdm
from torch.utils.data import Dataset
import pandas as pd


warnings.simplefilter('ignore')
from tabm_reference import Model, make_parameter_groups

warnings.resetwarnings()

In [2]:
seed = 0
random.seed(seed)
np.random.seed(seed + 1)
torch.manual_seed(seed + 2)
pass

# Dataset

In [3]:
CAT_TO_NUM_LABELS = {
    "Normal traffic": 0,
    "Suspicious traffic": 1,
    "DDoS attack": 2,
}

class DDoSDataset(Dataset):
    def __init__(self, split):
        self.train_data_paths = [f'/home/appuser/data/train/SCLDDoS2024_SetA_events_extended.csv',
                                 f'/home/appuser/data/train/SCLDDoS2024_SetB_events_extended.csv']
        self.test_data_paths = [f'/home/appuser/data/test/SCLDDoS2024_SetC_events_extended.csv']     
        
        self.split = split   
        
        if split == 'train':
            self.features, self.lables = self.load_data(self.train_data_paths, apply_smote=False)
        elif split == 'test':
            self.features, self.lables = self.load_data(self.test_data_paths, apply_smote=False)
        else:
            print("Invalid split. Use 'train' or 'test'")
            
    
    def get_ports(self):
        return self.ddos_ports
    
    
    def get_data(self):
        return self.features.numpy(), self.lables.numpy()
        
        
    # preload the data as it makes the training much faster (and it easily fits in memory)
    def load_data(self, data_paths, apply_smote=False, undersample=False, sample_factor=4):
        data = []
        for path in data_paths:
            data.append(pd.read_csv(path).fillna(0))  # Read and fill NaNs with 0s
            
        df = pd.concat(data, ignore_index=True)  # Combine all dataframes
        
        self.ddos_ports = df[df['Type'] == "DDoS attack"]["Port number"].unique()
        
        feature_columns = df.columns[:19]  # All except the last column
        #feature_columns = df.columns[:-1]
        label_column = df.columns[-1]  # The last column
        
        # Convert categorical labels to numeric using the dictionary
        df[label_column] = df[label_column].map(CAT_TO_NUM_LABELS)
        
        # Check for missing or unknown labels
        if df[label_column].isna().any():
            print(df[label_column].isna().sum(), "missing labels")

        features = df[feature_columns].values
        labels = df[label_column].values
        
        # Normalize the features
        #features = self.normalize(features)
        
        # Convert to PyTorch tensors
        features = torch.tensor(features, dtype=torch.float32)
        labels = torch.tensor(labels, dtype=torch.long)  # Classification requires long dtype
        
        
        
        if undersample and self.split == 'train':
            # Undersample the majority class (label=0)
            class_0_indices = np.where(labels.cpu().numpy() == 0)[0]
            class_1_indices = np.where(labels.cpu().numpy() == 1)[0]
            class_2_indices = np.where(labels.cpu().numpy() == 2)[0]

            # Randomly undersample the majority class
            num_class_0_samples = sample_factor*(len(class_1_indices) + len(class_2_indices))  # Same number as the minority class
            class_0_indices_undersampled = np.random.choice(class_0_indices, num_class_0_samples, replace=False)

            # Concatenate indices of class 1, 2, and undersampled class 0
            undersampled_indices = np.concatenate([class_0_indices_undersampled, class_1_indices, class_2_indices])

            # Subset the dataset to include only the sampled indices
            features = features[undersampled_indices]
            labels = labels[undersampled_indices]

        return features, labels


    def __len__(self):
        return len(self.features)

    def __getitem__(self, idx):
        return self.features[idx], self.lables[idx]

In [4]:
# >>> Dataset.
TaskType = Literal['regression', 'binclass', 'multiclass']

# Regression.
task_type: TaskType = 'multiclass'
n_classes = None
dataset = sklearn.datasets.fetch_california_housing()
X_cont: np.ndarray = dataset['data']
Y: np.ndarray = dataset['target']

dataset = DDoSDataset(split='train')
X_cont = dataset.features.numpy()
Y = dataset.lables.numpy()


# Classification.
n_classes = 3
assert n_classes >= 2
task_type: TaskType = 'binclass' if n_classes == 2 else 'multiclass'
# X_cont, Y = sklearn.datasets.make_classification(
#     n_samples=20000,
#     n_features=8,
#     n_classes=n_classes,
#     n_informative=3,
#     n_redundant=2,
# )

#task_is_regression = task_type == 'regression'

# >>> Continuous features.
X_cont: np.ndarray = X_cont.astype(np.float32)
n_cont_features = X_cont.shape[1]

# # >>> Categorical features.
# # NOTE: the above datasets do not have categorical features, however,
# # for the demonstration purposes, it is possible to generate them.
# cat_cardinalities = [
#     # NOTE: uncomment the two lines below to add two categorical features.
#     # 4,  # Allowed values: [0, 1, 2, 3].
#     # 7,  # Allowed values: [0, 1, 2, 3, 4, 5, 6].
# ]
# X_cat = (
#     np.column_stack(
#         [np.random.randint(0, c, (len(X_cont),)) for c in cat_cardinalities]
#     )
#     if cat_cardinalities
#     else None
# )

# >>> Labels.
if task_type == 'regression':
    Y = Y.astype(np.float32)
else:
    assert n_classes is not None
    Y = Y.astype(np.int64)
    assert set(Y.tolist()) == set(
        range(n_classes)
    ), 'Classification labels must form the range [0, 1, ..., n_classes - 1]'

# >>> Split the dataset.
all_idx = np.arange(len(Y))
train_idx, val_idx = sklearn.model_selection.train_test_split(
    all_idx, train_size=0.8
)

data_numpy = {
    'train': {'x_cont': X_cont[train_idx], 'y': Y[train_idx]},
    'val': {'x_cont': X_cont[val_idx], 'y': Y[val_idx]},
}

s = 0
# if X_cat is not None:
#     data_numpy['train']['x_cat'] = X_cat[train_idx]
#     data_numpy['val']['x_cat'] = X_cat[val_idx]
#     data_numpy['test']['x_cat'] = X_cat[test_idx]

In [5]:
# Add the Dataset C as the testing class
dataset = DDoSDataset(split='test')
X_cont = dataset.features.numpy()
Y = dataset.lables.numpy()


# Classification.
n_classes = 3
assert n_classes >= 2
task_type: TaskType = 'binclass' if n_classes == 2 else 'multiclass'
# X_cont, Y = sklearn.datasets.make_classification(
#     n_samples=20000,
#     n_features=8,
#     n_classes=n_classes,
#     n_informative=3,
#     n_redundant=2,
# )

#task_is_regression = task_type == 'regression'

# >>> Continuous features.
X_cont: np.ndarray = X_cont.astype(np.float32)
n_cont_features = X_cont.shape[1]

# >>> Labels.
if task_type == 'regression':
    Y = Y.astype(np.float32)
else:
    assert n_classes is not None
    Y = Y.astype(np.int64)
    assert set(Y.tolist()) == set(
        range(n_classes)
    ), 'Classification labels must form the range [0, 1, ..., n_classes - 1]'
    
    
data_numpy['test'] = {'x_cont': X_cont, 'y': Y}

# Data preprocessing

In [6]:
# Feature preprocessing.
# NOTE
# The choice between preprocessing strategies depends on a task and a model.

# Simple preprocessing strategy.
# preprocessing = sklearn.preprocessing.StandardScaler().fit(
#     data_numpy['train']['x_cont']
# )

# Advanced preprocessing strategy.
# The noise is added to improve the output of QuantileTransformer in some cases.
X_cont_train_numpy = data_numpy['train']['x_cont']
noise = (
    np.random.default_rng(0)
    .normal(0.0, 1e-5, X_cont_train_numpy.shape)
    .astype(X_cont_train_numpy.dtype)
)
preprocessing = sklearn.preprocessing.QuantileTransformer(
    n_quantiles=max(min(len(train_idx) // 30, 1000), 10),
    output_distribution='normal',
    subsample=10**9,
).fit(X_cont_train_numpy + noise)
del X_cont_train_numpy

# Apply the preprocessing.
for part in data_numpy:
    data_numpy[part]['x_cont'] = preprocessing.transform(data_numpy[part]['x_cont'])


# Label preprocessing.
class RegressionLabelStats(NamedTuple):
    mean: float
    std: float


Y_train = data_numpy['train']['y'].copy()
if task_type == 'regression':
    # For regression tasks, it is highly recommended to standardize the training labels.
    regression_label_stats = RegressionLabelStats(
        Y_train.mean().item(), Y_train.std().item()
    )
    Y_train = (Y_train - regression_label_stats.mean) / regression_label_stats.std
else:
    regression_label_stats = None

#  PyTorch settings

In [7]:
# Device
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# Convert data to tensors
data = {
    part: {k: torch.as_tensor(v, device=device) for k, v in data_numpy[part].items()}
    for part in data_numpy
}
Y_train = torch.as_tensor(Y_train, device=device)
if task_type == 'regression':
    for part in data:
        data[part]['y'] = data[part]['y'].float()
    Y_train = Y_train.float()

# Automatic mixed precision (AMP)
# torch.float16 is implemented for completeness,
# but it was not tested in the project,
# so torch.bfloat16 is used by default.
amp_dtype = (
    torch.bfloat16
    if torch.cuda.is_available() and torch.cuda.is_bf16_supported()
    else torch.float16
    if torch.cuda.is_available()
    else None
)
# Changing False to True will result in faster training on compatible hardware.
amp_enabled = False and amp_dtype is not None
grad_scaler = torch.cuda.amp.GradScaler() if amp_dtype is torch.float16 else None  # type: ignore

# torch.compile
compile_model = False

# fmt: off
print(
    f'Device:        {device.type.upper()}'
    f'\nAMP:           {amp_enabled} (dtype: {amp_dtype})'
    f'\ntorch.compile: {compile_model}'
)
# fmt: on

Device:        CUDA
AMP:           False (dtype: torch.bfloat16)
torch.compile: False


# Model

In [17]:
# Choose one of the two configurations below.

# TabM
arch_type = 'tabm'
bins = None

# TabM-mini with the piecewise-linear embeddings.
# arch_type = 'tabm-mini'
# bins = rtdl_num_embeddings.compute_bins(data['train']['x_cont'])

# arch_type = 'tabm-packed'
# bins = rtdl_num_embeddings.compute_bins(data['train']['x_cont'])

# d_block: 512
# n_blocks: 3

model = Model(
    n_num_features=n_cont_features,
    cat_cardinalities=[],
    n_classes=n_classes,
    backbone={
        'type': 'MLP',
        'n_blocks': 3 if bins is None else 2,
        'd_block': 256,
        'dropout': 0.2,
        'n_blocks': 5
    },
    bins=bins,
    num_embeddings=(
        None
        if bins is None
        else {
            'type': 'PiecewiseLinearEmbeddings',
            'd_embedding': 32,
            'activation': False,
            'version': 'B',
        }
    ),
    arch_type=arch_type,
    k=48,
    share_training_batches=True,
).to(device)
optimizer = torch.optim.AdamW(make_parameter_groups(model), lr=1e-3, weight_decay=3e-4)

if compile_model:
    # NOTE
    # `torch.compile` is intentionally called without the `mode` argument
    # (mode="reduce-overhead" caused issues during training with torch==2.0.1).
    model = torch.compile(model)
    evaluation_mode = torch.no_grad
else:
    evaluation_mode = torch.inference_mode

In [18]:
@torch.autocast(device.type, enabled=amp_enabled, dtype=amp_dtype)  # type: ignore[code]
def apply_model(part: str, idx: Tensor) -> Tensor:
    return (
        model(
            data[part]['x_cont'][idx],
            data[part]['x_cat'][idx] if 'x_cat' in data[part] else None,
        )
        .squeeze(-1)  # Remove the last dimension for regression tasks.
        .float()
    )


base_loss_fn = F.mse_loss if task_type == 'regression' else F.cross_entropy
weight = torch.tensor([1.0, 1.0, 3.0], device=device)


def loss_fn(y_pred: Tensor, y_true: Tensor) -> Tensor:
    # TabM produces k predictions. Each of them must be trained separately.
    # (regression)     y_pred.shape == (batch_size, k)
    # (classification) y_pred.shape == (batch_size, k, n_classes)
    k = y_pred.shape[-1 if task_type == 'regression' else -2]
    return base_loss_fn(
        y_pred.flatten(0, 1),
        y_true.repeat_interleave(k) if model.share_training_batches else y_true,
        weight=weight,
    )


@evaluation_mode()
def evaluate(part: str) -> float:
    model.eval()

    # When using torch.compile, you may need to reduce the evaluation batch size.
    eval_batch_size = 8096
    y_pred: np.ndarray = (
        torch.cat(
            [
                apply_model(part, idx)
                for idx in torch.arange(len(data[part]['y']), device=device).split(
                    eval_batch_size
                )
            ]
        )
        .cpu()
        .numpy()
    )
    if task_type == 'regression':
        # Transform the predictions back to the original label space.
        assert regression_label_stats is not None
        y_pred = y_pred * regression_label_stats.std + regression_label_stats.mean

    # Compute the mean of the k predictions.
    if task_type != 'regression':
        # For classification, the mean must be computed in the probabily space.
        y_pred = scipy.special.softmax(y_pred, axis=-1)
    y_pred = y_pred.mean(1)

    y_true = data[part]['y'].cpu().numpy()
    score = (
        -(sklearn.metrics.mean_squared_error(y_true, y_pred) ** 0.5)
        if task_type == 'regression'
        else sklearn.metrics.f1_score(y_true, y_pred.argmax(1), average='macro')
    )
    return float(score)  # The higher -- the better.


print(f'Test score before training: {evaluate("test"):.4f}')

Test score before training: 0.0153


# Training

In [19]:
# For demonstration purposes (fast training and bad performance),
# one can set smaller values:
# n_epochs = 20
# patience = 2
n_epochs = 200

train_size = len(train_idx)
batch_size = 256
epoch_size = math.ceil(train_size / batch_size)
best = {
    'val': -math.inf,
    'test': -math.inf,
    'epoch': -1,
}
# Early stopping: the training stops when
# there are more than `patience` consequtive bad updates.
patience = 200
remaining_patience = patience



print('-' * 88 + '\n')
for epoch in range(n_epochs):
    batches = (
        torch.randperm(train_size, device=device).split(batch_size)
        if model.share_training_batches
        else [
            x.transpose(0, 1).flatten()
            for x in torch.rand((model.k, train_size), device=device)
            .argsort(dim=1)
            .split(batch_size, dim=1)
        ]
    )
    for batch_idx in tqdm(batches, desc=f'Epoch {epoch}'):
        model.train()
        optimizer.zero_grad()
        loss = loss_fn(apply_model('train', batch_idx), Y_train[batch_idx])
        if grad_scaler is None:
            loss.backward()
            optimizer.step()
        else:
            grad_scaler.scale(loss).backward()  # type: ignore
            grad_scaler.step(optimizer)
            grad_scaler.update()

    val_score = evaluate('val')
    test_score = evaluate('test')
    print(f'(val) {val_score:.4f} (test) {test_score:.4f}')

    if test_score > best['test']:
        print('🌸 New best epoch! 🌸')
        best = {'val': val_score, 'test': test_score, 'epoch': epoch}
        remaining_patience = patience
        torch.save(model.state_dict(), 'tabm_no_frequ_features_extended_more_weight.pth')
    else:
        remaining_patience -= 1

    if remaining_patience < 0:
        break

    print()


print('\n\nResult:')
print(best)

----------------------------------------------------------------------------------------



Epoch 0: 100%|██████████| 828/828 [00:06<00:00, 120.31it/s]


(val) 0.7652 (test) 0.7199
🌸 New best epoch! 🌸



Epoch 1: 100%|██████████| 828/828 [00:06<00:00, 120.84it/s]


(val) 0.7554 (test) 0.7836
🌸 New best epoch! 🌸



Epoch 2: 100%|██████████| 828/828 [00:06<00:00, 120.64it/s]


(val) 0.8198 (test) 0.7800



Epoch 3: 100%|██████████| 828/828 [00:06<00:00, 119.02it/s]


(val) 0.8349 (test) 0.7723



Epoch 4: 100%|██████████| 828/828 [00:06<00:00, 120.45it/s]


(val) 0.8302 (test) 0.7797



Epoch 5: 100%|██████████| 828/828 [00:06<00:00, 118.96it/s]


(val) 0.8235 (test) 0.7903
🌸 New best epoch! 🌸



Epoch 6: 100%|██████████| 828/828 [00:06<00:00, 120.84it/s]


(val) 0.8212 (test) 0.8021
🌸 New best epoch! 🌸



Epoch 7: 100%|██████████| 828/828 [00:06<00:00, 120.86it/s]


(val) 0.8373 (test) 0.7889



Epoch 8: 100%|██████████| 828/828 [00:06<00:00, 120.25it/s]


(val) 0.8407 (test) 0.7883



Epoch 9: 100%|██████████| 828/828 [00:06<00:00, 120.89it/s]


(val) 0.8351 (test) 0.8080
🌸 New best epoch! 🌸



Epoch 10: 100%|██████████| 828/828 [00:06<00:00, 120.67it/s]


(val) 0.8355 (test) 0.8004



Epoch 11: 100%|██████████| 828/828 [00:06<00:00, 120.84it/s]


(val) 0.8449 (test) 0.7764



Epoch 12: 100%|██████████| 828/828 [00:06<00:00, 120.92it/s]


(val) 0.8394 (test) 0.7903



Epoch 13: 100%|██████████| 828/828 [00:06<00:00, 120.27it/s]


(val) 0.8363 (test) 0.7961



Epoch 14: 100%|██████████| 828/828 [00:06<00:00, 120.96it/s]


(val) 0.8412 (test) 0.8304
🌸 New best epoch! 🌸



Epoch 15: 100%|██████████| 828/828 [00:06<00:00, 120.55it/s]


(val) 0.8402 (test) 0.8203



Epoch 16: 100%|██████████| 828/828 [00:06<00:00, 118.83it/s]


(val) 0.8471 (test) 0.8020



Epoch 17: 100%|██████████| 828/828 [00:06<00:00, 119.69it/s]


(val) 0.8439 (test) 0.8121



Epoch 18: 100%|██████████| 828/828 [00:06<00:00, 120.21it/s]


(val) 0.8411 (test) 0.8225



Epoch 19: 100%|██████████| 828/828 [00:06<00:00, 120.02it/s]


(val) 0.8474 (test) 0.8266



Epoch 20: 100%|██████████| 828/828 [00:06<00:00, 119.78it/s]


(val) 0.8389 (test) 0.7732



Epoch 21: 100%|██████████| 828/828 [00:06<00:00, 120.36it/s]


(val) 0.8465 (test) 0.8435
🌸 New best epoch! 🌸



Epoch 22: 100%|██████████| 828/828 [00:06<00:00, 119.30it/s]


(val) 0.8425 (test) 0.7917



Epoch 23: 100%|██████████| 828/828 [00:06<00:00, 120.75it/s]


(val) 0.8481 (test) 0.8238



Epoch 24: 100%|██████████| 828/828 [00:06<00:00, 118.38it/s]


(val) 0.8456 (test) 0.8211



Epoch 25: 100%|██████████| 828/828 [00:06<00:00, 120.48it/s]


(val) 0.8493 (test) 0.8361



Epoch 26: 100%|██████████| 828/828 [00:06<00:00, 120.42it/s]


(val) 0.8365 (test) 0.8260



Epoch 27: 100%|██████████| 828/828 [00:06<00:00, 120.80it/s]


(val) 0.8491 (test) 0.8426



Epoch 28: 100%|██████████| 828/828 [00:06<00:00, 118.42it/s]


(val) 0.8477 (test) 0.8408



Epoch 29: 100%|██████████| 828/828 [00:06<00:00, 118.56it/s]


(val) 0.8535 (test) 0.8320



Epoch 30: 100%|██████████| 828/828 [00:06<00:00, 120.58it/s]


(val) 0.8514 (test) 0.8251



Epoch 31: 100%|██████████| 828/828 [00:06<00:00, 120.99it/s]


(val) 0.8521 (test) 0.8328



Epoch 32: 100%|██████████| 828/828 [00:06<00:00, 119.25it/s]


(val) 0.8509 (test) 0.8359



Epoch 33: 100%|██████████| 828/828 [00:06<00:00, 121.27it/s]


(val) 0.8468 (test) 0.8525
🌸 New best epoch! 🌸



Epoch 34: 100%|██████████| 828/828 [00:06<00:00, 121.17it/s]


(val) 0.8479 (test) 0.7500



Epoch 35: 100%|██████████| 828/828 [00:06<00:00, 120.61it/s]


(val) 0.8470 (test) 0.7696



Epoch 36: 100%|██████████| 828/828 [00:06<00:00, 120.61it/s]


(val) 0.8558 (test) 0.8401



Epoch 37: 100%|██████████| 828/828 [00:06<00:00, 120.89it/s]


(val) 0.8482 (test) 0.8548
🌸 New best epoch! 🌸



Epoch 38: 100%|██████████| 828/828 [00:06<00:00, 120.80it/s]


(val) 0.8502 (test) 0.8191



Epoch 39: 100%|██████████| 828/828 [00:06<00:00, 120.91it/s]


(val) 0.8468 (test) 0.7993



Epoch 40: 100%|██████████| 828/828 [00:06<00:00, 119.88it/s]


(val) 0.8453 (test) 0.7851



Epoch 41:  57%|█████▋    | 476/828 [00:03<00:02, 121.37it/s]


KeyboardInterrupt: 

In [13]:
from sklearn.metrics import (
    accuracy_score,
    classification_report,
    confusion_matrix,
    f1_score
)

In [15]:
# Define the model
arch_type = 'tabm'
bins = None

arch_type = 'tabm-mini'
bins = rtdl_num_embeddings.compute_bins(data['train']['x_cont'])

model = Model(
    n_num_features=n_cont_features,
    cat_cardinalities=[],
    n_classes=n_classes,
    backbone={
        'type': 'MLP',
        'n_blocks': 3 if bins is None else 2,
        'd_block': 512,
        'dropout': 0.1,
    },
    bins=bins,
    num_embeddings=(
        None
        if bins is None
        else {
            'type': 'PiecewiseLinearEmbeddings',
            'd_embedding': 16,
            'activation': False,
            'version': 'B',
        }
    ),
    arch_type=arch_type,
    k=32,
    share_training_batches=True,
).to(device)

In [16]:
@torch.autocast(device.type, enabled=amp_enabled, dtype=amp_dtype)  # type: ignore[code]
def apply_model(part: str, idx: Tensor) -> Tensor:
    return (
        model(
            data[part]['x_cont'][idx],
            data[part]['x_cat'][idx] if 'x_cat' in data[part] else None,
        )
        .squeeze(-1)  # Remove the last dimension for regression tasks.
        .float()
    )

In [17]:
# Inference on the test dataset
model.load_state_dict(torch.load('m_ft_transformer_model.pth'))
model.eval()

part = "test"

eval_batch_size = 12
# y_pred: np.ndarray = (
#     torch.cat(
#         [
#             apply_model(part, idx).cpu()
#             for idx in torch.arange(len(data[part]['y']), device=device).split(
#                 eval_batch_size
#             )
#         ]
#     )
#     .numpy()
# )

y_pred_list = []
for idx in tqdm(torch.arange(len(data[part]['y']), device=device).split(eval_batch_size)):
    with torch.no_grad():
        preds = apply_model(part, idx).cpu()
        probs = scipy.special.softmax(preds.numpy(), axis=-1)
        averaged = probs.mean(1)  # shape: [B, C]
        preds_class = np.argmax(averaged, axis=1)
        y_pred_list.append(preds_class)

y_pred = np.concatenate(y_pred_list)

y_test = data[part]['y'].cpu().numpy()

# Overall accuracy
accuracy = accuracy_score(y_test, y_pred)
print(f"Accuracy: {accuracy:.4f}")

# Classification report (includes precision, recall, F1 per class + macro/micro)
report = classification_report(y_test, y_pred, output_dict=True)
report_df = pd.DataFrame(report).transpose()
print("\nClassification Report:\n", report_df)

# F1 scores
f1_micro = f1_score(y_test, y_pred, average='micro')
f1_macro = f1_score(y_test, y_pred, average='macro')
print(f"\nF1 (Micro): {f1_micro:.4f}")
print(f"F1 (Macro): {f1_macro:.4f}")

# Class-wise accuracy (same as recall per class)
class_wise_accuracy = report_df.loc[[str(i) for i in np.unique(y_test)], "recall"]
print("\nClass-wise Accuracy (Recall):\n", class_wise_accuracy)

# Confusion Matrix
conf_matrix = confusion_matrix(y_test, y_pred)
print("\nConfusion Matrix:\n", conf_matrix)

100%|██████████| 10834/10834 [00:06<00:00, 1656.53it/s]


Accuracy: 0.9917

Classification Report:
               precision    recall  f1-score        support
0              0.995075  0.996727  0.995901  125892.000000
1              0.887179  0.963630  0.923826    3052.000000
2              0.861063  0.475829  0.612943    1055.000000
accuracy       0.991723  0.991723  0.991723       0.991723
macro avg      0.914439  0.812062  0.844223  129999.000000
weighted avg   0.991455  0.991723  0.991101  129999.000000

F1 (Micro): 0.9917
F1 (Macro): 0.8442

Class-wise Accuracy (Recall):
 0    0.996727
1    0.963630
2    0.475829
Name: recall, dtype: float64

Confusion Matrix:
 [[125480    350     62]
 [    92   2941     19]
 [   529     24    502]]
