In [None]:
import nn_utils
import builders
import importlib

from ray import tune
import optuna
from ray.tune.suggest.optuna import OptunaSearch
import torch
import torch.nn as nn

from ray.tune.schedulers import AsyncHyperBandScheduler

import inspect
import argparse

import skorch
import numpy as np


from tab_transformer_pytorch import TabTransformer, FTTransformer

from sklearn import base, pipeline, preprocessing, compose, metrics

In [None]:

dataset = "adult"
aggregator_str = "cls"

In [None]:
MODULE = f"{dataset}.{aggregator_str}.config"
CHECKPOINT_DIR = f"./{dataset}/{aggregator_str}/checkpoint"
SEED = 11
N_SAMPLES = 30

BATCH_SIZE = 128
MAX_EPOCHS = 500 
EARLY_STOPPING = 15
MAX_CHECKPOINTS = 10
multiclass = False

In [None]:
def get_class_from_type(module, class_type):
    for attr in dir(module):
        clazz = getattr(module, attr)
        if callable(clazz) and inspect.isclass(clazz) and issubclass(clazz, class_type) and not str(clazz)==str(class_type):
            return clazz
        
    return None

In [None]:
def get_params_startswith(params, prefix):
    keys = [k for k in params.keys() if k.startswith(prefix)]
    extracted = {}

    for k in keys:
        extracted[k.replace(prefix, "")] = params.pop(k)

    return extracted

In [None]:
module = importlib.import_module(MODULE)

dataset = get_class_from_type(module, builders.DatasetConfig)
if dataset is not None:
    dataset = dataset()
else:
    raise ValueError("Dataset configuration not found")

transformer_config = get_class_from_type(module, builders.TransformerConfig)
if transformer_config is not None:
    transformer_config = transformer_config()
else:
    raise ValueError("Transformer configuration not found")

search_space_config = get_class_from_type(module, builders.SearchSpaceConfig)
if search_space_config is not None:
    search_space_config = search_space_config()
else:
    raise ValueError("Search space configuration not found")

In [None]:
def get_default_preprocessing_pipeline(categorical_cols, numerical_cols):
    categorical_transformer = pipeline.Pipeline(steps=[
        ('label', preprocessing.OrdinalEncoder(handle_unknown="use_encoded_value", unknown_value=-1))
    ])

    numerical_transformer = pipeline.Pipeline(steps=[
        ('scaler', preprocessing.MinMaxScaler())
    ])

    preprocessing_pipe = pipeline.Pipeline([
        ('columns_transformer', compose.ColumnTransformer(
            remainder='passthrough', #passthough features not listed
            transformers=[
                ('categorical_transformer', categorical_transformer , categorical_cols),
                ('numerical_transformer', numerical_transformer , numerical_cols)
            ]),
        ),
        ('dtype_transform', nn_utils.DTypeTransformer(np.float32))
    ])

    return preprocessing_pipe


In [None]:

if not dataset.exists():
    dataset.download()
    
dataset.load(seed=None)

preprocessor = get_default_preprocessing_pipeline(
                        dataset.get_categorical_columns(),
                        dataset.get_numerical_columns()
                    )

In [None]:
train_features, train_labels = dataset.get_train_data()
val_features, val_labels = dataset.get_val_data()
test_features, test_labels = dataset.get_test_data()

total_examples = train_features.shape[0] + val_features.shape[0] + test_features.shape[0]

print("Training examples {} ({})".format(train_features.shape[0], train_features.shape[0] / total_examples))
print("Validation examples {} ({})".format(val_features.shape[0], val_features.shape[0] / total_examples))
print("Test examples {} ({})".format(test_features.shape[0], test_features.shape[0] / total_examples))

In [None]:
preprocessor = preprocessor.fit(train_features, train_labels)

train_features = preprocessor.transform(train_features)
val_features = preprocessor.transform(val_features)
test_features = preprocessor.transform(test_features)

all_features, all_labels, indices = nn_utils.join_data([train_features, val_features], [train_labels, val_labels])
train_indices, val_indices = indices[0], indices[1]

if dataset.get_n_labels() <= 2:
    n_labels = 1
    criterion = torch.nn.BCEWithLogitsLoss
else:
    n_labels = dataset.get_n_labels()
    multiclass = True
    criterion = torch.nn.CrossEntropyLoss

