# TabM

This is a standalone usage example for the TabM project.
The easiest way to run it is [Pixi](https://pixi.sh/latest/#installation):

```shell
git clone https://github.com/yandex-research/tabm
cd tabm

# With GPU:
pixi run -e cuda jupyter-lab example.ipynb

# Without GPU:
pixi run jupyter-lab example.ipynb
```

For the full overview of the project, and for non-Pixi environment setups, see README in the repository:
https://github.com/yandex-research/tabm

In [28]:
# ruff: noqa: E402
import math
import random
import warnings
from typing import Literal, NamedTuple

import os
import numpy as np
import rtdl_num_embeddings  # https://github.com/yandex-research/rtdl-num-embeddings
import scipy.special
import sklearn.datasets
import sklearn.metrics
import sklearn.model_selection
import sklearn.preprocessing
from sklearn.preprocessing import MultiLabelBinarizer
import torch
import torch.nn.functional as F
import torch.optim
from torch import Tensor
from tqdm.std import tqdm
from torch.utils.data import Dataset
import pandas as pd
from imblearn.over_sampling import SMOTE, BorderlineSMOTE
from collections import Counter

warnings.simplefilter('ignore')
from tabm_reference import Model, make_parameter_groups

warnings.resetwarnings()

In [29]:
seed = 0
random.seed(seed)
np.random.seed(seed + 1)
torch.manual_seed(seed + 2)
pass

# Dataset

In [41]:
# Adding The Inferred Attack Code columns and saving it, so we don't need to wait through it mutliple times
CAT_TO_NUM_LABELS = {
    "Normal traffic": 0,
    "Suspicious traffic": 1,
    "DDoS attack": 2,
}

component_columns = [
    "Attack ID", "Detect count", "Card", "Victim IP", "Port number",
    "Attack code", "Significant flag", "Packet speed", "Data speed", "Avg packet len",
    "Source IP count", "Time"
]

event_columns = [
    "Attack ID", "Card", "Victim IP", "Port number", "Attack code", 
    "Detect count", "Significant flag", "Packet speed", "Data speed", 
    "Avg packet len", "Avg source IP count", "Start time", "End time", 
    "Whitelist flag", "Type"
]

class DDoSDataset(Dataset):
    def __init__(self, split):
        self.train_data_paths = [f'/home/appuser/data/train/SCLDDoS2024_SetA_events_extended.csv',
                                 f'/home/appuser/data/train/SCLDDoS2024_SetB_events_extended.csv']
        self.test_data_paths = [f'/home/appuser/data/test/SCLDDoS2024_SetC_events_extended.csv']     
        
        self.split = split   
        
        if split == 'train':
            self.load_data(self.train_data_paths, apply_smote=False)
        elif split == 'test':
            self.load_data(self.test_data_paths, apply_smote=False)
        else:
            print("Invalid split. Use 'train' or 'test'")
            
    
    def get_ports(self):
        return self.ddos_ports
    
    
    def get_data(self):
        return self.features.numpy(), self.lables.numpy()
    
    def engineer_features_from_components(self, df_components):

        grouped = df_components.groupby('Attack ID')

        features = pd.DataFrame()

        features['Unique Ports'] = grouped['Port number'].nunique()
        features['Unique Victim IPs'] = grouped['Victim IP'].nunique()

        return features.reset_index()
    
    # Function to infer attack codes for a full attack group
    def infer_attack_code_row(self, row):
        codes = set()

        # CHARGEN:
        if row["Packet speed"] > 500000 and row["Data speed"] > 400 and row["Port number"] == 443:
            codes.add("CHARGEN")
            
        # CLDAP:
        if row["Detect count"] >= 10 and row["Data speed"] > 400 and row["Port number"] in [389, 53,80,443,0]:
            codes.add("CLDAP")
            
        # CoAP: I don't see any indicators for this
        
        # # DNS: notghing specific but we can use the port number
        # if row["Port number"] in [53, 443] and row["Data speed"] < 30:
        #     codes.add("DNS")
            
        # Generic UDP:
        if row["Data speed"] < 20 and row["Port number"] in [0,80, 56, 5656, 4500]:
            codes.add("Generic UDP")  
            
        # IPV4 fragmentation:
        if row["Packet speed"] > 1000000 and row["Data speed"] > 1000 and row["Port number"] in [0,80,443]:
            codes.add("IPV4 fragmentation")
            
        # NTP: I don't see anything
        
        # RDP: same
        
        # RPC: same
        
        # SNMP: same
        
        # SSDP: same
        
        # SYN Attack:
        if row["Data speed"] <= 10 and row["Avg packet len"] <= 10 and row["Port number"] in [80,11,22, 443, 0]:
            codes.add("SYN Attack")
            
        # Sentinel:
        if row["Packet speed"] < 30000 and row["Data speed"] < 10 and row["Port number"] == 0:
            codes.add("Sentinel")
            
        # TCP Anomaly:
        if row["Avg packet len"] == 0:
            codes.add("TCP Anomaly")

        return "; ".join(sorted(codes)) if codes else "Unknown"
        
    # preload the data as it makes the training much faster (and it easily fits in memory)
    def load_data(self, data_paths, apply_smote=False, undersample=False, sample_factor=4, add_features=True):
        data = []
        component_data = []
        
        for path in data_paths:
            event_df = pd.read_csv(path).fillna(0)
            data.append(event_df)

            # Attempt to load corresponding component file
            comp_path = path.replace('_events_extended.csv', '_components.csv')
            ref_event_path = path.replace('_events_extended.csv', '_events.csv')
            if os.path.exists(comp_path) and add_features:
                # Load event data
                ref_ev_df = pd.read_csv(ref_event_path).fillna(0)
                ref_ev_df.columns = event_columns

                # Filter out invalid 'Attack ID's based on 'End time'
                ref_ev_df2 = ref_ev_df[ref_ev_df['End time'].astype(str) != '0']
                invalid_attack_ids = ref_ev_df[ref_ev_df['End time'].astype(str) == '0']['Attack ID'].unique()

                # Filter the event data by removing rows with invalid 'Attack ID's
                valid_attack_ids = ref_ev_df2['Attack ID'].unique()  # Attack IDs present in valid events

                # Load component data
                component_df = pd.read_csv(comp_path).fillna(0)
                component_df.columns = component_columns

                # Remove invalid attack IDs from component data
                component_df = component_df[~component_df['Attack ID'].isin(invalid_attack_ids)]
                
                # Now filter component data to only include 'Attack ID's present in valid events
                component_df = component_df[component_df['Attack ID'].isin(valid_attack_ids)]

                # Append the filtered component data
                component_data.append(component_df)
            else:
                print(f"Component file not found: {comp_path}")
        
        df = pd.concat(data, ignore_index=True)
        
        if component_data:
            # df_components = pd.concat(component_data, ignore_index=True)
            
            # attack_id_to_code = (
            #     df_components.groupby("Attack ID")
            #     .apply(self.infer_attack_code_group)
            #     .rename("Inferred Attack Code")
            # )
            
            # Merge back the inferred attack code to all component rows
            #df_components = df_components.merge(attack_id_to_code, on="Attack ID")
            
            #comp_features = self.engineer_features_from_components(df_components)
            df["Inferred Attack Code"] = df.apply(self.infer_attack_code_row, axis=1)
            # df = pd.concat([df, attack_id_to_code], axis=1)
            # df = df.drop(columns=['Attack ID'])
            cols = list(df.columns)
            cols[-2], cols[-1] = cols[-1], cols[-2]
            df = df[cols]
            df = df.dropna(how='all')
            
        # Save the dataframes to a CSV file
        if self.split == 'train':
            df.to_csv('/home/appuser/data/train/A_B_inferred_attack_code.csv', index=False)
        elif self.split == 'test':
            df.to_csv('/home/appuser/data/test/C_inferred_attack_code.csv', index=False)
       
train = DDoSDataset('train')
test = DDoSDataset('test')            
    

In [42]:
CAT_TO_NUM_LABELS = {
    "Normal traffic": 0,
    "Suspicious traffic": 1,
    "DDoS attack": 2,
}

component_columns = [
    "Attack ID", "Detect count", "Card", "Victim IP", "Port number",
    "Attack code", "Significant flag", "Packet speed", "Data speed", "Avg packet len",
    "Source IP count", "Time"
]

event_columns = [
    "Attack ID", "Card", "Victim IP", "Port number", "Attack code", 
    "Detect count", "Significant flag", "Packet speed", "Data speed", 
    "Avg packet len", "Avg source IP count", "Start time", "End time", 
    "Whitelist flag", "Type"
]

attack_code_ports = {
    "ACK Attack": [],
    "CHARGEN": [19],
    "CLDAP": [389],
    "CoAP": [5683],
    "DNS": [53],
    "Generic UDP": [],
    "IPV4 fragmentation": [],
    "RDP": [3389],
    "RPC": [111, 135],
    "SNMP": [161, 162],
    "SYN Attack": [],
    "TCP Anomaly": [],
    "NTP": [123],
    "SSDP": [1900],
    "Sentinel": [],
    "Memcached": [11211],
    "RIP": [520],
    "TFTP": [69],
    "WSD": [3702],
}

# Port to attack mapping for fast lookup
port_to_attack = {}
for attack, ports in attack_code_ports.items():
    for port in ports:
        port_to_attack.setdefault(port, []).append(attack)

class DDoSDataset(Dataset):
    def __init__(self, split, use_inferred_atck_code):
        self.train_data_paths = [f'/home/appuser/data/train/SCLDDoS2024_SetA_events_extended.csv',
                                 f'/home/appuser/data/train/SCLDDoS2024_SetB_events_extended.csv']
        self.test_data_paths = [f'/home/appuser/data/test/SCLDDoS2024_SetC_events_extended.csv']     
        
        self.inferred_train_data_paths = [f'/home/appuser/data/train/A_B_inferred_attack_code.csv']
        self.inferred_test_data_paths = [f'/home/appuser/data/test/C_inferred_attack_code.csv']
        
        self.split = split   
        self.use_inferred_atck_code = use_inferred_atck_code
        
        if split == 'train':
            if use_inferred_atck_code:
                self.features, self.lables = self.load_data(self.inferred_train_data_paths, apply_smote=False)
            else:
                self.features, self.lables = self.load_data(self.train_data_paths, apply_smote=False)            
        elif split == 'test':
            if use_inferred_atck_code:
                self.features, self.lables = self.load_data(self.inferred_test_data_paths, apply_smote=False)
            else:
                self.features, self.lables = self.load_data(self.test_data_paths, apply_smote=False)
        else:
            print("Invalid split. Use 'train' or 'test'")
            
    
    def get_ports(self):
        return self.ddos_ports
    
    
    def get_data(self):
        return self.features.numpy(), self.lables.numpy()
        
    # preload the data as it makes the training much faster (and it easily fits in memory)
    def load_data(self, data_paths, apply_smote=False, undersample=False, sample_factor=4):
        data = []
        
        for path in data_paths:
            event_df = pd.read_csv(path).fillna(0)
            data.append(event_df)

            # Attempt to load corresponding component file
            comp_path = path.replace('_events_extended.csv', '_components.csv')
            ref_event_path = path.replace('_events_extended.csv', '_events.csv')
        
        df = pd.concat(data, ignore_index=True)
            
        
        self.ddos_ports = df[df['Type'] == "DDoS attack"]["Port number"].unique()
        
        #feature_columns = df.columns[:19]  # All except the last column
        
        # comment this line I u need the frequency features
        df = df.loc[:, ~df.columns.str.contains('Ac')]
        feature_columns = df.columns[:-1]
        label_column = df.columns[-1]  # The last column
        
        # Convert categorical labels to numeric using the dictionary
        df[label_column] = df[label_column].map(CAT_TO_NUM_LABELS)
        
        # Check for missing or unknown labels
        if df[label_column].isna().any():
            print(df[label_column].isna().sum(), "missing labels")

        X = df[feature_columns]
        y = df[label_column]
        
        # one-hot encode the categorical features
        if self.use_inferred_atck_code:
            split_labels = X['Inferred Attack Code'].str.split('; ')
            mlb = MultiLabelBinarizer()
            attack_code_ohe = pd.DataFrame(mlb.fit_transform(split_labels),
                                        columns=[f"Inferred Attack Code_{cls}" for cls in mlb.classes_],
                                        index=X.index)

            # Drop the original column and join the new one-hot encoded columns
            X = X.drop(columns=['Inferred Attack Code'])
            X = X.join(attack_code_ohe)
            
            
            #X = pd.get_dummies(X, columns=["Inferred Attack Code"], drop_first=False)        
        
        self.columns = X.columns
        
        cols = list(X.columns)
        
        
        # Normalize the features
        #features = self.normalize(features)
        
        # Convert to PyTorch tensors
        if apply_smote and self.split == 'train':
            class_counts = Counter(y)
            sm = SMOTE(sampling_strategy={0: class_counts[0], 1:50000, 2:50000}, random_state=42)
            sm = SMOTE(sampling_strategy='not majority', random_state=42)
            sm = BorderlineSMOTE(sampling_strategy='not majority', random_state=42)
            X_resampled, y_resampled = sm.fit_resample(X.values, y.values)
            features = torch.tensor(X_resampled, dtype=torch.float32)
            labels = torch.tensor(y_resampled, dtype=torch.long)
        else:
            features = torch.tensor(X.values, dtype=torch.float32)
            labels = torch.tensor(y.values, dtype=torch.long)
        
        
        
        if undersample and self.split == 'train':
            # Undersample the majority class (label=0)
            class_0_indices = np.where(labels.cpu().numpy() == 0)[0]
            class_1_indices = np.where(labels.cpu().numpy() == 1)[0]
            class_2_indices = np.where(labels.cpu().numpy() == 2)[0]

            # Randomly undersample the majority class
            num_class_0_samples = sample_factor*(len(class_1_indices) + len(class_2_indices))  # Same number as the minority class
            class_0_indices_undersampled = np.random.choice(class_0_indices, num_class_0_samples, replace=False)

            # Concatenate indices of class 1, 2, and undersampled class 0
            undersampled_indices = np.concatenate([class_0_indices_undersampled, class_1_indices, class_2_indices])

            # Subset the dataset to include only the sampled indices
            features = features[undersampled_indices]
            labels = labels[undersampled_indices]

        return features, labels


    def __len__(self):
        return len(self.features)

    def __getitem__(self, idx):
        x = self.features[idx]
        y = self.lables[idx]

        return x, y  # y is class index (long)

In [32]:
dataset = DDoSDataset('test', use_inferred_atck_code=True)

In [43]:
# >>> Dataset.
TaskType = Literal['regression', 'binclass', 'multiclass']

# Regression.
task_type: TaskType = 'multiclass'
n_classes = None
dataset = sklearn.datasets.fetch_california_housing()
X_cont: np.ndarray = dataset['data']
Y: np.ndarray = dataset['target']

dataset = DDoSDataset(split='train', use_inferred_atck_code=True)
X_cont = dataset.features.numpy()
Y = dataset.lables.numpy()


# Classification.
n_classes = 3
assert n_classes >= 2
task_type: TaskType = 'binclass' if n_classes == 2 else 'multiclass'
# X_cont, Y = sklearn.datasets.make_classification(
#     n_samples=20000,
#     n_features=8,
#     n_classes=n_classes,
#     n_informative=3,
#     n_redundant=2,
# )

#task_is_regression = task_type == 'regression'

# >>> Continuous features.
X_cont: np.ndarray = X_cont.astype(np.float32)
n_cont_features = X_cont.shape[1]

# # >>> Categorical features.
# # NOTE: the above datasets do not have categorical features, however,
# # for the demonstration purposes, it is possible to generate them.
# cat_cardinalities = [
#     # NOTE: uncomment the two lines below to add two categorical features.
#     # 4,  # Allowed values: [0, 1, 2, 3].
#     # 7,  # Allowed values: [0, 1, 2, 3, 4, 5, 6].
# ]
# X_cat = (
#     np.column_stack(
#         [np.random.randint(0, c, (len(X_cont),)) for c in cat_cardinalities]
#     )
#     if cat_cardinalities
#     else None
# )

# >>> Labels.
if task_type == 'regression':
    Y = Y.astype(np.float32)
else:
    assert n_classes is not None
    Y = Y.astype(np.int64)
    assert set(Y.tolist()) == set(
        range(n_classes)
    ), 'Classification labels must form the range [0, 1, ..., n_classes - 1]'

# >>> Split the dataset.
all_idx = np.arange(len(Y))
train_idx, val_idx = sklearn.model_selection.train_test_split(
    all_idx, train_size=0.8
)

data_numpy = {
    'train': {'x_cont': X_cont[train_idx], 'y': Y[train_idx]},
    'val': {'x_cont': X_cont[val_idx], 'y': Y[val_idx]},
}

s = 0
# if X_cat is not None:
#     data_numpy['train']['x_cat'] = X_cat[train_idx]
#     data_numpy['val']['x_cat'] = X_cat[val_idx]
#     data_numpy['test']['x_cat'] = X_cat[test_idx]

In [44]:
# Add the Dataset C as the testing class
dataset = DDoSDataset(split='test', use_inferred_atck_code=True)
X_cont = dataset.features.numpy()
Y = dataset.lables.numpy()


# Classification.
n_classes = 3
assert n_classes >= 2
task_type: TaskType = 'binclass' if n_classes == 2 else 'multiclass'
# X_cont, Y = sklearn.datasets.make_classification(
#     n_samples=20000,
#     n_features=8,
#     n_classes=n_classes,
#     n_informative=3,
#     n_redundant=2,
# )

#task_is_regression = task_type == 'regression'

# >>> Continuous features.
X_cont: np.ndarray = X_cont.astype(np.float32)
n_cont_features = X_cont.shape[1]

# >>> Labels.
if task_type == 'regression':
    Y = Y.astype(np.float32)
else:
    assert n_classes is not None
    Y = Y.astype(np.int64)
    assert set(Y.tolist()) == set(
        range(n_classes)
    ), 'Classification labels must form the range [0, 1, ..., n_classes - 1]'
    
    
data_numpy['test'] = {'x_cont': X_cont, 'y': Y}

# Data preprocessing

In [47]:
# Feature preprocessing.
# NOTE
# The choice between preprocessing strategies depends on a task and a model.

# Simple preprocessing strategy.
# preprocessing = sklearn.preprocessing.StandardScaler().fit(
#     data_numpy[part]['x_cont'][:, continuous_feature_indices]
# )

# Advanced preprocessing strategy.
# The noise is added to improve the output of QuantileTransformer in some cases.

# own
port_columns = [
    'DNS', 'NTP', 'SNMP', 'SSDP', 'CLDAP', 'QUIC', 'RDP', 'CoAP', 
    'HTTP Flood', 'FTP', 'SSH', 'Memcached', 'WS-DD', 'NetBIOS', 'Kubernetes'
]

# gpt
# port_columns = [
#     'DNS', 'RDP', 'TCP', 'SYN', 'UDP', 'CoAP', 'Attack Ports'
# ]

continuous_feature_indices = [
    i for i, col in enumerate(dataset.columns)
    if not col.startswith("Inferred Attack Code_")
]


X_cont_train_numpy = data_numpy['train']['x_cont'][:, continuous_feature_indices]
# for testing the old model
#X_cont_train_numpy = data_numpy['train']['x_cont']
noise = (
    np.random.default_rng(0)
    .normal(0.0, 1e-5, X_cont_train_numpy.shape)
    .astype(X_cont_train_numpy.dtype)
)
preprocessing = sklearn.preprocessing.QuantileTransformer(
    n_quantiles=max(min(len(train_idx) // 30, 1000), 10),
    output_distribution='normal',
    subsample=10**9,
).fit(X_cont_train_numpy + noise)
del X_cont_train_numpy

# for the old model
# Apply the preprocessing to the training, test and validation sets.
# for part in data_numpy:
#     # Transform only continuous features
#     data_numpy[part]['x_cont'] = preprocessing.transform(
#         data_numpy[part]['x_cont']
#     )
#     # Leave port features unchanged

# Apply the preprocessing to the training, test and validation sets.
for part in data_numpy:
    # Transform only continuous features
    data_numpy[part]['x_cont'][:, continuous_feature_indices] = preprocessing.transform(
        data_numpy[part]['x_cont'][:, continuous_feature_indices]
    )
    # Leave port features unchanged

# Label preprocessing.
class RegressionLabelStats(NamedTuple):
    mean: float
    std: float


Y_train = data_numpy['train']['y'].copy()
if task_type == 'regression':
    # For regression tasks, it is highly recommended to standardize the training labels.
    regression_label_stats = RegressionLabelStats(
        Y_train.mean().item(), Y_train.std().item()
    )
    Y_train = (Y_train - regression_label_stats.mean) / regression_label_stats.std
else:
    regression_label_stats = None
    
regression_label_stats = None
Y_train = data_numpy['train']['y'].copy()
    

#  PyTorch settings

In [49]:
regression_label_stats = None
Y_train = data_numpy['train']['y'].copy()

# Device
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# Convert data to tensors
data = {
    part: {k: torch.as_tensor(v, device=device) for k, v in data_numpy[part].items()}
    for part in data_numpy
}
Y_train = torch.as_tensor(Y_train, device=device)
if task_type == 'regression':
    for part in data:
        data[part]['y'] = data[part]['y'].float()
    Y_train = Y_train.float()

# Automatic mixed precision (AMP)
# torch.float16 is implemented for completeness,
# but it was not tested in the project,
# so torch.bfloat16 is used by default.
amp_dtype = (
    torch.bfloat16
    if torch.cuda.is_available() and torch.cuda.is_bf16_supported()
    else torch.float16
    if torch.cuda.is_available()
    else None
)
# Changing False to True will result in faster training on compatible hardware.
amp_enabled = False and amp_dtype is not None
grad_scaler = torch.cuda.amp.GradScaler() if amp_dtype is torch.float16 else None  # type: ignore

# torch.compile
compile_model = False

# fmt: off
print(
    f'Device:        {device.type.upper()}'
    f'\nAMP:           {amp_enabled} (dtype: {amp_dtype})'
    f'\ntorch.compile: {compile_model}'
)
# fmt: on

Device:        CUDA
AMP:           False (dtype: torch.bfloat16)
torch.compile: False


# Model

In [53]:
# Choose one of the two configurations below.

# Define the model
arch_type = 'tabm'
bins = None

# arch_type = 'tabm-mini'
# bins = rtdl_num_embeddings.compute_bins(data['train']['x_cont'])

model = Model(
    n_num_features=data['train']['x_cont'].shape[1],
    cat_cardinalities=[],
    n_classes=n_classes,
    backbone={
        'type': 'MLP',
        'n_blocks': 3 if bins is None else 2,
        'd_block': 256,
        'dropout': 0.2,
        'n_blocks': 5
    },
    bins=bins,
    num_embeddings=(
        None
        if bins is None
        else {
            'type': 'PiecewiseLinearEmbeddings',
            'd_embedding': 16,
            'activation': False,
            'version': 'B',
        }
    ),
    arch_type=arch_type,
    k=48,
    share_training_batches=True,
).to(device)
optimizer = torch.optim.AdamW(make_parameter_groups(model), lr=1e-3, weight_decay=3e-4)
# regi: 1e-3
if compile_model:
    # NOTE
    # `torch.compile` is intentionally called without the `mode` argument
    # (mode="reduce-overhead" caused issues during training with torch==2.0.1).
    model = torch.compile(model)
    evaluation_mode = torch.no_grad
else:
    evaluation_mode = torch.inference_mode

In [54]:
@torch.autocast(device.type, enabled=amp_enabled, dtype=amp_dtype)  # type: ignore[code]
def apply_model(part: str, idx: Tensor) -> Tensor:
    return (
        model(
            data[part]['x_cont'][idx],
            data[part]['x_cat'][idx] if 'x_cat' in data[part] else None,
        )
        .squeeze(-1)  # Remove the last dimension for regression tasks.
        .float()
    )


base_loss_fn = F.mse_loss if task_type == 'regression' else F.cross_entropy
#weight = torch.tensor([1.0, 1.0, 1.0], device=device)


def loss_fn(y_pred: Tensor, y_true: Tensor, weight: float = None) -> Tensor:
    # TabM produces k predictions. Each of them must be trained separately.
    # (regression)     y_pred.shape == (batch_size, k)
    # (classification) y_pred.shape == (batch_size, k, n_classes)
    k = y_pred.shape[-1 if task_type == 'regression' else -2]
    return base_loss_fn(
        y_pred.flatten(0, 1),
        y_true.repeat_interleave(k) if model.share_training_batches else y_true,
        weight=weight,
    )


@evaluation_mode()
def evaluate(part: str) -> float:
    model.eval()

    # When using torch.compile, you may need to reduce the evaluation batch size.
    eval_batch_size = 8096
    y_pred: np.ndarray = (
        torch.cat(
            [
                apply_model(part, idx)
                for idx in torch.arange(len(data[part]['y']), device=device).split(
                    eval_batch_size
                )
            ]
        )
        .cpu()
        .numpy()
    )
    if task_type == 'regression':
        # Transform the predictions back to the original label space.
        assert regression_label_stats is not None
        y_pred = y_pred * regression_label_stats.std + regression_label_stats.mean

    # Compute the mean of the k predictions.
    if task_type != 'regression':
        # For classification, the mean must be computed in the probabily space.
        y_pred = scipy.special.softmax(y_pred, axis=-1)
    y_pred = y_pred.mean(1)

    y_true = data[part]['y'].cpu().numpy()
    score = (
        -(sklearn.metrics.mean_squared_error(y_true, y_pred) ** 0.5)
        if task_type == 'regression'
        else sklearn.metrics.f1_score(y_true, y_pred.argmax(1), average='macro')
    )
    return float(score)  # The higher -- the better.


print(f'Test score before training: {evaluate("test"):.4f}')

Test score before training: 0.0054


# Training

In [55]:
# Tensorboard for training
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime


run_name = "tabm_no_frequency_features_inferred_attack_code"
tb_log_dir = f'/home/appuser/src/logs/TabM/{run_name}'
writer = SummaryWriter(log_dir=tb_log_dir)

n_epochs = 1000

loss_weight = torch.tensor([1.0, 1.0, 3.0], device=device)

train_size = len(train_idx)
batch_size = 256
epoch_size = math.ceil(train_size / batch_size)
best = {
    'val': -math.inf,
    'test': -math.inf,
    'epoch': -1,
}
# Early stopping: the training stops when
# there are more than `patience` consequtive bad updates.
patience = 1000
remaining_patience = patience



print('-' * 88 + '\n')
for epoch in range(n_epochs):
    batches = (
        torch.randperm(train_size, device=device).split(batch_size)
        if model.share_training_batches
        else [
            x.transpose(0, 1).flatten()
            for x in torch.rand((model.k, train_size), device=device)
            .argsort(dim=1)
            .split(batch_size, dim=1)
        ]
    )
    epoch_loss = 0.0
    epoch_num = 0
    for batch_idx in tqdm(batches, desc=f'Epoch {epoch}'):
        model.train()
        optimizer.zero_grad()
        loss = loss_fn(apply_model('train', batch_idx), Y_train[batch_idx], weight=loss_weight)
        epoch_loss += loss.item()
        epoch_num += 1
        if grad_scaler is None:
            loss.backward()
            optimizer.step()
        else:
            grad_scaler.scale(loss).backward()  # type: ignore
            grad_scaler.step(optimizer)
            grad_scaler.update()

    val_score = evaluate('val')
    test_score = evaluate('test')
    writer.add_scalar('validation_macro_F1', val_score, epoch+1)
    writer.add_scalar('test_macro_F1', test_score, epoch+1)
    writer.add_scalar('train_loss', epoch_loss/epoch_num, epoch+1)
    print(f'(val) {val_score:.4f} (test) {test_score:.4f}')

    if test_score > best['test']:
        print('🌸 New best epoch! 🌸')
        best = {'val': val_score, 'test': test_score, 'epoch': epoch}
        remaining_patience = patience
        torch.save(model.state_dict(), f"{run_name}.pth")
    else:
        remaining_patience -= 1

    if remaining_patience < 0:
        break

    print()
    if (epoch+1) % 30 == 0:
        loss_weight += torch.tensor([0.0, 0.0, 0.0], device=device)


print('\n\nResult:')
print(best)

----------------------------------------------------------------------------------------



Epoch 0: 100%|██████████| 828/828 [00:07<00:00, 113.69it/s]


(val) 0.7510 (test) 0.7353
🌸 New best epoch! 🌸



Epoch 1: 100%|██████████| 828/828 [00:07<00:00, 113.05it/s]


(val) 0.7890 (test) 0.7722
🌸 New best epoch! 🌸



Epoch 2: 100%|██████████| 828/828 [00:07<00:00, 113.09it/s]


(val) 0.8129 (test) 0.7184



Epoch 3: 100%|██████████| 828/828 [00:07<00:00, 113.83it/s]


(val) 0.8158 (test) 0.7351



Epoch 4: 100%|██████████| 828/828 [00:07<00:00, 112.94it/s]


(val) 0.8310 (test) 0.7617



Epoch 5: 100%|██████████| 828/828 [00:07<00:00, 113.19it/s]


(val) 0.8326 (test) 0.7475



Epoch 6: 100%|██████████| 828/828 [00:07<00:00, 113.69it/s]


(val) 0.8339 (test) 0.7761
🌸 New best epoch! 🌸



Epoch 7: 100%|██████████| 828/828 [00:07<00:00, 113.42it/s]


(val) 0.8320 (test) 0.7491



Epoch 8: 100%|██████████| 828/828 [00:07<00:00, 114.50it/s]


(val) 0.8317 (test) 0.7448



Epoch 9: 100%|██████████| 828/828 [00:07<00:00, 114.30it/s]


(val) 0.8361 (test) 0.7882
🌸 New best epoch! 🌸



Epoch 10: 100%|██████████| 828/828 [00:07<00:00, 114.17it/s]


(val) 0.8328 (test) 0.7286



Epoch 11: 100%|██████████| 828/828 [00:07<00:00, 114.28it/s]


(val) 0.8317 (test) 0.7568



Epoch 12: 100%|██████████| 828/828 [00:07<00:00, 113.71it/s]


(val) 0.8321 (test) 0.8173
🌸 New best epoch! 🌸



Epoch 13: 100%|██████████| 828/828 [00:07<00:00, 114.14it/s]


(val) 0.8312 (test) 0.7319



Epoch 14: 100%|██████████| 828/828 [00:07<00:00, 113.22it/s]


(val) 0.8367 (test) 0.7311



Epoch 15: 100%|██████████| 828/828 [00:07<00:00, 113.82it/s]


(val) 0.8388 (test) 0.8355
🌸 New best epoch! 🌸



Epoch 16: 100%|██████████| 828/828 [00:07<00:00, 113.27it/s]


(val) 0.8428 (test) 0.7643



Epoch 17: 100%|██████████| 828/828 [00:07<00:00, 114.01it/s]


(val) 0.8428 (test) 0.7677



Epoch 18: 100%|██████████| 828/828 [00:07<00:00, 114.44it/s]


(val) 0.8437 (test) 0.7368



Epoch 19: 100%|██████████| 828/828 [00:07<00:00, 113.94it/s]


(val) 0.8384 (test) 0.7439



Epoch 20: 100%|██████████| 828/828 [00:07<00:00, 114.00it/s]


(val) 0.8383 (test) 0.7483



Epoch 21: 100%|██████████| 828/828 [00:07<00:00, 114.11it/s]


(val) 0.8452 (test) 0.7925



Epoch 22: 100%|██████████| 828/828 [00:07<00:00, 114.22it/s]


(val) 0.8388 (test) 0.7670



Epoch 23: 100%|██████████| 828/828 [00:07<00:00, 113.17it/s]


(val) 0.8439 (test) 0.7988



Epoch 24: 100%|██████████| 828/828 [00:07<00:00, 113.27it/s]


(val) 0.8408 (test) 0.7604



Epoch 25: 100%|██████████| 828/828 [00:07<00:00, 113.58it/s]


(val) 0.8432 (test) 0.7915



Epoch 26: 100%|██████████| 828/828 [00:07<00:00, 113.56it/s]


(val) 0.8516 (test) 0.8294



Epoch 27: 100%|██████████| 828/828 [00:07<00:00, 113.34it/s]


(val) 0.8492 (test) 0.7582



Epoch 28: 100%|██████████| 828/828 [00:07<00:00, 113.32it/s]


(val) 0.8440 (test) 0.8307



Epoch 29: 100%|██████████| 828/828 [00:07<00:00, 113.83it/s]


(val) 0.8516 (test) 0.7951



Epoch 30: 100%|██████████| 828/828 [00:07<00:00, 112.80it/s]


(val) 0.8481 (test) 0.7924



Epoch 31: 100%|██████████| 828/828 [00:07<00:00, 113.48it/s]


(val) 0.8425 (test) 0.7912



Epoch 32: 100%|██████████| 828/828 [00:07<00:00, 113.64it/s]


(val) 0.8450 (test) 0.8147



Epoch 33: 100%|██████████| 828/828 [00:07<00:00, 113.56it/s]


(val) 0.8512 (test) 0.8388
🌸 New best epoch! 🌸



Epoch 34: 100%|██████████| 828/828 [00:07<00:00, 113.93it/s]


(val) 0.8567 (test) 0.8242



Epoch 35: 100%|██████████| 828/828 [00:07<00:00, 114.21it/s]


(val) 0.8448 (test) 0.7679



Epoch 36: 100%|██████████| 828/828 [00:07<00:00, 113.84it/s]


(val) 0.8516 (test) 0.8233



Epoch 37: 100%|██████████| 828/828 [00:07<00:00, 114.16it/s]


(val) 0.8474 (test) 0.7501



Epoch 38: 100%|██████████| 828/828 [00:07<00:00, 113.65it/s]


(val) 0.8555 (test) 0.8175



Epoch 39: 100%|██████████| 828/828 [00:07<00:00, 113.84it/s]


(val) 0.8550 (test) 0.8118



Epoch 40: 100%|██████████| 828/828 [00:07<00:00, 113.53it/s]


(val) 0.8512 (test) 0.8013



Epoch 41: 100%|██████████| 828/828 [00:07<00:00, 113.93it/s]


(val) 0.8411 (test) 0.7858



Epoch 42: 100%|██████████| 828/828 [00:07<00:00, 114.18it/s]


(val) 0.8429 (test) 0.7799



Epoch 43: 100%|██████████| 828/828 [00:07<00:00, 114.25it/s]


(val) 0.8489 (test) 0.8536
🌸 New best epoch! 🌸



Epoch 44: 100%|██████████| 828/828 [00:07<00:00, 114.27it/s]


(val) 0.8518 (test) 0.8118



Epoch 45: 100%|██████████| 828/828 [00:07<00:00, 113.47it/s]


(val) 0.8466 (test) 0.7712



Epoch 46: 100%|██████████| 828/828 [00:07<00:00, 115.06it/s]


(val) 0.8401 (test) 0.7478



Epoch 47: 100%|██████████| 828/828 [00:07<00:00, 113.46it/s]


(val) 0.8499 (test) 0.8401



Epoch 48: 100%|██████████| 828/828 [00:07<00:00, 113.44it/s]


(val) 0.8558 (test) 0.8157



Epoch 49: 100%|██████████| 828/828 [00:07<00:00, 113.68it/s]


(val) 0.8520 (test) 0.8052



Epoch 50: 100%|██████████| 828/828 [00:07<00:00, 113.90it/s]


(val) 0.8561 (test) 0.8045



Epoch 51: 100%|██████████| 828/828 [00:07<00:00, 113.54it/s]


(val) 0.8482 (test) 0.8045



Epoch 52: 100%|██████████| 828/828 [00:07<00:00, 113.45it/s]


(val) 0.8507 (test) 0.8222



Epoch 53: 100%|██████████| 828/828 [00:07<00:00, 113.15it/s]


(val) 0.8519 (test) 0.7936



Epoch 54: 100%|██████████| 828/828 [00:07<00:00, 112.89it/s]


(val) 0.8523 (test) 0.8218



Epoch 55: 100%|██████████| 828/828 [00:07<00:00, 110.88it/s]


(val) 0.8485 (test) 0.8004



Epoch 56: 100%|██████████| 828/828 [00:07<00:00, 111.12it/s]


(val) 0.8529 (test) 0.7683



Epoch 57: 100%|██████████| 828/828 [00:07<00:00, 109.57it/s]


(val) 0.8554 (test) 0.7767



Epoch 58: 100%|██████████| 828/828 [00:07<00:00, 107.25it/s]


(val) 0.8525 (test) 0.8025



Epoch 59: 100%|██████████| 828/828 [00:07<00:00, 110.97it/s]


(val) 0.8548 (test) 0.7514



Epoch 60: 100%|██████████| 828/828 [00:07<00:00, 108.00it/s]


(val) 0.8457 (test) 0.7769



Epoch 61: 100%|██████████| 828/828 [00:07<00:00, 109.68it/s]


(val) 0.8576 (test) 0.7782



Epoch 62: 100%|██████████| 828/828 [00:07<00:00, 113.02it/s]


(val) 0.8517 (test) 0.8058



Epoch 63: 100%|██████████| 828/828 [00:07<00:00, 112.27it/s]


(val) 0.8538 (test) 0.8206



Epoch 64: 100%|██████████| 828/828 [00:07<00:00, 113.13it/s]


(val) 0.8562 (test) 0.8032



Epoch 65: 100%|██████████| 828/828 [00:07<00:00, 113.70it/s]


(val) 0.8540 (test) 0.7945



Epoch 66:  94%|█████████▍| 777/828 [00:06<00:00, 112.55it/s]


KeyboardInterrupt: 

In [56]:
from sklearn.metrics import (
    accuracy_score,
    classification_report,
    confusion_matrix,
    f1_score
)

In [13]:
# PREV BEST SETUP

# Define the model
arch_type = 'tabm'
bins = None

# arch_type = 'tabm-mini'
# bins = rtdl_num_embeddings.compute_bins(data['train']['x_cont'])

model = Model(
    n_num_features=data['train']['x_cont'].shape[1],
    cat_cardinalities=[],
    n_classes=n_classes,
    backbone={
        'type': 'MLP',
        'n_blocks': 3 if bins is None else 2,
        'd_block': 256,
        'dropout': 0.2,
        'n_blocks': 5
    },
    bins=bins,
    num_embeddings=(
        None
        if bins is None
        else {
            'type': 'PiecewiseLinearEmbeddings',
            'd_embedding': 16,
            'activation': False,
            'version': 'B',
        }
    ),
    arch_type=arch_type,
    k=48,
    share_training_batches=True,
).to(device)

In [None]:
# DYNMIC WEIHGT SETUP
arch_type = 'tabm'
bins = None

# TabM-mini with the piecewise-linear embeddings.
# arch_type = 'tabm-mini'
# bins = rtdl_num_embeddings.compute_bins(data['train']['x_cont'])

# arch_type = 'tabm-packed'
# bins = rtdl_num_embeddings.compute_bins(data['train']['x_cont'])

# d_block: 512
# n_blocks: 3

model = Model(
    n_num_features=data['train']['x_cont'].shape[1],
    cat_cardinalities=[],
    n_classes=n_classes,
    backbone={
        'type': 'MLP',
        'n_blocks': 3 if bins is None else 2,
        'd_block': 1024,
        'dropout': 0.05,
        'n_blocks': 2
    },
    bins=bins,
    num_embeddings=(
        None
        if bins is None
        else {
            'type': 'PiecewiseLinearEmbeddings',
            'd_embedding': 32,
            'activation': False,
            'version': 'B',
        }
    ),
    arch_type=arch_type,
    k=48,
    share_training_batches=True,
).to(device)

In [11]:
data['train']['x_cont'].shape[1]

19

In [57]:
@torch.autocast(device.type, enabled=amp_enabled, dtype=amp_dtype)  # type: ignore[code]
def apply_model(part: str, idx: Tensor) -> Tensor:
    return (
        model(
            data[part]['x_cont'][idx],
            data[part]['x_cat'][idx] if 'x_cat' in data[part] else None,
        )
        .squeeze(-1)  # Remove the last dimension for regression tasks.
        .float()
    )

In [58]:
from time import time   

# Inference on the test dataset
#model.load_state_dict(torch.load('/home/appuser/src/visualization/tabm_no_frequ_features_extended_more_weight.pth'))
model.load_state_dict(torch.load(f"{run_name}.pth"))
model.eval()

part = "test"

eval_batch_size = 12
# y_pred: np.ndarray = (
#     torch.cat(
#         [
#             apply_model(part, idx).cpu()
#             for idx in torch.arange(len(data[part]['y']), device=device).split(
#                 eval_batch_size
#             )
#         ]
#     )
#     .numpy()
# )

num_params = sum(p.numel() for p in model.parameters())
print(f"Total parameters: {num_params}")

# model.cpu()

# single_sample = data[part]['x_cont'][0].unsqueeze(0)
# for single_sample in data[part]['x_cont']:
#     single_sample = single_sample.unsqueeze(0)
#     start = time()
#     preds = model(single_sample.cpu())
#     duration = (time() - start)*10**3
#     print(f"Time taken for single sample: {duration:.2f} ms")

y_pred_list = []
for idx in tqdm(torch.arange(len(data[part]['y']), device=device).split(eval_batch_size)):
    with torch.no_grad():
        preds = apply_model(part, idx).cpu()
        probs = scipy.special.softmax(preds.numpy(), axis=-1)
        averaged = probs.mean(1)  # shape: [B, C]
        preds_class = np.argmax(averaged, axis=1)
        y_pred_list.append(preds_class)

y_pred = np.concatenate(y_pred_list)

y_test = data[part]['y'].cpu().numpy()

# Overall accuracy
accuracy = accuracy_score(y_test, y_pred)
print(f"Accuracy: {accuracy:.4f}")

# Classification report (includes precision, recall, F1 per class + macro/micro)
report = classification_report(y_test, y_pred, output_dict=True)
report_df = pd.DataFrame(report).transpose()
print("\nClassification Report:\n", report_df)

# F1 scores
f1_micro = f1_score(y_test, y_pred, average='micro')
f1_macro = f1_score(y_test, y_pred, average='macro')
print(f"\nF1 (Micro): {f1_micro:.4f}")
print(f"F1 (Macro): {f1_macro:.4f}")

# Class-wise accuracy (same as recall per class)
class_wise_accuracy = report_df.loc[[str(i) for i in np.unique(y_test)], "recall"]
print("\nClass-wise Accuracy (Recall):\n", class_wise_accuracy)

# Confusion Matrix
conf_matrix = confusion_matrix(y_test, y_pred)
print("\nConfusion Matrix:\n", conf_matrix)

Total parameters: 479392


  0%|          | 0/10834 [00:00<?, ?it/s]

100%|██████████| 10834/10834 [00:07<00:00, 1425.33it/s]


Accuracy: 0.9906

Classification Report:
               precision    recall  f1-score      support
0              0.996434  0.994424  0.995428  125892.0000
1              0.873532  0.950524  0.910403    3052.0000
2              0.659615  0.650237  0.654893    1055.0000
accuracy       0.990600  0.990600  0.990600       0.9906
macro avg      0.843194  0.865062  0.853575  129999.0000
weighted avg   0.990815  0.990600  0.990668  129999.0000

F1 (Micro): 0.9906
F1 (Macro): 0.8536

Class-wise Accuracy (Recall):
 0    0.994424
1    0.950524
2    0.650237
Name: recall, dtype: float64

Confusion Matrix:
 [[125190    399    303]
 [   100   2901     51]
 [   348     21    686]]
