# Multiclass classification on ToN http dataset with encoder clustering

### Imports

In [None]:
import torch
import logging
import os
import joblib

from src.utilities.config_manager import ConfigManager
from src.utilities.io_handler import load_data
from src.utilities.dataset_utils import *
from pytorch_tabnet.tab_model import TabNetClassifier
import torch.nn as nn
import numpy as np
from sklearn.preprocessing import LabelEncoder, StandardScaler
import matplotlib.pyplot as plt


### Configuration

In [None]:
DATASET_CONFIG_PATH = './config/ton_config.json'

ConfigManager.load_config(DATASET_CONFIG_PATH)
paths_config = ConfigManager.get_section("paths")
data_cols_config = ConfigManager.get_section("data_columns")

DATA_PATH = paths_config.get("dataset_path")
OUTPUT_DIR = paths_config.get("output_dir")
TARGET_COL = data_cols_config.get("target_category_column")
NUMERICAL_COLS = data_cols_config.get("numerical_cols")
CATEGORICAL_COLS = data_cols_config.get("categorical_cols")
RANDOM_STATE = 42    

### Dataset loading and splitting

In [None]:
df = load_data(DATA_PATH)

keep_cols = CATEGORICAL_COLS + NUMERICAL_COLS + [TARGET_COL]
df = df[keep_cols].copy() 

train_df, temp_df = train_test_split(df, test_size=0.3, random_state=RANDOM_STATE, stratify=df[TARGET_COL])
valid_df, test_df = train_test_split(temp_df, test_size=0.5, random_state=RANDOM_STATE, stratify=temp_df[TARGET_COL])   
test_df.to_csv(f'./resources/dataset/test_set_ton.csv', index=False)

### Preprocessing
- StandardScaler per le features numeriche; nonostante TabNet accetti features numeriche raw, normalizzare i dati aumenta le performance del modello

- LabelEncoder per le features categoriche; sarebbe meglio usare OrdinalEncoder, questo è un esperimento. inoltre mappiamo le categorie sconosciute al train set con '_UNK'

In [None]:
scaler = StandardScaler()
scaler.fit(train_df[NUMERICAL_COLS])
for _df in (train_df, valid_df, test_df):
    _df[NUMERICAL_COLS] = scaler.transform(_df[NUMERICAL_COLS])


categorical_dims, encoders = {}, {}
for col in CATEGORICAL_COLS:
    le = LabelEncoder().fit(train_df[col])
    le.classes_ = np.append(le.classes_, "_UNK")
    train_df[col] = le.transform(train_df[col])
    valid_df[col] = le.transform(
        valid_df[col].where(valid_df[col].isin(le.classes_), "_UNK")
    )
    test_df[col] = le.transform(
        test_df[col].where(test_df[col].isin(le.classes_), "_UNK")
    )
    categorical_dims[col] = len(le.classes_)   
    encoders[col] = le  
y_le = LabelEncoder().fit(train_df[TARGET_COL])
for _df in (train_df, valid_df, test_df):
    _df[TARGET_COL] = y_le.transform(_df[TARGET_COL])

### Some parameters

In [None]:
unused_feat = [ col for col in df.columns if col not in NUMERICAL_COLS + CATEGORICAL_COLS]

features = [ col for col in df.columns if col not in unused_feat+[TARGET_COL]] 

cat_idxs = [ i for i, f in enumerate(features) if f in CATEGORICAL_COLS]

cat_dims = [ categorical_dims[f] for i, f in enumerate(features) if f in CATEGORICAL_COLS]

device = 'cuda' if torch.cuda.is_available() else 'cpu'

print(f"Used features: {features}")
print(f"Unused features: {unused_feat}")

### Model instance

In [None]:
clf = TabNetClassifier(
    n_d=64, n_a=64, n_steps=5,
    gamma=1.8,
    cat_idxs=cat_idxs,
    cat_dims=cat_dims,
    cat_emb_dim=[min(50, (dim + 1) // 2) for dim in cat_dims], 
    lambda_sparse=1e-3,
    momentum=0.02,
    clip_value=1.5,
    optimizer_fn=torch.optim.Adam,
    optimizer_params=dict(lr=1e-2),  
    scheduler_params={"gamma": 0.95, "step_size": 20},
    scheduler_fn=torch.optim.lr_scheduler.StepLR,
    epsilon=1e-15, device_name=device
)

### Sets prep

In [None]:
X_train = train_df[features].values
y_train = train_df[TARGET_COL].values

X_valid = valid_df[features].values
y_valid = valid_df[TARGET_COL].values

X_test = test_df[features].values
y_test = test_df[TARGET_COL].values

#### Computing class weights

In [None]:
from sklearn.utils.class_weight import compute_class_weight
cw = compute_class_weight('balanced', classes=np.unique(y_train), y=y_train)
cw = cw / cw.mean() # to avoid huge differences
weights_tensor = torch.tensor(cw, dtype=torch.float).to(device)
loss_fn = nn.CrossEntropyLoss(weight=weights_tensor)

### Model training

In [None]:
clf.fit(
    X_train=X_train, y_train=y_train,
    eval_set=[(X_train, y_train), (X_valid, y_valid)],
    eval_name=['train', 'valid'],
    max_epochs=100, patience=15,
    batch_size=2048, virtual_batch_size=256,
    loss_fn=loss_fn,
    eval_metric= ['balanced_accuracy', 'accuracy']
)

df_hist = pd.DataFrame(clf.history)
ax = df_hist[[
    'train_loss',
    'train_balanced_accuracy'
]].plot(figsize=(10, 5), grid=True)
ax.set_xlabel('Epoch')
ax.set_ylabel('Val')
ax.set_title('Training vs Validation: Loss & Balanced Accuracy')
plt.legend(loc='best')
plt.show()

### Extracting encoder output

Stiamo prendendo gli output di tutti gli step e li stiamo sommando, esattamente come fa tabnet, senza aggiungere l'utlimo layer lineare che serve a fare previsioni 

In [None]:
clf.network.eval()
with torch.no_grad():
    X_tensor = torch.tensor(X_test, device=clf.device, dtype=torch.float)
    steps_output, _ = clf.network.encoder(X_tensor)
    features = torch.sum(torch.stack(steps_output, dim=0), dim=0)  
Z = features.cpu().numpy()

### Dimension reduction

In [None]:
from sklearn.decomposition import PCA
Z_2d = PCA(n_components=2, random_state=42).fit_transform(Z)

### Clustering and plotting

In [None]:
from sklearn.cluster import KMeans
labels = KMeans(n_clusters=3, random_state=42).fit_predict(Z_2d)
plt.scatter(Z_2d[:,0], Z_2d[:,1], c=labels, alpha=0.7)
plt.show()

### Model evaluation

In [None]:
from sklearn.metrics import accuracy_score, balanced_accuracy_score, classification_report, confusion_matrix

y_pred = clf.predict(X_test)
acc = accuracy_score(y_test, y_pred)
bal_acc = balanced_accuracy_score(y_test, y_pred)
report = classification_report(y_test, y_pred)
cm = confusion_matrix(y_test, y_pred)

print(f"Test Accuracy: {acc:.4f}")
print(f"Test Balanced Accuracy: {bal_acc:.4f}")
print("Classification Report:\n" + report)
print("Confusion Matrix:\n" + str(cm))