In [None]:
def build_transformer_model(
    train_indices,
    validation_indices,
    callbacks,
    n_categories, # List of number of categories
    n_numerical, # Number of numerical features
    n_head, # Number of heads per layer
    n_hid, # Size of the MLP inside each transformer encoder layer
    n_layers, # Number of transformer encoder layers    
    n_output, # The number of output neurons
    embed_dim,
    dropout=0.1, # Used dropout
    aggregator=None, # The aggregator for output vectors before decoder
    categorical_preprocessor=None,
    numerical_preprocessor=None,
    need_weights=False,
    numerical_passthrough=False,
    decoder_hidden_units=None,
    decoder_activation_fn=None,
    **kwargs
    ):

    module = TabTransformer(
        categories = (8, 16, 7, 14, 6, 5, 2, 41),      # tuple containing the number of unique values within each category
        num_continuous = 6,                # number of continuous values
        dim = 32,                           # dimension, paper set at 32
        dim_out = 1,                        # binary prediction, but could be anything
        depth = 6,                          # depth, paper recommended 6
        heads = 8,                          # heads, paper recommends 8
        attn_dropout = 0.1,                 # post-attention dropout
        ff_dropout = 0.1,                   # feed forward dropout
        mlp_hidden_mults = (4, 2),          # relative multiples of each hidden dimension of the last mlp to logits
        mlp_act = torch.nn.ReLU(),                # activation for final mlp, defaults to relu, but could be anything else (selu etc)
    )

    """
    # Define model
    module = TabularTransformer(
        n_head, # Number of heads per layer
        n_hid, # Size of the MLP inside each transformer encoder layer
        n_layers, # Number of transformer encoder layers    
        n_output, # The number of output neurons
        torch.nn.ModuleList(encoders), # List of features encoders
        dropout=dropout, # Used dropout
        aggregator=aggregator, # The aggregator for output vectors before decoder
        preprocessor=preprocessor,
        need_weights=need_weights,
        numerical_passthrough=numerical_passthrough
    )
    """

    model = skorch.NeuralNetClassifier(
            module=module,
            train_split=skorch.dataset.CVSplit(((train_indices, validation_indices),)),
            callbacks=callbacks,
            **kwargs
        )

    return model

In [None]:
def build_transformer_model_ft(
    train_indices,
    validation_indices,
    callbacks,
    n_categories, # List of number of categories
    n_numerical, # Number of numerical features
    n_head, # Number of heads per layer
    n_hid, # Size of the MLP inside each transformer encoder layer
    n_layers, # Number of transformer encoder layers    
    n_output, # The number of output neurons
    embed_dim,
    dropout=0.1, # Used dropout
    aggregator=None, # The aggregator for output vectors before decoder
    categorical_preprocessor=None,
    numerical_preprocessor=None,
    need_weights=False,
    numerical_passthrough=False,
    decoder_hidden_units=None,
    decoder_activation_fn=None,
    **kwargs
    ):

    module = FTTransformer(
        categories = (8, 16, 7, 14, 6, 5, 2, 41),      # tuple containing the number of unique values within each category
        num_continuous = 6,                # number of continuous values
        dim = 32,                           # dimension, paper set at 32
        dim_out = 1,                        # binary prediction, but could be anything
        depth = 6,                          # depth, paper recommended 6
        heads = 8,                          # heads, paper recommends 8
        attn_dropout = 0.1,                 # post-attention dropout
        ff_dropout = 0.1,                   # feed forward dropout
    )

    """
    # Define model
    module = TabularTransformer(
        n_head, # Number of heads per layer
        n_hid, # Size of the MLP inside each transformer encoder layer
        n_layers, # Number of transformer encoder layers    
        n_output, # The number of output neurons
        torch.nn.ModuleList(encoders), # List of features encoders
        dropout=dropout, # Used dropout
        aggregator=aggregator, # The aggregator for output vectors before decoder
        preprocessor=preprocessor,
        need_weights=need_weights,
        numerical_passthrough=numerical_passthrough
    )
    """

    model = skorch.NeuralNetClassifier(
            module=module,
            train_split=skorch.dataset.CVSplit(((train_indices, validation_indices),)),
            callbacks=callbacks,
            **kwargs
        )

    return model

In [None]:
config = {
            "n_layers": 6,
            "optimizer__lr": 10e-4,#tune.choice([10e-6, 10e-5, 10e-4, 10e-3]),
            "optimizer__weight_decay": 10e-1,
            "n_head": 8, # Number of heads per layer
            "n_hid": 128, # Size of the MLP inside each transformer encoder layer
            "dropout": 0.1, #tune.choice([0, 0.1, 0.2, 0.3, 0.4, 0.5]), # Used dropout
            "embedding_size": 32,
            "numerical_passthrough": False
        }


embedding_size = config.pop("embedding_size")

encoders_params = get_params_startswith(config, "encoders__")
aggregator_params = get_params_startswith(config, "aggregator__")
preprocessor_params = get_params_startswith(config, "preprocessor__")

