In [1]:
import nn_utils
import builders
import importlib

from ray import tune
import optuna
from ray.tune.suggest.optuna import OptunaSearch
import torch

from ray.tune.schedulers import AsyncHyperBandScheduler
from ray.tune import ExperimentAnalysis
from ray.tune import register_trainable

import inspect
import argparse
import skorch
import os

import json

from torch.utils import tensorboard
from sklearn import metrics

import numpy as np

In [2]:
DATASETS = [
    "sylvine", "volkert",
    "adult", "australian",
    "anneal",  
    "jasmine", "kr_vs_kp", 
    "nomao", "ldpa"
]
DATASETS = [
    "adult"
]

AGGREGATORS = ["cls", "concatenate", "rnn", "sum", "mean", "max"]
AGGREGATORS = ["cls", "concatenate"]
BATCH_SIZE = 128
SEED = 11

In [3]:
results = {}
best_configs = {}

In [4]:
def count_parameters(model, trainable=True):
    total_params = 0
    
    for name, parameter in model.module_.named_parameters():
        
        if not parameter.requires_grad and trainable: 
            continue
            
        params = parameter.numel()
        total_params+=params
        
    return total_params

In [5]:
errors= []

for dataset_ in DATASETS:
    for aggregator_str_ in AGGREGATORS:
        
        dataset = dataset_
        aggregator_str = aggregator_str_
       
       
        print(f"Using -- Dataset:{dataset} Aggregator:{aggregator_str}")

        #####################################################
        # Configuration
        #####################################################

        MODULE = f"{dataset}.{aggregator_str}.config"
        CHECKPOINT_DIR = f"./{dataset}/{aggregator_str}/checkpoint"
        SEED = 11
        N_SAMPLES = 30

        BATCH_SIZE = 128
        MAX_EPOCHS = 1000
        EARLY_STOPPING = 30
        MAX_CHECKPOINTS = 10
        multiclass = False

        #####################################################
        # Util functions
        #####################################################

        def get_class_from_type(module, class_type):
            for attr in dir(module):
                clazz = getattr(module, attr)
                if callable(clazz) and inspect.isclass(clazz) and issubclass(clazz, class_type) and not str(clazz)==str(class_type):
                    return clazz

            return None

        def get_params_startswith(params, prefix):
            keys = [k for k in params.keys() if k.startswith(prefix)]
            extracted = {}

            for k in keys:
                extracted[k.replace(prefix, "")] = params.pop(k)

            return extracted

        def trainable(config, checkpoint_dir=CHECKPOINT_DIR):
            embedding_size = config.pop("embedding_size")

            aggregator_params = get_params_startswith(config, "aggregator__")
            preprocessor_params = get_params_startswith(config, "preprocessor__")

            limit_categories = len( transformer_config.get_n_categories())

            model_params = {
                **config,
                "n_categories": transformer_config.get_n_categories(),
                "n_numerical": transformer_config.get_n_numerical(),
                "embed_dim": embedding_size,
                "aggregator": transformer_config.get_aggregator(embedding_size, **{**config, **aggregator_params}),
                "categorical_preprocessor": transformer_config.get_preprocessor(**{**config, **preprocessor_params}),
                "optimizer": torch.optim.AdamW,
                "criterion": criterion,
                "device": "cuda" if torch.cuda.is_available() else "cpu",
                "batch_size": BATCH_SIZE,
                "max_epochs": MAX_EPOCHS,
                "n_output": n_labels, # The number of output neurons
                "need_weights": False,
                "decoder_hidden_units": transformer_config.get_decoder_hidden_units(),
                "decoder_activation_fn": transformer_config.get_decoder_activation_fn(),
                "verbose": 1
                
            }

            if not os.path.exists(os.path.join(CHECKPOINT_DIR, "best_model/.fitted")):
                print("Not fitted before! I'm not going to do anything")
                return


            checkpoint = skorch.callbacks.Checkpoint(monitor="balanced_accuracy_best", dirname=os.path.join(CHECKPOINT_DIR, "best_model"))

            model = nn_utils.build_transformer_model(
                        train_indices,
                        val_indices, 
                        [],
                        **model_params
                        )
            model.load_params(checkpoint=checkpoint)
            return model
        

            

        #####################################################
        # Dataset and components
        #####################################################

        module = importlib.import_module(MODULE)

        dataset = get_class_from_type(module, builders.DatasetConfig)
        if dataset is not None:
            dataset = dataset()
        else:
            raise ValueError("Dataset configuration not found")

        transformer_config = get_class_from_type(module, builders.TransformerConfig)
        if transformer_config is not None:
            transformer_config = transformer_config()
        else:
            raise ValueError("Transformer configuration not found")

        search_space_config = get_class_from_type(module, builders.SearchSpaceConfig)
        if search_space_config is not None:
            search_space_config = search_space_config()
        else:
            raise ValueError("Search space configuration not found")

        #####################################################
        # Configure dataset
        #####################################################

        if not dataset.exists():
            dataset.download()

        dataset.load(seed=SEED)

        preprocessor = nn_utils.get_default_preprocessing_pipeline(
                                dataset.get_categorical_columns(),
                                dataset.get_numerical_columns()
                            )

        #####################################################
        # Data preparation
        #####################################################

        train_features, train_labels = dataset.get_train_data()
        val_features, val_labels = dataset.get_val_data()
        test_features, test_labels = dataset.get_test_data()

        preprocessor = preprocessor.fit(train_features, train_labels)

        train_features = preprocessor.transform(train_features)
        val_features = preprocessor.transform(val_features)
        test_features = preprocessor.transform(test_features)

        all_features, all_labels, indices = nn_utils.join_data([train_features, val_features], [train_labels, val_labels])
        train_indices, val_indices = indices[0], indices[1]

        if dataset.get_n_labels() <= 2:
            n_labels = 1
            criterion = torch.nn.BCEWithLogitsLoss
        else:
            n_labels = dataset.get_n_labels()
            multiclass = True
            criterion = torch.nn.CrossEntropyLoss

        #####################################################
        # Hyperparameter search
        #####################################################
        
        #register_trainable("training_function", training_function)

        limit_categories = len( transformer_config.get_n_categories())
        
        register_trainable("trainable", trainable)
        
        try:
            '''
            analysis = tune.run(
                trainable,
                resume="AUTO",
                local_dir=CHECKPOINT_DIR, 
                name="param_search"    
            )
            '''
            
            analysis = ExperimentAnalysis(os.path.join(CHECKPOINT_DIR, "param_search"))
            
            best_config = analysis.get_best_config(metric="balanced_accuracy", mode="max")
            '''
            if dataset_ not in results:
                results[dataset_] = {}

            if aggregator_str_ not in results[dataset_]:
                results[dataset_][aggregator_str_] = {}
                
            results[dataset_][aggregator_str_]["balanced_accuracy"] = analysis.get_best_trial(metric="balanced_accuracy", mode="max").last_result["balanced_accuracy"]
            del analysis
            '''
            model = trainable(best_config)
            y_pred_train = model.predict({
                "x_categorical": train_features[:, :limit_categories].astype(np.int32), 
                "x_numerical": train_features[:, limit_categories:].astype(np.float32)
                })
            y_pred_val = model.predict({
                "x_categorical": val_features[:, :limit_categories].astype(np.int32), 
                "x_numerical": val_features[:, limit_categories:].astype(np.float32)
                })
            y_pred_test = model.predict({
                "x_categorical": test_features[:, :limit_categories].astype(np.int32), 
                "x_numerical": test_features[:, limit_categories:].astype(np.float32)
                })

            if dataset_ not in results:
                results[dataset_] = {}
                best_configs[dataset_] = {}

            if aggregator_str_ not in results[dataset_]:
                results[dataset_][aggregator_str_] = {}
                best_configs[dataset_][aggregator_str_] = best_config

            #results[dataset_][aggregator_str_]["loss"] = metrics.log_loss(test_labels, y_pred)
            results[dataset_][aggregator_str_]["balanced_accuracy_train"] = metrics.balanced_accuracy_score(train_labels, y_pred_train)
            results[dataset_][aggregator_str_]["balanced_accuracy_val"] = metrics.balanced_accuracy_score(val_labels, y_pred_val)
            results[dataset_][aggregator_str_]["balanced_accuracy_test"] = metrics.balanced_accuracy_score(test_labels, y_pred_test)
            results[dataset_][aggregator_str_]["n_parameters"] = count_parameters(model, trainable=False)
            results[dataset_][aggregator_str_]["n_trainable"] = count_parameters(model)
            #results[dataset_][aggregator_str_]["roc_auc"] = metrics.roc_auc_score(test_labels, y_pred)
            #print(metrics.balanced_accuracy_score(test_labels, y_pred))
            
        except Exception as e:
            errors.append("{}.{} - {}".format(dataset_, aggregator_str_, str(e)))
            pass
        
        


