In [None]:
import nn_utils
import builders
import importlib

from ray import tune
import optuna
from ray.tune.suggest.optuna import OptunaSearch
import torch

from ray.tune.schedulers import AsyncHyperBandScheduler
from ray.tune import ExperimentAnalysis
from ray.tune import register_trainable

import inspect
import argparse
import skorch
import os

import json

from torch.utils import tensorboard
from sklearn import metrics

In [None]:
DATASETS = [
    "sylvine", #"volkert",
    "adult", "australian",
    "anneal",  
    "jasmine", "kr_vs_kp", 
    "nomao", "ldpa"
]
AGGREGATORS = ["cls", "concatenate", "rnn", "sum", "mean", "max"]
BATCH_SIZE = 128
SEED = 11

In [None]:
results = {}
best_configs = {}

In [None]:
def count_parameters(model, trainable=True):
    total_params = 0
    
    for name, parameter in model.module.named_parameters():
        
        if not parameter.requires_grad and trainable: 
            continue
            
        params = parameter.numel()
        total_params+=params
        
    return total_params

In [None]:
errors= []

for dataset_ in DATASETS:
    for aggregator_str_ in AGGREGATORS:
        
        dataset = dataset_
        aggregator_str = aggregator_str_
       
       
        print(f"Using -- Dataset:{dataset} Aggregator:{aggregator_str}")

        #####################################################
        # Configuration
        #####################################################

        MODULE = f"{dataset}.{aggregator_str}.config"
        CHECKPOINT_DIR = f"./{dataset}/{aggregator_str}/checkpoint"
        SEED = 11
        N_SAMPLES = 30
        BATCH_SIZE = 128
        multiclass = False

        #####################################################
        # Util functions
        #####################################################

        def get_class_from_type(module, class_type):
            for attr in dir(module):
                clazz = getattr(module, attr)
                if callable(clazz) and inspect.isclass(clazz) and issubclass(clazz, class_type) and not str(clazz)==str(class_type):
                    return clazz

            return None

        def get_params_startswith(params, prefix):
            keys = [k for k in params.keys() if k.startswith(prefix)]
            extracted = {}

            for k in keys:
                extracted[k.replace(prefix, "")] = params.pop(k)

            return extracted


        def trainable(config, checkpoint_dir=CHECKPOINT_DIR):
            embedding_size = config.pop("embedding_size")

            encoders_params = get_params_startswith(config, "encoders__")
            aggregator_params = get_params_startswith(config, "aggregator__")
            preprocessor_params = get_params_startswith(config, "preprocessor__")

            model_params = {
                **config,
                "encoders": transformer_config.get_encoders(embedding_size, **{**config, **encoders_params}),
                "aggregator": transformer_config.get_aggregator(embedding_size, **{**config, **aggregator_params}),
                "preprocessor": transformer_config.get_preprocessor(**{**config, **preprocessor_params}),
                "optimizer": torch.optim.SGD,
                "criterion": criterion,
                "device": "cuda" if torch.cuda.is_available() else "cpu",
                "batch_size": BATCH_SIZE,
                "max_epochs": 1,
                "n_output": n_labels, # The number of output neurons
                "need_weights": False,
                "verbose": 1

            }

            model = nn_utils.build_transformer_model(
                        train_indices,
                        val_indices, 
                        [],
                        **model_params
                        )
            
            return model
        

        #####################################################
        # Dataset and components
        #####################################################

        module = importlib.import_module(MODULE)

        dataset = get_class_from_type(module, builders.DatasetConfig)
        if dataset is not None:
            dataset = dataset()
        else:
            raise ValueError("Dataset configuration not found")

        transformer_config = get_class_from_type(module, builders.TransformerConfig)
        if transformer_config is not None:
            transformer_config = transformer_config()
        else:
            raise ValueError("Transformer configuration not found")

        search_space_config = get_class_from_type(module, builders.SearchSpaceConfig)
        if search_space_config is not None:
            search_space_config = search_space_config()
        else:
            raise ValueError("Search space configuration not found")

        #####################################################
        # Configure dataset
        #####################################################

        if not dataset.exists():
            dataset.download()

        dataset.load(seed=SEED)

        preprocessor = nn_utils.get_default_preprocessing_pipeline(
                                dataset.get_categorical_columns(),
                                dataset.get_numerical_columns()
                            )

        #####################################################
        # Data preparation
        #####################################################

        train_features, train_labels = dataset.get_train_data()
        val_features, val_labels = dataset.get_val_data()
        test_features, test_labels = dataset.get_test_data()

        preprocessor = preprocessor.fit(train_features, train_labels)

        train_features = preprocessor.transform(train_features)
        val_features = preprocessor.transform(val_features)
        test_features = preprocessor.transform(test_features)

        all_features, all_labels, indices = nn_utils.join_data([train_features, val_features], [train_labels, val_labels])
        train_indices, val_indices = indices[0], indices[1]

        if dataset.get_n_labels() <= 2:
            n_labels = 1
            criterion = torch.nn.BCEWithLogitsLoss
        else:
            n_labels = dataset.get_n_labels()
            multiclass = True
            criterion = torch.nn.CrossEntropyLoss

        #####################################################
        # Hyperparameter search
        #####################################################
        
        #register_trainable("training_function", training_function)
        register_trainable("trainable", trainable)
        
        try:
            '''
            analysis = tune.run(
                trainable,
                resume="AUTO",
                local_dir=CHECKPOINT_DIR, 
                name="param_search"    
            )
            '''
            
            analysis = ExperimentAnalysis(os.path.join(CHECKPOINT_DIR, "param_search"))
            best_config = analysis.get_best_config(metric="balanced_accuracy", mode="max")
            
            if dataset_ not in results:
                results[dataset_] = {}
            
            if aggregator_str_ not in results[dataset_]:
                results[dataset_][aggregator_str_] = {}
            
            
            for trial_idx, trial in enumerate(analysis.trials):
                model = trainable(trial.config)
                #print("*" * 50)
                #print(trial.config)
                #print(trial.last_result)
                #print(trial.metric_analysis)
                #print(trial.checkpoint)
                #print(count_parameters(model, trainable=False))
                #print(count_parameters(model, trainable=True))
                #print("*" * 50)
                
                results[dataset_][aggregator_str_][trial_idx] = {}
                
                results[dataset_][aggregator_str_][trial_idx]["trial"] = trial_idx
                results[dataset_][aggregator_str_][trial_idx]["config"] = trial.config
                results[dataset_][aggregator_str_][trial_idx]["balanced_accuracy_max"] = trial.metric_analysis["balanced_accuracy"]["max"]
                results[dataset_][aggregator_str_][trial_idx]["training_iter_sec"] = trial.metric_analysis["time_total_s"]["avg"]
                results[dataset_][aggregator_str_][trial_idx]["non_trainable_params"] = count_parameters(model, trainable=False)
                results[dataset_][aggregator_str_][trial_idx]["trainable_params"] = count_parameters(model, trainable=True)
            
            #model = trainable(best_config)
            #y_pred = model.predict(test_features)

            #if dataset_ not in results:
            #    results[dataset_] = {}
            #    best_configs[dataset_] = {}

            #if aggregator_str_ not in results[dataset_]:
            #    results[dataset_][aggregator_str_] = {}
            #    best_configs[dataset_][aggregator_str_] = best_config

            #results[dataset_][aggregator_str_]["loss"] = metrics.log_loss(test_labels, y_pred)
            #results[dataset_][aggregator_str_]["balanced_accuracy"] = metrics.balanced_accuracy_score(test_labels, y_pred)
            #results[dataset_][aggregator_str_]["n_parameters"] = count_parameters(model, trainable=False)
            #results[dataset_][aggregator_str_]["n_trainable"] = count_parameters(model)
            #results[dataset_][aggregator_str_]["roc_auc"] = metrics.roc_auc_score(test_labels, y_pred)
            #print(metrics.balanced_accuracy_score(test_labels, y_pred))
            
        except Exception as e:
            errors.append("{}.{} - {}".format(dataset_, aggregator_str_, str(e)))
            pass