model_params = {
    **config,
    #"encoders": transformer_config.get_encoders(embedding_size, **{**config, **encoders_params}),
    "n_categories": (8, 16, 7, 14, 6, 5, 2, 41),
    "n_numerical": 6,
    "embed_dim": 32,
    "aggregator": transformer_config.get_aggregator(embedding_size, **{**config, **aggregator_params}),
    "numerical_preprocessor": transformer_config.get_preprocessor(**{**config, **preprocessor_params}),
    "optimizer": torch.optim.AdamW,
    "criterion": criterion,
    "device": "cuda" if torch.cuda.is_available() else "cpu",
    "batch_size": BATCH_SIZE,
    "max_epochs": MAX_EPOCHS,
    "n_output": n_labels, # The number of output neurons
    "need_weights": False,
    "decoder_hidden_units": [128, 64],
    "decoder_activation_fn": nn.ReLU(),
    "verbose": 1
    
}

model = nn_utils.build_transformer_model(
            train_indices,
            val_indices,
            nn_utils.get_default_callbacks(seed=SEED, multiclass=multiclass),
            **model_params
            )

#model = build_transformer_model_ft(
#            train_indices,
#            val_indices,
#            nn_utils.get_default_callbacks(seed=SEED, multiclass=multiclass),
#            **model_params
#            )

model = model.fit(X={
    "x_categorical": all_features[:, :8].astype(np.int32), 
    "x_numerical": all_features[:, 8:].astype(np.float32)
    }, 
    y=all_labels)


# TabTransformer
#model = model.fit(X={
#    "x_categ": all_features[:, :8].astype(np.int32), 
#    "x_cont": all_features[:, 8:].astype(np.float32)
#    }, 
#    y=all_labels[:, np.newaxis].astype(np.double))

# FT-Transformer 
model = model.fit(X={
    "x_categ": all_features[:, :8].astype(np.int32), 
    "x_numer": all_features[:, 8:].astype(np.float32)
    }, 
    y=all_labels[:, np.newaxis].astype(np.double))

In [None]:
def trainable(config, checkpoint_dir=CHECKPOINT_DIR):
    
    embedding_size = config.pop("embedding_size")

    encoders_params = get_params_startswith(config, "encoders__")
    aggregator_params = get_params_startswith(config, "aggregator__")
    preprocessor_params = get_params_startswith(config, "preprocessor__")

    model_params = {
        **config,
        "encoders": transformer_config.get_encoders(embedding_size, **{**config, **encoders_params}),
        "aggregator": transformer_config.get_aggregator(embedding_size, **{**config, **aggregator_params}),
        "preprocessor": transformer_config.get_preprocessor(**{**config, **preprocessor_params}),
        "optimizer": torch.optim.AdamW,
        "criterion": criterion,
        "device": "cuda" if torch.cuda.is_available() else "cpu",
        "batch_size": BATCH_SIZE,
        "max_epochs": MAX_EPOCHS,
        "n_output": n_labels, # The number of output neurons
        "need_weights": False,
        "verbose": 1
        
    }
    
    model = nn_utils.build_transformer_model(
                train_indices,
                val_indices,
                nn_utils.get_default_callbacks(seed=SEED, multiclass=multiclass),
                **model_params
                )
    
    model = model.fit(X=all_features, y=all_labels)

In [None]:
search_space = {
            "n_layers": tune.choice([6]), # Number of transformer encoder layers    
            "optimizer__lr": tune.choice([10e-6, 10e-5, 10e-4, 10e-3]),
            "n_head": tune.choice([8]), # Number of heads per layer
            "n_hid": tune.choice([56]), # Size of the MLP inside each transformer encoder layer
            "dropout": tune.choice([0, 0.1, 0.2, 0.3, 0.4, 0.5]), # Used dropout
            "embedding_size": tune.choice([32]),
            "numerical_passthrough": tune.choice([True])
        }

In [None]:
resume_modes = ["AUTO", "ERRORED_ONLY"]


for try_cnt, resume_mode in enumerate(resume_modes):
    try:
        0 / 0

        analysis = tune.run(
            trainable,
            config=search_space,
            resources_per_trial={
                "gpu": 1,
                "cpu": 6
            },
            search_alg=OptunaSearch(
                metric="roc_auc",
                mode="max",
                sampler=optuna.samplers.TPESampler()
            ),
            num_samples=N_SAMPLES,
            fail_fast=True,
            checkpoint_score_attr="max-roc_auc",
            keep_checkpoints_num=MAX_CHECKPOINTS,
            resume=resume_mode,
            local_dir=CHECKPOINT_DIR, 
            name="param_search",
            scheduler=AsyncHyperBandScheduler(
                            time_attr="training_iteration",
                            metric="roc_auc",
                            mode="max",
                            grace_period=EARLY_STOPPING
                        )
        )

        break
    except Exception as e:

        if try_cnt + 1 == len(resume_modes):
            raise(e)

        print(e)
        print("Retrying in second mode")

In [None]:
print("Best config: ", analysis.get_best_config(metric="roc_auc", mode="max"))