# TabM

This is a standalone usage example for the TabM project.
The easiest way to run it is [Pixi](https://pixi.sh/latest/#installation):

```shell
git clone https://github.com/yandex-research/tabm
cd tabm

# With GPU:
pixi run -e cuda jupyter-lab example.ipynb

# Without GPU:
pixi run jupyter-lab example.ipynb
```

For the full overview of the project, and for non-Pixi environment setups, see README in the repository:
https://github.com/yandex-research/tabm

In [1]:
# ruff: noqa: E402
import math
import random
import warnings
from typing import Literal, NamedTuple

import os
import numpy as np
import rtdl_num_embeddings  # https://github.com/yandex-research/rtdl-num-embeddings
import scipy.special
import sklearn.datasets
import sklearn.metrics
import sklearn.model_selection
import sklearn.preprocessing
import torch
import torch.nn.functional as F
import torch.optim
from torch import Tensor
from tqdm.std import tqdm
from torch.utils.data import Dataset
import pandas as pd


warnings.simplefilter('ignore')
from tabm_reference import Model, make_parameter_groups

warnings.resetwarnings()

In [2]:
seed = 0
random.seed(seed)
np.random.seed(seed + 1)
torch.manual_seed(seed + 2)
pass

# Dataset

In [5]:
CAT_TO_NUM_LABELS = {
    "Normal traffic": 0,
    "Suspicious traffic": 1,
    "DDoS attack": 2,
}

component_columns = [
    "Attack ID", "Detect count", "Card", "Victim IP", "Port number",
    "Attack code", "Significant flag", "Packet speed", "Data speed", "Avg packet len",
    "Source IP count", "Time"
]

event_columns = [
    "Attack ID", "Card", "Victim IP", "Port number", "Attack code", 
    "Detect count", "Significant flag", "Packet speed", "Data speed", 
    "Avg packet len", "Avg source IP count", "Start time", "End time", 
    "Whitelist flag", "Type"
]

class DDoSDataset(Dataset):
    def __init__(self, split):
        self.train_data_paths = [f'/home/appuser/data/train/SCLDDoS2024_SetA_events_extended.csv',
                                 f'/home/appuser/data/train/SCLDDoS2024_SetB_events_extended.csv']
        self.test_data_paths = [f'/home/appuser/data/test/SCLDDoS2024_SetC_events_extended.csv']     
        
        self.split = split   
        
        if split == 'train':
            self.features, self.lables = self.load_data(self.train_data_paths, apply_smote=False)
        elif split == 'test':
            self.features, self.lables = self.load_data(self.test_data_paths, apply_smote=False)
        else:
            print("Invalid split. Use 'train' or 'test'")
            
    
    def get_ports(self):
        return self.ddos_ports
    
    
    def get_data(self):
        return self.features.numpy(), self.lables.numpy()
    
    def engineer_features_from_components(self, df_components):

        grouped = df_components.groupby('Attack ID')

        features = pd.DataFrame()

        features['Unique Ports'] = grouped['Port number'].nunique()
        features['Unique Victim IPs'] = grouped['Victim IP'].nunique()

        return features.reset_index()
    
    # ChatGPT-version
    def add_protocol_columns(self, df):
        # DNS: Port 53
        df['DNS'] = df['Port number'].apply(lambda x: 1 if x == 53 else 0)
        
        # RDP: Port 3389
        df['RDP'] = df['Port number'].apply(lambda x: 1 if x == 3389 else 0)
        
        # TCP: Ports 80, 443, 21, 22
        df['TCP'] = df['Port number'].apply(lambda x: 1 if x in [80, 443, 21, 22, 8080, 8001] else 0)
        
        # SYN: Ports 80, 443
        df['SYN'] = df['Port number'].apply(lambda x: 1 if x in [80, 443] else 0)
        
        # UDP: Ports 53, 123, 161, 162, 69
        df['UDP'] = df['Port number'].apply(lambda x: 1 if x in [53, 123, 161, 162, 69, 123] else 0)
        
        # CoAP: Port 5683
        df['CoAP'] = df['Port number'].apply(lambda x: 1 if x == 5683 else 0)

        # Additional ports that frequently appear in DDoS attacks
        df['Attack Ports'] = df['Port number'].apply(lambda x: 1 if x in [51822, 7777, 1900, 9987, 8080, 7547, 7010, 7007, 2301] else 0)

        return df
        
    # def add_protocol_columns(self, df):
    #     ''' UDP-based protocols '''
    #     # DNS: Port 53
    #     df['DNS'] = df['Port number'].apply(lambda x: 1 if x in [53, 5353] else 0)
        
    #     # NTP: Port 123
    #     df['NTP'] = df['Port number'].apply(lambda x: 1 if x == 123 else 0)
        
    #     # SNMP: Ports 161, 162
    #     df['SNMP'] = df['Port number'].apply(lambda x: 1 if x in [161, 162] else 0)
        
    #     # SSDP: Port 1900
    #     df['SSDP'] = df['Port number'].apply(lambda x: 1 if x == 1900 else 0)
        
    #     # CLDAP: Port 389
    #     df['CLDAP'] = df['Port number'].apply(lambda x: 1 if x == 389 else 0)
        
    #     # Quic: Port 443
    #     df['QUIC'] = df['Port number'].apply(lambda x: 1 if x == 443 else 0)
        
    #     # RDP: Port 3389
    #     df['RDP'] = df['Port number'].apply(lambda x: 1 if x == 3389 else 0)
        
    #     # CoAP: Port 5683, 5684
    #     df['CoAP'] = df['Port number'].apply(lambda x: 1 if x in [5683, 5684] else 0)
        
    #     # TCP: Ports 80, 443, 21, 22
    #     df['HTTP Flood'] = df['Port number'].apply(lambda x: 1 if x in [80, 443] else 0)
                
    #     # FTP: Port 21
    #     df['FTP'] = df['Port number'].apply(lambda x: 1 if x == 21 else 0)
        
    #     # SSH: Port 22
    #     df['SSH'] = df['Port number'].apply(lambda x: 1 if x == 22 else 0)
    
    #     df['Memcached'] = df['Port number'].apply(lambda x: 1 if x == 11211 else 0)
        
    #     df['WS-DD'] = df['Port number'].apply(lambda x: 1 if x == 3702 else 0)
        
    #     df['NetBIOS'] = df['Port number'].apply(lambda x: 1 if x in [137, 138] else 0)
        
    #     df['Kubernetes'] = df['Port number'].apply(lambda x: 1 if x in [10250, 6443] else 0)

    #     return df
        
    # preload the data as it makes the training much faster (and it easily fits in memory)
    def load_data(self, data_paths, apply_smote=False, undersample=False, sample_factor=4, add_features=True):
        data = []
        component_data = []
        
        for path in data_paths:
            event_df = pd.read_csv(path).fillna(0)
            data.append(event_df)

            # Attempt to load corresponding component file
            comp_path = path.replace('_events_extended.csv', '_components.csv')
            ref_event_path = path.replace('_events_extended.csv', '_events.csv')
            if os.path.exists(comp_path) and add_features:
                # Load event data
                ref_ev_df = pd.read_csv(ref_event_path).fillna(0)
                ref_ev_df.columns = event_columns

                # Filter out invalid 'Attack ID's based on 'End time'
                ref_ev_df2 = ref_ev_df[ref_ev_df['End time'].astype(str) != '0']
                invalid_attack_ids = ref_ev_df[ref_ev_df['End time'].astype(str) == '0']['Attack ID'].unique()

                # Filter the event data by removing rows with invalid 'Attack ID's
                valid_attack_ids = ref_ev_df2['Attack ID'].unique()  # Attack IDs present in valid events

                # Load component data
                component_df = pd.read_csv(comp_path).fillna(0)
                component_df.columns = component_columns

                # Remove invalid attack IDs from component data
                component_df = component_df[~component_df['Attack ID'].isin(invalid_attack_ids)]
                
                # Now filter component data to only include 'Attack ID's present in valid events
                component_df = component_df[component_df['Attack ID'].isin(valid_attack_ids)]

                # Append the filtered component data
                component_data.append(component_df)
            else:
                print(f"Component file not found: {comp_path}")
        
        df = pd.concat(data, ignore_index=True)
        
        if component_data:
            df_components = pd.concat(component_data, ignore_index=True)
            comp_features = self.engineer_features_from_components(df_components)
            df = pd.concat([df, comp_features], axis=1)
            df = df.drop(columns=['Attack ID'])
            cols = list(df.columns)
            cols[-3], cols[-1] = cols[-1], cols[-3]
            df = df[cols]
            df = df.dropna(how='all')
            
        
        self.ddos_ports = df[df['Type'] == "DDoS attack"]["Port number"].unique()
        
        #feature_columns = df.columns[:19]  # All except the last column
        
        df = df.loc[:, ~df.columns.str.contains('Ac')]
        feature_columns = df.columns[:-1]
        label_column = df.columns[-1]  # The last column
        
        # Convert categorical labels to numeric using the dictionary
        df[label_column] = df[label_column].map(CAT_TO_NUM_LABELS)
        
        # Check for missing or unknown labels
        if df[label_column].isna().any():
            print(df[label_column].isna().sum(), "missing labels")

        X = df[feature_columns]
        if (add_features):
            X = self.add_protocol_columns(X)
        y = df[label_column]
        
        self.columns = X.columns
        
        
        # Normalize the features
        #features = self.normalize(features)
        
        # Convert to PyTorch tensors
        features = torch.tensor(X.values, dtype=torch.float32)
        labels = torch.tensor(y.values, dtype=torch.long)  # Classification requires long dtype
        
        
        
        if undersample and self.split == 'train':
            # Undersample the majority class (label=0)
            class_0_indices = np.where(labels.cpu().numpy() == 0)[0]
            class_1_indices = np.where(labels.cpu().numpy() == 1)[0]
            class_2_indices = np.where(labels.cpu().numpy() == 2)[0]

            # Randomly undersample the majority class
            num_class_0_samples = sample_factor*(len(class_1_indices) + len(class_2_indices))  # Same number as the minority class
            class_0_indices_undersampled = np.random.choice(class_0_indices, num_class_0_samples, replace=False)

            # Concatenate indices of class 1, 2, and undersampled class 0
            undersampled_indices = np.concatenate([class_0_indices_undersampled, class_1_indices, class_2_indices])

            # Subset the dataset to include only the sampled indices
            features = features[undersampled_indices]
            labels = labels[undersampled_indices]

        return features, labels


    def __len__(self):
        return len(self.features)

    def __getitem__(self, idx):
        return self.features[idx], self.lables[idx]

In [7]:
dataset = DDoSDataset('test')

In [6]:
# >>> Dataset.
TaskType = Literal['regression', 'binclass', 'multiclass']

# Regression.
task_type: TaskType = 'multiclass'
n_classes = None
dataset = sklearn.datasets.fetch_california_housing()
X_cont: np.ndarray = dataset['data']
Y: np.ndarray = dataset['target']

dataset = DDoSDataset(split='train')
X_cont = dataset.features.numpy()
Y = dataset.lables.numpy()


# Classification.
n_classes = 3
assert n_classes >= 2
task_type: TaskType = 'binclass' if n_classes == 2 else 'multiclass'
# X_cont, Y = sklearn.datasets.make_classification(
#     n_samples=20000,
#     n_features=8,
#     n_classes=n_classes,
#     n_informative=3,
#     n_redundant=2,
# )

#task_is_regression = task_type == 'regression'

# >>> Continuous features.
X_cont: np.ndarray = X_cont.astype(np.float32)
n_cont_features = X_cont.shape[1]

# # >>> Categorical features.
# # NOTE: the above datasets do not have categorical features, however,
# # for the demonstration purposes, it is possible to generate them.
# cat_cardinalities = [
#     # NOTE: uncomment the two lines below to add two categorical features.
#     # 4,  # Allowed values: [0, 1, 2, 3].
#     # 7,  # Allowed values: [0, 1, 2, 3, 4, 5, 6].
# ]
# X_cat = (
#     np.column_stack(
#         [np.random.randint(0, c, (len(X_cont),)) for c in cat_cardinalities]
#     )
#     if cat_cardinalities
#     else None
# )

# >>> Labels.
if task_type == 'regression':
    Y = Y.astype(np.float32)
else:
    assert n_classes is not None
    Y = Y.astype(np.int64)
    assert set(Y.tolist()) == set(
        range(n_classes)
    ), 'Classification labels must form the range [0, 1, ..., n_classes - 1]'

# >>> Split the dataset.
all_idx = np.arange(len(Y))
train_idx, val_idx = sklearn.model_selection.train_test_split(
    all_idx, train_size=0.8
)

data_numpy = {
    'train': {'x_cont': X_cont[train_idx], 'y': Y[train_idx]},
    'val': {'x_cont': X_cont[val_idx], 'y': Y[val_idx]},
}

s = 0
# if X_cat is not None:
#     data_numpy['train']['x_cat'] = X_cat[train_idx]
#     data_numpy['val']['x_cat'] = X_cat[val_idx]
#     data_numpy['test']['x_cat'] = X_cat[test_idx]

In [7]:
# Add the Dataset C as the testing class
dataset = DDoSDataset(split='test')
X_cont = dataset.features.numpy()
Y = dataset.lables.numpy()


# Classification.
n_classes = 3
assert n_classes >= 2
task_type: TaskType = 'binclass' if n_classes == 2 else 'multiclass'
# X_cont, Y = sklearn.datasets.make_classification(
#     n_samples=20000,
#     n_features=8,
#     n_classes=n_classes,
#     n_informative=3,
#     n_redundant=2,
# )

#task_is_regression = task_type == 'regression'

# >>> Continuous features.
X_cont: np.ndarray = X_cont.astype(np.float32)
n_cont_features = X_cont.shape[1]

# >>> Labels.
if task_type == 'regression':
    Y = Y.astype(np.float32)
else:
    assert n_classes is not None
    Y = Y.astype(np.int64)
    assert set(Y.tolist()) == set(
        range(n_classes)
    ), 'Classification labels must form the range [0, 1, ..., n_classes - 1]'
    
    
data_numpy['test'] = {'x_cont': X_cont, 'y': Y}

# Data preprocessing

In [8]:
# Feature preprocessing.
# NOTE
# The choice between preprocessing strategies depends on a task and a model.

# Simple preprocessing strategy.
# preprocessing = sklearn.preprocessing.StandardScaler().fit(
#     data_numpy['train']['x_cont']
# )

# Advanced preprocessing strategy.
# The noise is added to improve the output of QuantileTransformer in some cases.

# own
port_columns = [
    'DNS', 'NTP', 'SNMP', 'SSDP', 'CLDAP', 'QUIC', 'RDP', 'CoAP', 
    'HTTP Flood', 'FTP', 'SSH', 'Memcached', 'WS-DD', 'NetBIOS', 'Kubernetes'
]

# gpt
port_columns = [
    'DNS', 'RDP', 'TCP', 'SYN', 'UDP', 'CoAP', 'Attack Ports'
]

continuous_feature_indices = [i for i, col in enumerate(dataset.columns) 
                            if col not in port_columns and col != 'Port number']


X_cont_train_numpy = data_numpy['train']['x_cont'][:, continuous_feature_indices]
noise = (
    np.random.default_rng(0)
    .normal(0.0, 1e-5, X_cont_train_numpy.shape)
    .astype(X_cont_train_numpy.dtype)
)
preprocessing = sklearn.preprocessing.QuantileTransformer(
    n_quantiles=max(min(len(train_idx) // 30, 1000), 10),
    output_distribution='normal',
    subsample=10**9,
).fit(X_cont_train_numpy + noise)
del X_cont_train_numpy

# Apply the preprocessing to the training, test and validation sets.
for part in data_numpy:
    # Transform only continuous features
    data_numpy[part]['x_cont'][:, continuous_feature_indices] = preprocessing.transform(
        data_numpy[part]['x_cont'][:, continuous_feature_indices]
    )
    # Leave port features unchanged

# Label preprocessing.
class RegressionLabelStats(NamedTuple):
    mean: float
    std: float


Y_train = data_numpy['train']['y'].copy()
if task_type == 'regression':
    # For regression tasks, it is highly recommended to standardize the training labels.
    regression_label_stats = RegressionLabelStats(
        Y_train.mean().item(), Y_train.std().item()
    )
    Y_train = (Y_train - regression_label_stats.mean) / regression_label_stats.std
else:
    regression_label_stats = None

#  PyTorch settings

In [9]:
# Device
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# Convert data to tensors
data = {
    part: {k: torch.as_tensor(v, device=device) for k, v in data_numpy[part].items()}
    for part in data_numpy
}
Y_train = torch.as_tensor(Y_train, device=device)
if task_type == 'regression':
    for part in data:
        data[part]['y'] = data[part]['y'].float()
    Y_train = Y_train.float()

# Automatic mixed precision (AMP)
# torch.float16 is implemented for completeness,
# but it was not tested in the project,
# so torch.bfloat16 is used by default.
amp_dtype = (
    torch.bfloat16
    if torch.cuda.is_available() and torch.cuda.is_bf16_supported()
    else torch.float16
    if torch.cuda.is_available()
    else None
)
# Changing False to True will result in faster training on compatible hardware.
amp_enabled = False and amp_dtype is not None
grad_scaler = torch.cuda.amp.GradScaler() if amp_dtype is torch.float16 else None  # type: ignore

# torch.compile
compile_model = False

# fmt: off
print(
    f'Device:        {device.type.upper()}'
    f'\nAMP:           {amp_enabled} (dtype: {amp_dtype})'
    f'\ntorch.compile: {compile_model}'
)
# fmt: on

Device:        CUDA
AMP:           False (dtype: torch.bfloat16)
torch.compile: False


# Model

In [10]:
# Choose one of the two configurations below.

# TabM
arch_type = 'tabm'
bins = None

# TabM-mini with the piecewise-linear embeddings.
# arch_type = 'tabm-mini'
# bins = rtdl_num_embeddings.compute_bins(data['train']['x_cont'])

# arch_type = 'tabm-packed'
# bins = rtdl_num_embeddings.compute_bins(data['train']['x_cont'])

# d_block: 512
# n_blocks: 3

model = Model(
    n_num_features=n_cont_features,
    cat_cardinalities=[],
    n_classes=n_classes,
    backbone={
        'type': 'MLP',
        'n_blocks': 3 if bins is None else 2,
        'd_block': 256,
        'dropout': 0.2,
        'n_blocks': 5
    },
    bins=bins,
    num_embeddings=(
        None
        if bins is None
        else {
            'type': 'PiecewiseLinearEmbeddings',
            'd_embedding': 32,
            'activation': False,
            'version': 'B',
        }
    ),
    arch_type=arch_type,
    k=48,
    share_training_batches=True,
).to(device)
optimizer = torch.optim.AdamW(make_parameter_groups(model), lr=1e-3, weight_decay=3e-4)

if compile_model:
    # NOTE
    # `torch.compile` is intentionally called without the `mode` argument
    # (mode="reduce-overhead" caused issues during training with torch==2.0.1).
    model = torch.compile(model)
    evaluation_mode = torch.no_grad
else:
    evaluation_mode = torch.inference_mode

In [11]:
@torch.autocast(device.type, enabled=amp_enabled, dtype=amp_dtype)  # type: ignore[code]
def apply_model(part: str, idx: Tensor) -> Tensor:
    return (
        model(
            data[part]['x_cont'][idx],
            data[part]['x_cat'][idx] if 'x_cat' in data[part] else None,
        )
        .squeeze(-1)  # Remove the last dimension for regression tasks.
        .float()
    )


base_loss_fn = F.mse_loss if task_type == 'regression' else F.cross_entropy
weight = torch.tensor([1.0, 1.0, 3.0], device=device)
#weight = torch.tensor([1.0, 1.0, 1.0], device=device)


def loss_fn(y_pred: Tensor, y_true: Tensor) -> Tensor:
    # TabM produces k predictions. Each of them must be trained separately.
    # (regression)     y_pred.shape == (batch_size, k)
    # (classification) y_pred.shape == (batch_size, k, n_classes)
    k = y_pred.shape[-1 if task_type == 'regression' else -2]
    return base_loss_fn(
        y_pred.flatten(0, 1),
        y_true.repeat_interleave(k) if model.share_training_batches else y_true,
        weight=weight,
    )


@evaluation_mode()
def evaluate(part: str) -> float:
    model.eval()

    # When using torch.compile, you may need to reduce the evaluation batch size.
    eval_batch_size = 8096
    y_pred: np.ndarray = (
        torch.cat(
            [
                apply_model(part, idx)
                for idx in torch.arange(len(data[part]['y']), device=device).split(
                    eval_batch_size
                )
            ]
        )
        .cpu()
        .numpy()
    )
    if task_type == 'regression':
        # Transform the predictions back to the original label space.
        assert regression_label_stats is not None
        y_pred = y_pred * regression_label_stats.std + regression_label_stats.mean

    # Compute the mean of the k predictions.
    if task_type != 'regression':
        # For classification, the mean must be computed in the probabily space.
        y_pred = scipy.special.softmax(y_pred, axis=-1)
    y_pred = y_pred.mean(1)

    y_true = data[part]['y'].cpu().numpy()
    score = (
        -(sklearn.metrics.mean_squared_error(y_true, y_pred) ** 0.5)
        if task_type == 'regression'
        else sklearn.metrics.f1_score(y_true, y_pred.argmax(1), average='macro')
    )
    return float(score)  # The higher -- the better.


print(f'Test score before training: {evaluate("test"):.4f}')

Test score before training: 0.1826


# Training

In [12]:
# Tensorboard for training
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime


run_name = "tabm_ports_not_normalized_v2"
tb_log_dir = f'/home/appuser/src/logs/TabM/{run_name}'
writer = SummaryWriter(log_dir=tb_log_dir)

n_epochs = 200

train_size = len(train_idx)
batch_size = 256
epoch_size = math.ceil(train_size / batch_size)
best = {
    'val': -math.inf,
    'test': -math.inf,
    'epoch': -1,
}
# Early stopping: the training stops when
# there are more than `patience` consequtive bad updates.
patience = 200
remaining_patience = patience



print('-' * 88 + '\n')
for epoch in range(n_epochs):
    batches = (
        torch.randperm(train_size, device=device).split(batch_size)
        if model.share_training_batches
        else [
            x.transpose(0, 1).flatten()
            for x in torch.rand((model.k, train_size), device=device)
            .argsort(dim=1)
            .split(batch_size, dim=1)
        ]
    )
    epoch_loss = 0.0
    epoch_num = 0
    for batch_idx in tqdm(batches, desc=f'Epoch {epoch}'):
        model.train()
        optimizer.zero_grad()
        loss = loss_fn(apply_model('train', batch_idx), Y_train[batch_idx])
        epoch_loss += loss.item()
        epoch_num += 1
        if grad_scaler is None:
            loss.backward()
            optimizer.step()
        else:
            grad_scaler.scale(loss).backward()  # type: ignore
            grad_scaler.step(optimizer)
            grad_scaler.update()

    val_score = evaluate('val')
    test_score = evaluate('test')
    writer.add_scalar('validation_macro_F1', val_score, epoch+1)
    writer.add_scalar('test_macro_F1', test_score, epoch+1)
    writer.add_scalar('train_loss', epoch_loss/epoch_num, epoch+1)
    print(f'(val) {val_score:.4f} (test) {test_score:.4f}')

    if test_score > best['test']:
        print('🌸 New best epoch! 🌸')
        best = {'val': val_score, 'test': test_score, 'epoch': epoch}
        remaining_patience = patience
        torch.save(model.state_dict(), f"{run_name}.pth")
    else:
        remaining_patience -= 1

    if remaining_patience < 0:
        break

    print()


print('\n\nResult:')
print(best)

----------------------------------------------------------------------------------------



Epoch 0: 100%|██████████| 828/828 [00:05<00:00, 138.02it/s]


(val) 0.3231 (test) 0.3280
🌸 New best epoch! 🌸



Epoch 1: 100%|██████████| 828/828 [00:05<00:00, 139.40it/s]


(val) 0.3231 (test) 0.3280



Epoch 2: 100%|██████████| 828/828 [00:05<00:00, 139.47it/s]


(val) 0.4496 (test) 0.5801
🌸 New best epoch! 🌸



Epoch 3: 100%|██████████| 828/828 [00:05<00:00, 139.01it/s]


(val) 0.5953 (test) 0.6678
🌸 New best epoch! 🌸



Epoch 4: 100%|██████████| 828/828 [00:06<00:00, 137.89it/s]


(val) 0.5995 (test) 0.6726
🌸 New best epoch! 🌸



Epoch 5: 100%|██████████| 828/828 [00:05<00:00, 138.13it/s]


(val) 0.5980 (test) 0.6643



Epoch 6: 100%|██████████| 828/828 [00:06<00:00, 137.10it/s]


(val) 0.6018 (test) 0.6814
🌸 New best epoch! 🌸



Epoch 7: 100%|██████████| 828/828 [00:06<00:00, 136.32it/s]


(val) 0.5982 (test) 0.6826
🌸 New best epoch! 🌸



Epoch 8: 100%|██████████| 828/828 [00:06<00:00, 136.14it/s]


(val) 0.6002 (test) 0.6740



Epoch 9: 100%|██████████| 828/828 [00:06<00:00, 136.03it/s]


(val) 0.6051 (test) 0.6809



Epoch 10: 100%|██████████| 828/828 [00:06<00:00, 136.25it/s]


(val) 0.6066 (test) 0.6792



Epoch 11: 100%|██████████| 828/828 [00:06<00:00, 136.84it/s]


(val) 0.6067 (test) 0.6792



Epoch 12: 100%|██████████| 828/828 [00:06<00:00, 137.36it/s]


(val) 0.6097 (test) 0.6813



Epoch 13: 100%|██████████| 828/828 [00:06<00:00, 137.45it/s]


(val) 0.6094 (test) 0.6809



Epoch 14: 100%|██████████| 828/828 [00:06<00:00, 136.08it/s]


(val) 0.6125 (test) 0.6959
🌸 New best epoch! 🌸



Epoch 15: 100%|██████████| 828/828 [00:06<00:00, 135.47it/s]


(val) 0.6160 (test) 0.6951



Epoch 16: 100%|██████████| 828/828 [00:06<00:00, 137.23it/s]


(val) 0.6478 (test) 0.7095
🌸 New best epoch! 🌸



Epoch 17: 100%|██████████| 828/828 [00:06<00:00, 137.07it/s]


(val) 0.6823 (test) 0.7265
🌸 New best epoch! 🌸



Epoch 18: 100%|██████████| 828/828 [00:06<00:00, 136.93it/s]


(val) 0.7187 (test) 0.7338
🌸 New best epoch! 🌸



Epoch 19: 100%|██████████| 828/828 [00:06<00:00, 135.05it/s]


(val) 0.8167 (test) 0.6885



Epoch 20: 100%|██████████| 828/828 [00:06<00:00, 135.95it/s]


(val) 0.8223 (test) 0.7119



Epoch 21: 100%|██████████| 828/828 [00:06<00:00, 136.34it/s]


(val) 0.8260 (test) 0.6963



Epoch 22: 100%|██████████| 828/828 [00:06<00:00, 136.40it/s]


(val) 0.8305 (test) 0.7377
🌸 New best epoch! 🌸



Epoch 23: 100%|██████████| 828/828 [00:06<00:00, 135.67it/s]


(val) 0.8349 (test) 0.7387
🌸 New best epoch! 🌸



Epoch 24: 100%|██████████| 828/828 [00:06<00:00, 136.17it/s]


(val) 0.8410 (test) 0.7963
🌸 New best epoch! 🌸



Epoch 25: 100%|██████████| 828/828 [00:06<00:00, 135.45it/s]


(val) 0.8378 (test) 0.7658



Epoch 26: 100%|██████████| 828/828 [00:06<00:00, 135.56it/s]


(val) 0.8364 (test) 0.7330



Epoch 27: 100%|██████████| 828/828 [00:06<00:00, 136.76it/s]


(val) 0.8300 (test) 0.7761



Epoch 28: 100%|██████████| 828/828 [00:06<00:00, 136.79it/s]


(val) 0.8371 (test) 0.7692



Epoch 29: 100%|██████████| 828/828 [00:06<00:00, 136.44it/s]


(val) 0.8385 (test) 0.7875



Epoch 30: 100%|██████████| 828/828 [00:06<00:00, 135.98it/s]


(val) 0.8457 (test) 0.7610



Epoch 31: 100%|██████████| 828/828 [00:06<00:00, 136.55it/s]


(val) 0.8455 (test) 0.7809



Epoch 32: 100%|██████████| 828/828 [00:06<00:00, 136.63it/s]


(val) 0.8479 (test) 0.7747



Epoch 33: 100%|██████████| 828/828 [00:06<00:00, 136.53it/s]


(val) 0.8389 (test) 0.7795



Epoch 34: 100%|██████████| 828/828 [00:06<00:00, 136.06it/s]


(val) 0.8485 (test) 0.8105
🌸 New best epoch! 🌸



Epoch 35: 100%|██████████| 828/828 [00:06<00:00, 136.71it/s]


(val) 0.8555 (test) 0.7968



Epoch 36: 100%|██████████| 828/828 [00:06<00:00, 136.64it/s]


(val) 0.8356 (test) 0.7873



Epoch 37: 100%|██████████| 828/828 [00:06<00:00, 135.54it/s]


(val) 0.8437 (test) 0.7846



Epoch 38: 100%|██████████| 828/828 [00:06<00:00, 136.13it/s]


(val) 0.8511 (test) 0.7982



Epoch 39: 100%|██████████| 828/828 [00:06<00:00, 136.59it/s]


(val) 0.8564 (test) 0.7915



Epoch 40: 100%|██████████| 828/828 [00:06<00:00, 136.70it/s]


(val) 0.8420 (test) 0.7879



Epoch 41: 100%|██████████| 828/828 [00:06<00:00, 136.17it/s]


(val) 0.8477 (test) 0.8305
🌸 New best epoch! 🌸



Epoch 42: 100%|██████████| 828/828 [00:06<00:00, 136.86it/s]


(val) 0.8499 (test) 0.8070



Epoch 43: 100%|██████████| 828/828 [00:06<00:00, 136.76it/s]


(val) 0.8495 (test) 0.7926



Epoch 44: 100%|██████████| 828/828 [00:06<00:00, 136.62it/s]


(val) 0.8438 (test) 0.8062



Epoch 45: 100%|██████████| 828/828 [00:06<00:00, 136.04it/s]


(val) 0.8517 (test) 0.8153



Epoch 46: 100%|██████████| 828/828 [00:06<00:00, 136.18it/s]


(val) 0.8539 (test) 0.7828



Epoch 47: 100%|██████████| 828/828 [00:06<00:00, 134.28it/s]


(val) 0.8496 (test) 0.8256



Epoch 48: 100%|██████████| 828/828 [00:06<00:00, 135.25it/s]


(val) 0.8537 (test) 0.8298



Epoch 49: 100%|██████████| 828/828 [00:06<00:00, 135.38it/s]


(val) 0.8421 (test) 0.7973



Epoch 50: 100%|██████████| 828/828 [00:06<00:00, 135.90it/s]


(val) 0.8542 (test) 0.8068



Epoch 51: 100%|██████████| 828/828 [00:06<00:00, 136.50it/s]


(val) 0.8612 (test) 0.8117



Epoch 52: 100%|██████████| 828/828 [00:06<00:00, 135.95it/s]


(val) 0.8568 (test) 0.8219



Epoch 53: 100%|██████████| 828/828 [00:06<00:00, 136.54it/s]


(val) 0.8601 (test) 0.7949



Epoch 54: 100%|██████████| 828/828 [00:06<00:00, 135.90it/s]


(val) 0.8479 (test) 0.8037



Epoch 55: 100%|██████████| 828/828 [00:06<00:00, 135.77it/s]


(val) 0.8562 (test) 0.8115



Epoch 56: 100%|██████████| 828/828 [00:06<00:00, 136.00it/s]


(val) 0.8611 (test) 0.8246



Epoch 57: 100%|██████████| 828/828 [00:06<00:00, 136.35it/s]


(val) 0.8586 (test) 0.8100



Epoch 58: 100%|██████████| 828/828 [00:06<00:00, 136.40it/s]


(val) 0.8591 (test) 0.8170



Epoch 59: 100%|██████████| 828/828 [00:06<00:00, 135.61it/s]


(val) 0.8533 (test) 0.7983



Epoch 60: 100%|██████████| 828/828 [00:06<00:00, 135.98it/s]


(val) 0.8631 (test) 0.8092



Epoch 61: 100%|██████████| 828/828 [00:06<00:00, 136.46it/s]


(val) 0.8557 (test) 0.8320
🌸 New best epoch! 🌸



Epoch 62: 100%|██████████| 828/828 [00:06<00:00, 136.27it/s]


(val) 0.8584 (test) 0.8113



Epoch 63: 100%|██████████| 828/828 [00:06<00:00, 135.93it/s]


(val) 0.8549 (test) 0.7968



Epoch 64: 100%|██████████| 828/828 [00:06<00:00, 136.38it/s]


(val) 0.8570 (test) 0.7812



Epoch 65: 100%|██████████| 828/828 [00:06<00:00, 136.34it/s]


(val) 0.8566 (test) 0.8336
🌸 New best epoch! 🌸



Epoch 66: 100%|██████████| 828/828 [00:06<00:00, 136.44it/s]


(val) 0.8505 (test) 0.8173



Epoch 67: 100%|██████████| 828/828 [00:06<00:00, 135.93it/s]


(val) 0.8532 (test) 0.7834



Epoch 68: 100%|██████████| 828/828 [00:06<00:00, 136.48it/s]


(val) 0.8548 (test) 0.8236



Epoch 69: 100%|██████████| 828/828 [00:06<00:00, 136.40it/s]


(val) 0.8525 (test) 0.8133



Epoch 70: 100%|██████████| 828/828 [00:06<00:00, 136.49it/s]


(val) 0.8503 (test) 0.8325



Epoch 71: 100%|██████████| 828/828 [00:06<00:00, 136.43it/s]


(val) 0.8547 (test) 0.8291



Epoch 72: 100%|██████████| 828/828 [00:06<00:00, 136.33it/s]


(val) 0.8552 (test) 0.7803



Epoch 73: 100%|██████████| 828/828 [00:06<00:00, 136.38it/s]


(val) 0.8591 (test) 0.8426
🌸 New best epoch! 🌸



Epoch 74: 100%|██████████| 828/828 [00:06<00:00, 135.73it/s]


(val) 0.8564 (test) 0.7947



Epoch 75: 100%|██████████| 828/828 [00:06<00:00, 136.15it/s]


(val) 0.8525 (test) 0.8096



Epoch 76: 100%|██████████| 828/828 [00:06<00:00, 134.77it/s]


(val) 0.8533 (test) 0.8086



Epoch 77: 100%|██████████| 828/828 [00:06<00:00, 135.49it/s]


(val) 0.8557 (test) 0.8322



Epoch 78: 100%|██████████| 828/828 [00:06<00:00, 134.51it/s]


(val) 0.8612 (test) 0.7918



Epoch 79: 100%|██████████| 828/828 [00:06<00:00, 135.76it/s]


(val) 0.8583 (test) 0.8242



Epoch 80: 100%|██████████| 828/828 [00:06<00:00, 135.04it/s]


(val) 0.8516 (test) 0.8175



Epoch 81: 100%|██████████| 828/828 [00:06<00:00, 135.56it/s]


(val) 0.8573 (test) 0.8182



Epoch 82: 100%|██████████| 828/828 [00:06<00:00, 134.92it/s]


(val) 0.8475 (test) 0.8077



Epoch 83: 100%|██████████| 828/828 [00:06<00:00, 135.26it/s]


(val) 0.8586 (test) 0.8309



Epoch 84: 100%|██████████| 828/828 [00:06<00:00, 135.27it/s]


(val) 0.8545 (test) 0.8299



Epoch 85: 100%|██████████| 828/828 [00:06<00:00, 133.70it/s]


(val) 0.8596 (test) 0.8217



Epoch 86:  75%|███████▌  | 625/828 [00:04<00:01, 134.65it/s]


KeyboardInterrupt: 

In [9]:
from sklearn.metrics import (
    accuracy_score,
    classification_report,
    confusion_matrix,
    f1_score
)

In [10]:
# Define the model
arch_type = 'tabm'
bins = None

# arch_type = 'tabm-mini'
# bins = rtdl_num_embeddings.compute_bins(data['train']['x_cont'])

model = Model(
    n_num_features=n_cont_features,
    cat_cardinalities=[],
    n_classes=n_classes,
    backbone={
        'type': 'MLP',
        'n_blocks': 3 if bins is None else 2,
        'd_block': 256,
        'dropout': 0.2,
        'n_blocks': 5
    },
    bins=bins,
    num_embeddings=(
        None
        if bins is None
        else {
            'type': 'PiecewiseLinearEmbeddings',
            'd_embedding': 16,
            'activation': False,
            'version': 'B',
        }
    ),
    arch_type=arch_type,
    k=48,
    share_training_batches=True,
).to(device)

In [11]:
@torch.autocast(device.type, enabled=amp_enabled, dtype=amp_dtype)  # type: ignore[code]
def apply_model(part: str, idx: Tensor) -> Tensor:
    return (
        model(
            data[part]['x_cont'][idx],
            data[part]['x_cat'][idx] if 'x_cat' in data[part] else None,
        )
        .squeeze(-1)  # Remove the last dimension for regression tasks.
        .float()
    )

In [12]:
# Inference on the test dataset
model.load_state_dict(torch.load('/home/appuser/src/visualization/tabm_no_frequ_features_extended_more_weight.pth'))
model.eval()

part = "test"

eval_batch_size = 12
# y_pred: np.ndarray = (
#     torch.cat(
#         [
#             apply_model(part, idx).cpu()
#             for idx in torch.arange(len(data[part]['y']), device=device).split(
#                 eval_batch_size
#             )
#         ]
#     )
#     .numpy()
# )

y_pred_list = []
for idx in tqdm(torch.arange(len(data[part]['y']), device=device).split(eval_batch_size)):
    with torch.no_grad():
        preds = apply_model(part, idx).cpu()
        probs = scipy.special.softmax(preds.numpy(), axis=-1)
        averaged = probs.mean(1)  # shape: [B, C]
        preds_class = np.argmax(averaged, axis=1)
        y_pred_list.append(preds_class)

y_pred = np.concatenate(y_pred_list)

y_test = data[part]['y'].cpu().numpy()

# Overall accuracy
accuracy = accuracy_score(y_test, y_pred)
print(f"Accuracy: {accuracy:.4f}")

# Classification report (includes precision, recall, F1 per class + macro/micro)
report = classification_report(y_test, y_pred, output_dict=True)
report_df = pd.DataFrame(report).transpose()
print("\nClassification Report:\n", report_df)

# F1 scores
f1_micro = f1_score(y_test, y_pred, average='micro')
f1_macro = f1_score(y_test, y_pred, average='macro')
print(f"\nF1 (Micro): {f1_micro:.4f}")
print(f"F1 (Macro): {f1_macro:.4f}")

# Class-wise accuracy (same as recall per class)
class_wise_accuracy = report_df.loc[[str(i) for i in np.unique(y_test)], "recall"]
print("\nClass-wise Accuracy (Recall):\n", class_wise_accuracy)

# Confusion Matrix
conf_matrix = confusion_matrix(y_test, y_pred)
print("\nConfusion Matrix:\n", conf_matrix)

100%|██████████| 10834/10834 [00:06<00:00, 1695.56it/s]


Accuracy: 0.9909

Classification Report:
               precision    recall  f1-score        support
0              0.995859  0.995194  0.995526  125892.000000
1              0.906769  0.930537  0.918499    3052.000000
2              0.649669  0.652133  0.650899    1055.000000
accuracy       0.990892  0.990892  0.990892       0.990892
macro avg      0.850766  0.859288  0.854975  129999.000000
weighted avg   0.990958  0.990892  0.990921  129999.000000

F1 (Micro): 0.9909
F1 (Macro): 0.8550

Class-wise Accuracy (Recall):
 0    0.995194
1    0.930537
2    0.652133
Name: recall, dtype: float64

Confusion Matrix:
 [[125287    270    335]
 [   176   2840     36]
 [   345     22    688]]