Using -- Dataset:adult Aggregator:cls
Target mapping: {'<=50K': 0, '>50K': 1}
Numerical columns: ['age', 'fnlwgt', 'education-num', 'capital-gain', 'capital-loss', 'hours-per-week']
Categorical columns: ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race', 'sex', 'native-country']
Columns: ['age', 'workclass', 'fnlwgt', 'education', 'education-num', 'marital-status', 'occupation', 'relationship', 'race', 'sex', 'capital-gain', 'capital-loss', 'hours-per-week', 'native-country', 'class']
Not fitted before! I'm not going to do anything
Using -- Dataset:adult Aggregator:concatenate
Target mapping: {'<=50K': 0, '>50K': 1}
Numerical columns: ['age', 'fnlwgt', 'education-num', 'capital-gain', 'capital-loss', 'hours-per-week']
Categorical columns: ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race', 'sex', 'native-country']
Columns: ['age', 'workclass', 'fnlwgt', 'education', 'education-num', 'marital-status', 'occupation', 'relations

In [6]:
results

{'adult': {'concatenate': {'balanced_accuracy_train': 0.8022011983862758,
   'balanced_accuracy_val': 0.7963661976992864,
   'balanced_accuracy_test': 0.79289930085895,
   'n_parameters': 354433,
   'n_trainable': 354433}}}

In [7]:
errors

["adult.cls - 'NoneType' object has no attribute 'predict'"]

In [8]:
with open("eval_results.json", "w") as f:
    json.dump(results, f)

In [9]:
best_configs

{'adult': {'concatenate': {'n_layers': 1,
   'optimizer__lr': 1.3473254295104352e-05,
   'n_head': 16,
   'n_hid': 128,
   'dropout': 0.014089422184433743,
   'numerical_passthrough': False}}}

In [10]:
with open("best_configs.json", "w") as f:
    json.dump(best_configs, f)

In [11]:
print(errors)

["adult.cls - 'NoneType' object has no attribute 'predict'"]


In [12]:
import pickle

with open("adult/concatenate/checkpoint/param_search/searcher-state-2022-04-28_04-26-37.pkl", "rb") as f:
    obj = pickle.load(f)

obj[0], obj[2]
dir(obj[2].best_params)
from optuna.visualization import plot_optimization_history
plot_optimization_history(obj[2])

FileNotFoundError: [Errno 2] No such file or directory: 'adult/concatenate/checkpoint/param_search/searcher-state-2022-04-28_04-26-37.pkl'

In [None]:
obj[2].best_params