In [1]:
import nn_utils
import builders
import importlib

import skorch
from sklearn import metrics
from ray import tune
import optuna
from ray.tune.suggest.optuna import OptunaSearch
import torch

from ray.tune.integration.torch import DistributedTrainableCreator

# Configuration

In [2]:
MODULE = "adult.cls.config"
CHECKPOINT_DIR = "./adult/cls/checkpoint"
SEED = 11

# Get dataset and components

In [3]:
def get_class_from_type(module, class_type):
    for attr in dir(module):
        clazz = getattr(module, attr)
        if callable(clazz) and issubclass(clazz, class_type):
            return clazz
        
    return None

In [4]:
module = importlib.import_module(MODULE)

In [5]:
dataset = get_class_from_type(module, builders.DatasetConfig)
if dataset is not None:
    dataset = dataset()
else:
    raise ValueError("Dataset configuration not found")

In [6]:
transformer_config = get_class_from_type(module, builders.TransformerConfig)
if transformer_config is not None:
    transformer_config = transformer_config()
else:
    raise ValueError("Transformer configuration not found")

# Configure dataset

In [7]:
if not dataset.exists():
    dataset.download()
    
dataset.load(seed=SEED)

Target mapping: {'<=50K': 0, '>50K': 1}
Numerical columns: ['fnlwgt', 'education-num']
Categorical columns: ['age', 'workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race', 'sex', 'capitalgain', 'capitalloss', 'hoursperweek', 'native-country']
Columns: ['age', 'workclass', 'fnlwgt', 'education', 'education-num', 'marital-status', 'occupation', 'relationship', 'race', 'sex', 'capitalgain', 'capitalloss', 'hoursperweek', 'native-country', 'class']


In [8]:
preprocessor = nn_utils.get_default_preprocessing_pipeline(
                        dataset.get_categorical_columns(),
                        dataset.get_numerical_columns()
                    )

# Data preparation

In [9]:
train_features, train_labels = dataset.get_train_data()
val_features, val_labels = dataset.get_val_data()
test_features, test_labels = dataset.get_test_data()

preprocessor = preprocessor.fit(train_features, train_labels)

train_features = preprocessor.transform(train_features)
val_features = preprocessor.transform(val_features)
test_features = preprocessor.transform(test_features)

all_features, all_labels, indices = nn_utils.join_data([train_features, val_features], [train_labels, val_labels])
train_indices, val_indices = indices[0], indices[1]

In [10]:
if dataset.get_n_labels() <= 2:
    n_labels = 1
    criterion = torch.nn.BCEWithLogitsLoss
else:
    n_labels = dataset.get_n_labels()
    criterion = torch.nn.CrossEntropyLoss

# Hyperparameter search

In [11]:
def trainable(config, checkpoint_dir=CHECKPOINT_DIR):
    
    embedding_size = config.pop("embedding_size")
    
    model_params = {
        **config,
        "encoders": transformer_config.get_encoders(embedding_size),
        "aggregator": transformer_config.get_aggregator(embedding_size),
        "preprocessor": transformer_config.get_preprocessor(),
        "optimizer": torch.optim.SGD,
        "criterion": criterion,
        "device": "cuda" if torch.cuda.is_available() else "cpu",
        "batch_size": 128,
        "max_epochs": 2,
        "n_output": n_labels, # The number of output neurons
        "need_weights": False,
        "verbose": 1
        
    }
    
    model = nn_utils.build_transformer_model(
                train_indices,
                val_indices,
                nn_utils.get_default_callbacks(seed=SEED),
                **model_params
                )
    
    model = model.fit(X=all_features, y=all_labels)

In [None]:
analysis = tune.run(
    trainable,
    resume="AUTO",
    local_dir=CHECKPOINT_DIR, 
    name="param_search"
    
)

2022-02-20 22:31:10,250	INFO trial_runner.py:488 -- A local experiment checkpoint was found and will be used to restore the previous experiment state.
2022-02-20 22:31:11,507	INFO tune.py:551 -- TrialRunner resumed, ignoring new add_experiment but updating trial resources.


Trial name,status,loc,dropout,embedding_size,n_head,n_hid,n_layers,optimizer__lr,iter,total time (s),train_loss,valid_loss,balanced_accuracy
trainable_727f4248,TERMINATED,192.168.1.72:2942,0.0699997,128,2,256,3,0.00909573,15,80.1114,0.333458,0.40196,0.816635
trainable_00535cf0,TERMINATED,192.168.1.72:2469,0.144154,256,2,32,3,0.008829,15,82.4752,0.332235,0.384565,0.819564
trainable_a0f55f9c,TERMINATED,192.168.1.72:2363,0.13851,1024,4,32,2,0.0129706,60,617.249,0.322936,0.360713,0.816545
trainable_c038bce2,TERMINATED,192.168.1.72:1519,0.000403666,1024,8,64,2,0.0119609,15,156.124,0.328601,0.450239,0.82022
trainable_85d079fa,TERMINATED,192.168.1.72:1351,0.00726525,1024,8,64,1,0.0322606,60,372.776,0.317215,0.343748,0.814118
trainable_25d5a1ec,TERMINATED,192.168.1.72:1084,0.0319958,1024,8,64,1,0.0395129,15,93.8465,0.391116,0.385768,0.779754
trainable_b3acad78,TERMINATED,192.168.1.72:481,0.048099,1024,8,64,2,0.0944024,15,157.05,0.324286,0.397083,0.813924
trainable_8a8a0242,TERMINATED,192.168.1.72:353,0.0503045,1024,8,64,2,0.065878,60,617.106,0.315559,0.347267,0.813576
trainable_616a1c6c,TERMINATED,192.168.1.72:32751,0.0504516,64,8,64,2,0.00253157,15,65.7435,0.408894,0.407552,0.754284
trainable_adfcbd1a,TERMINATED,192.168.1.72:32549,0.242383,64,8,512,2,0.00150805,15,65.047,0.532288,0.527636,0.5


2022-02-20 22:31:11,740	INFO tune.py:636 -- Total run time: 1.53 seconds (0.00 seconds for the tuning loop).


In [None]:
print("Best config: ", analysis.get_best_config(metric="valid_loss", mode="min"))

In [None]:
analysis.results_df