# Extract data

In [1]:
import nn_utils
import builders
import importlib

from ray import tune
import optuna
from ray.tune.suggest.optuna import OptunaSearch
import torch

from ray.tune.schedulers import AsyncHyperBandScheduler
from ray.tune import ExperimentAnalysis
from ray.tune import register_trainable

import inspect
import argparse
import skorch
import os

import json

from torch.utils import tensorboard
from sklearn import metrics
import pandas as pd

In [2]:
DATASETS = [
    "sylvine", #"volkert",
    "adult", "australian",
    "anneal",  
    "jasmine", "kr_vs_kp", 
    "nomao", "ldpa"
]
AGGREGATORS = ["cls", "concatenate", "rnn", "sum", "mean", "max"]
BATCH_SIZE = 128
SEED = 11

In [3]:
results = {}
best_configs = {}

In [4]:
def count_parameters(model, only_trainable=True):
    total_params = 0
    
    for name, parameter in model.module.named_parameters():
        
        if not parameter.requires_grad and only_trainable: 
            continue
            
        params = parameter.numel()
        total_params+=params
        
    return total_params

In [5]:
errors= []

for dataset_ in DATASETS:
    for aggregator_str_ in AGGREGATORS:
        
        dataset = dataset_
        aggregator_str = aggregator_str_
       
       
        print(f"Using -- Dataset:{dataset} Aggregator:{aggregator_str}")

        #####################################################
        # Configuration
        #####################################################

        MODULE = f"{dataset}.{aggregator_str}.config"
        CHECKPOINT_DIR = f"./{dataset}/{aggregator_str}/checkpoint"
        SEED = 11
        N_SAMPLES = 30
        BATCH_SIZE = 128
        multiclass = False

        #####################################################
        # Util functions
        #####################################################

        def get_class_from_type(module, class_type):
            for attr in dir(module):
                clazz = getattr(module, attr)
                if callable(clazz) and inspect.isclass(clazz) and issubclass(clazz, class_type) and not str(clazz)==str(class_type):
                    return clazz

            return None

        def get_params_startswith(params, prefix):
            keys = [k for k in params.keys() if k.startswith(prefix)]
            extracted = {}

            for k in keys:
                extracted[k.replace(prefix, "")] = params.pop(k)

            return extracted


        def trainable(config, checkpoint_dir=CHECKPOINT_DIR):
            embedding_size = config.pop("embedding_size")

            encoders_params = get_params_startswith(config, "encoders__")
            aggregator_params = get_params_startswith(config, "aggregator__")
            preprocessor_params = get_params_startswith(config, "preprocessor__")

            model_params = {
                **config,
                "encoders": transformer_config.get_encoders(embedding_size, **{**config, **encoders_params}),
                "aggregator": transformer_config.get_aggregator(embedding_size, **{**config, **aggregator_params}),
                "preprocessor": transformer_config.get_preprocessor(**{**config, **preprocessor_params}),
                "optimizer": torch.optim.SGD,
                "criterion": criterion,
                "device": "cuda" if torch.cuda.is_available() else "cpu",
                "batch_size": BATCH_SIZE,
                "max_epochs": 1,
                "n_output": n_labels, # The number of output neurons
                "need_weights": False,
                "verbose": 1

            }

            model = nn_utils.build_transformer_model(
                        train_indices,
                        val_indices, 
                        [],
                        **model_params
                        )
            
            return model
        

        #####################################################
        # Dataset and components
        #####################################################

        module = importlib.import_module(MODULE)

        dataset = get_class_from_type(module, builders.DatasetConfig)
        if dataset is not None:
            dataset = dataset()
        else:
            raise ValueError("Dataset configuration not found")

        transformer_config = get_class_from_type(module, builders.TransformerConfig)
        if transformer_config is not None:
            transformer_config = transformer_config()
        else:
            raise ValueError("Transformer configuration not found")

        search_space_config = get_class_from_type(module, builders.SearchSpaceConfig)
        if search_space_config is not None:
            search_space_config = search_space_config()
        else:
            raise ValueError("Search space configuration not found")

        #####################################################
        # Configure dataset
        #####################################################

        if not dataset.exists():
            dataset.download()

        dataset.load(seed=SEED)

        preprocessor = nn_utils.get_default_preprocessing_pipeline(
                                dataset.get_categorical_columns(),
                                dataset.get_numerical_columns()
                            )

        #####################################################
        # Data preparation
        #####################################################

        train_features, train_labels = dataset.get_train_data()
        val_features, val_labels = dataset.get_val_data()
        test_features, test_labels = dataset.get_test_data()

        preprocessor = preprocessor.fit(train_features, train_labels)

        train_features = preprocessor.transform(train_features)
        val_features = preprocessor.transform(val_features)
        test_features = preprocessor.transform(test_features)

        all_features, all_labels, indices = nn_utils.join_data([train_features, val_features], [train_labels, val_labels])
        train_indices, val_indices = indices[0], indices[1]

        if dataset.get_n_labels() <= 2:
            n_labels = 1
            criterion = torch.nn.BCEWithLogitsLoss
        else:
            n_labels = dataset.get_n_labels()
            multiclass = True
            criterion = torch.nn.CrossEntropyLoss

        #####################################################
        # Hyperparameter search
        #####################################################
        
        #register_trainable("training_function", training_function)
        register_trainable("trainable", trainable)
        
        try:
            '''
            analysis = tune.run(
                trainable,
                resume="AUTO",
                local_dir=CHECKPOINT_DIR, 
                name="param_search"    
            )
            '''
            
            analysis = ExperimentAnalysis(os.path.join(CHECKPOINT_DIR, "param_search"))
            best_config = analysis.get_best_config(metric="balanced_accuracy", mode="max")
            
            if dataset_ not in results:
                results[dataset_] = {}
            
            if aggregator_str_ not in results[dataset_]:
                results[dataset_][aggregator_str_] = {}
            
            
            for trial_idx, trial in enumerate(analysis.trials):
                model = trainable(trial.config)
                #print("*" * 50)
                #print(trial.config)
                #print(trial.last_result)
                #print(trial.metric_analysis)
                #print(trial.checkpoint)
                #print(count_parameters(model, trainable=False))
                #print(count_parameters(model, trainable=True))
                #print("*" * 50)
                
                results[dataset_][aggregator_str_][trial_idx] = {}
                
                results[dataset_][aggregator_str_][trial_idx]["trial"] = trial_idx
                results[dataset_][aggregator_str_][trial_idx]["config"] = trial.config
                results[dataset_][aggregator_str_][trial_idx]["trial_balanced_accuracy_max"] = trial.metric_analysis["balanced_accuracy"]["max"]
                results[dataset_][aggregator_str_][trial_idx]["training_iter_sec"] = trial.metric_analysis["time_total_s"]["avg"]
                results[dataset_][aggregator_str_][trial_idx]["non_trainable_params"] = count_parameters(model, only_trainable=False) - count_parameters(model, only_trainable=True)
                results[dataset_][aggregator_str_][trial_idx]["trainable_params"] = count_parameters(model, only_trainable=True)
            
            #model = trainable(best_config)
            #y_pred = model.predict(test_features)

            #if dataset_ not in results:
            #    results[dataset_] = {}
            #    best_configs[dataset_] = {}

            #if aggregator_str_ not in results[dataset_]:
            #    results[dataset_][aggregator_str_] = {}
            #    best_configs[dataset_][aggregator_str_] = best_config

            #results[dataset_][aggregator_str_]["loss"] = metrics.log_loss(test_labels, y_pred)
            #results[dataset_][aggregator_str_]["balanced_accuracy"] = metrics.balanced_accuracy_score(test_labels, y_pred)
            #results[dataset_][aggregator_str_]["n_parameters"] = count_parameters(model, trainable=False)
            #results[dataset_][aggregator_str_]["n_trainable"] = count_parameters(model)
            #results[dataset_][aggregator_str_]["roc_auc"] = metrics.roc_auc_score(test_labels, y_pred)
            #print(metrics.balanced_accuracy_score(test_labels, y_pred))
            
        except Exception as e:
            errors.append("{}.{} - {}".format(dataset_, aggregator_str_, str(e)))
            pass
        
        


Using -- Dataset:sylvine Aggregator:cls
Target mapping: {1: 0, 0: 1}
Numerical columns: ['V1', 'V2', 'V3', 'V4', 'V5', 'V6', 'V7', 'V8', 'V9', 'V10', 'V11', 'V12', 'V13', 'V14', 'V15', 'V16', 'V17', 'V18', 'V19', 'V20']
Categorical columns: []
Columns: ['V1', 'V2', 'V3', 'V4', 'V5', 'V6', 'V7', 'V8', 'V9', 'V10', 'V11', 'V12', 'V13', 'V14', 'V15', 'V16', 'V17', 'V18', 'V19', 'V20', 'class']
Using -- Dataset:sylvine Aggregator:concatenate
Target mapping: {1: 0, 0: 1}
Numerical columns: ['V1', 'V2', 'V3', 'V4', 'V5', 'V6', 'V7', 'V8', 'V9', 'V10', 'V11', 'V12', 'V13', 'V14', 'V15', 'V16', 'V17', 'V18', 'V19', 'V20']
Categorical columns: []
Columns: ['V1', 'V2', 'V3', 'V4', 'V5', 'V6', 'V7', 'V8', 'V9', 'V10', 'V11', 'V12', 'V13', 'V14', 'V15', 'V16', 'V17', 'V18', 'V19', 'V20', 'class']
Using -- Dataset:sylvine Aggregator:rnn
Target mapping: {1: 0, 0: 1}
Numerical columns: ['V1', 'V2', 'V3', 'V4', 'V5', 'V6', 'V7', 'V8', 'V9', 'V10', 'V11', 'V12', 'V13', 'V14', 'V15', 'V16', 'V17', 'V18'



Using -- Dataset:sylvine Aggregator:sum
Target mapping: {1: 0, 0: 1}
Numerical columns: ['V1', 'V2', 'V3', 'V4', 'V5', 'V6', 'V7', 'V8', 'V9', 'V10', 'V11', 'V12', 'V13', 'V14', 'V15', 'V16', 'V17', 'V18', 'V19', 'V20']
Categorical columns: []
Columns: ['V1', 'V2', 'V3', 'V4', 'V5', 'V6', 'V7', 'V8', 'V9', 'V10', 'V11', 'V12', 'V13', 'V14', 'V15', 'V16', 'V17', 'V18', 'V19', 'V20', 'class']
Using -- Dataset:sylvine Aggregator:mean
Target mapping: {1: 0, 0: 1}
Numerical columns: ['V1', 'V2', 'V3', 'V4', 'V5', 'V6', 'V7', 'V8', 'V9', 'V10', 'V11', 'V12', 'V13', 'V14', 'V15', 'V16', 'V17', 'V18', 'V19', 'V20']
Categorical columns: []
Columns: ['V1', 'V2', 'V3', 'V4', 'V5', 'V6', 'V7', 'V8', 'V9', 'V10', 'V11', 'V12', 'V13', 'V14', 'V15', 'V16', 'V17', 'V18', 'V19', 'V20', 'class']
Using -- Dataset:sylvine Aggregator:max
Target mapping: {1: 0, 0: 1}
Numerical columns: ['V1', 'V2', 'V3', 'V4', 'V5', 'V6', 'V7', 'V8', 'V9', 'V10', 'V11', 'V12', 'V13', 'V14', 'V15', 'V16', 'V17', 'V18', 'V19'



Using -- Dataset:adult Aggregator:sum
Target mapping: {'<=50K': 0, '>50K': 1}
Numerical columns: ['age', 'fnlwgt', 'education-num', 'capital-gain', 'capital-loss', 'hours-per-week']
Categorical columns: ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race', 'sex', 'native-country']
Columns: ['age', 'workclass', 'fnlwgt', 'education', 'education-num', 'marital-status', 'occupation', 'relationship', 'race', 'sex', 'capital-gain', 'capital-loss', 'hours-per-week', 'native-country', 'class']
Using -- Dataset:adult Aggregator:mean
Target mapping: {'<=50K': 0, '>50K': 1}
Numerical columns: ['age', 'fnlwgt', 'education-num', 'capital-gain', 'capital-loss', 'hours-per-week']
Categorical columns: ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race', 'sex', 'native-country']
Columns: ['age', 'workclass', 'fnlwgt', 'education', 'education-num', 'marital-status', 'occupation', 'relationship', 'race', 'sex', 'capital-gain', 'capital-loss', 'h



Using -- Dataset:australian Aggregator:sum
Target mapping: {0: 0, 1: 1}
Numerical columns: ['A2', 'A3', 'A7', 'A10', 'A13', 'A14']
Categorical columns: ['A1', 'A4', 'A5', 'A6', 'A8', 'A9', 'A11', 'A12']
Columns: ['A1', 'A2', 'A3', 'A4', 'A5', 'A6', 'A7', 'A8', 'A9', 'A10', 'A11', 'A12', 'A13', 'A14', 'A15']
Using -- Dataset:australian Aggregator:mean
Target mapping: {0: 0, 1: 1}
Numerical columns: ['A2', 'A3', 'A7', 'A10', 'A13', 'A14']
Categorical columns: ['A1', 'A4', 'A5', 'A6', 'A8', 'A9', 'A11', 'A12']
Columns: ['A1', 'A2', 'A3', 'A4', 'A5', 'A6', 'A7', 'A8', 'A9', 'A10', 'A11', 'A12', 'A13', 'A14', 'A15']
Using -- Dataset:australian Aggregator:max
Target mapping: {0: 0, 1: 1}
Numerical columns: ['A2', 'A3', 'A7', 'A10', 'A13', 'A14']
Categorical columns: ['A1', 'A4', 'A5', 'A6', 'A8', 'A9', 'A11', 'A12']
Columns: ['A1', 'A2', 'A3', 'A4', 'A5', 'A6', 'A7', 'A8', 'A9', 'A10', 'A11', 'A12', 'A13', 'A14', 'A15']
Using -- Dataset:anneal Aggregator:cls
Target mapping: {'3': 0, 'U': 1, 



Using -- Dataset:anneal Aggregator:sum
Target mapping: {'3': 0, 'U': 1, '1': 2, '5': 3, '2': 4}
Numerical columns: ['carbon', 'hardness', 'strength']
Categorical columns: ['family', 'product-type', 'steel', 'temper_rolling', 'condition', 'formability', 'non-ageing', 'surface-finish', 'surface-quality', 'enamelability', 'bc', 'bf', 'bt', 'bw%2Fme', 'bl', 'chrom', 'phos', 'cbond', 'exptl', 'ferro', 'blue%2Fbright%2Fvarn%2Fclean', 'lustre', 'shape', 'thick', 'width', 'len', 'oil', 'bore', 'packing']
Columns: ['family', 'product-type', 'steel', 'carbon', 'hardness', 'temper_rolling', 'condition', 'formability', 'strength', 'non-ageing', 'surface-finish', 'surface-quality', 'enamelability', 'bc', 'bf', 'bt', 'bw%2Fme', 'bl', 'chrom', 'phos', 'cbond', 'exptl', 'ferro', 'blue%2Fbright%2Fvarn%2Fclean', 'lustre', 'shape', 'thick', 'width', 'len', 'oil', 'bore', 'packing', 'class']
Using -- Dataset:anneal Aggregator:mean
Target mapping: {'3': 0, 'U': 1, '1': 2, '5': 3, '2': 4}
Numerical columns:



Using -- Dataset:jasmine Aggregator:sum
Target mapping: {1: 0, 0: 1}
Numerical columns: ['V13', 'V23', 'V43', 'V45', 'V56', 'V59', 'V126', 'V131']
Categorical columns: ['V1', 'V2', 'V3', 'V4', 'V5', 'V6', 'V7', 'V8', 'V9', 'V10', 'V11', 'V12', 'V14', 'V15', 'V16', 'V17', 'V18', 'V19', 'V20', 'V21', 'V22', 'V24', 'V25', 'V26', 'V27', 'V28', 'V29', 'V30', 'V31', 'V32', 'V33', 'V34', 'V35', 'V36', 'V37', 'V38', 'V39', 'V40', 'V41', 'V42', 'V44', 'V46', 'V47', 'V48', 'V49', 'V50', 'V51', 'V52', 'V53', 'V54', 'V55', 'V57', 'V58', 'V60', 'V61', 'V62', 'V63', 'V64', 'V65', 'V66', 'V67', 'V68', 'V69', 'V70', 'V71', 'V72', 'V73', 'V74', 'V75', 'V76', 'V77', 'V78', 'V79', 'V80', 'V81', 'V82', 'V83', 'V84', 'V85', 'V86', 'V87', 'V88', 'V89', 'V90', 'V91', 'V92', 'V93', 'V94', 'V95', 'V96', 'V97', 'V98', 'V99', 'V100', 'V101', 'V102', 'V103', 'V104', 'V105', 'V106', 'V107', 'V108', 'V109', 'V110', 'V111', 'V112', 'V113', 'V114', 'V115', 'V116', 'V117', 'V118', 'V119', 'V120', 'V121', 'V122', 'V123

Using -- Dataset:kr_vs_kp Aggregator:rnn
Target mapping: {'won': 0, 'nowin': 1}
Numerical columns: []
Categorical columns: ['bkblk', 'bknwy', 'bkon8', 'bkona', 'bkspr', 'bkxbq', 'bkxcr', 'bkxwp', 'blxwp', 'bxqsq', 'cntxt', 'dsopp', 'dwipd', 'hdchk', 'katri', 'mulch', 'qxmsq', 'r2ar8', 'reskd', 'reskr', 'rimmx', 'rkxwp', 'rxmsq', 'simpl', 'skach', 'skewr', 'skrxp', 'spcop', 'stlmt', 'thrsk', 'wkcti', 'wkna8', 'wknck', 'wkovl', 'wkpos', 'wtoeg']
Columns: ['bkblk', 'bknwy', 'bkon8', 'bkona', 'bkspr', 'bkxbq', 'bkxcr', 'bkxwp', 'blxwp', 'bxqsq', 'cntxt', 'dsopp', 'dwipd', 'hdchk', 'katri', 'mulch', 'qxmsq', 'r2ar8', 'reskd', 'reskr', 'rimmx', 'rkxwp', 'rxmsq', 'simpl', 'skach', 'skewr', 'skrxp', 'spcop', 'stlmt', 'thrsk', 'wkcti', 'wkna8', 'wknck', 'wkovl', 'wkpos', 'wtoeg', 'class']




Using -- Dataset:kr_vs_kp Aggregator:sum
Target mapping: {'won': 0, 'nowin': 1}
Numerical columns: []
Categorical columns: ['bkblk', 'bknwy', 'bkon8', 'bkona', 'bkspr', 'bkxbq', 'bkxcr', 'bkxwp', 'blxwp', 'bxqsq', 'cntxt', 'dsopp', 'dwipd', 'hdchk', 'katri', 'mulch', 'qxmsq', 'r2ar8', 'reskd', 'reskr', 'rimmx', 'rkxwp', 'rxmsq', 'simpl', 'skach', 'skewr', 'skrxp', 'spcop', 'stlmt', 'thrsk', 'wkcti', 'wkna8', 'wknck', 'wkovl', 'wkpos', 'wtoeg']
Columns: ['bkblk', 'bknwy', 'bkon8', 'bkona', 'bkspr', 'bkxbq', 'bkxcr', 'bkxwp', 'blxwp', 'bxqsq', 'cntxt', 'dsopp', 'dwipd', 'hdchk', 'katri', 'mulch', 'qxmsq', 'r2ar8', 'reskd', 'reskr', 'rimmx', 'rkxwp', 'rxmsq', 'simpl', 'skach', 'skewr', 'skrxp', 'spcop', 'stlmt', 'thrsk', 'wkcti', 'wkna8', 'wknck', 'wkovl', 'wkpos', 'wtoeg', 'class']
Using -- Dataset:kr_vs_kp Aggregator:mean
Target mapping: {'won': 0, 'nowin': 1}
Numerical columns: []
Categorical columns: ['bkblk', 'bknwy', 'bkon8', 'bkona', 'bkspr', 'bkxbq', 'bkxcr', 'bkxwp', 'blxwp', 'bx



Using -- Dataset:nomao Aggregator:sum
Target mapping: {2: 0, 1: 1}
Numerical columns: ['V1', 'V2', 'V3', 'V4', 'V5', 'V6', 'V9', 'V10', 'V11', 'V12', 'V13', 'V14', 'V17', 'V18', 'V19', 'V20', 'V21', 'V22', 'V25', 'V26', 'V27', 'V28', 'V29', 'V30', 'V33', 'V34', 'V35', 'V36', 'V37', 'V38', 'V41', 'V42', 'V43', 'V44', 'V45', 'V46', 'V49', 'V50', 'V51', 'V52', 'V53', 'V54', 'V57', 'V58', 'V59', 'V60', 'V61', 'V62', 'V65', 'V66', 'V67', 'V68', 'V69', 'V70', 'V73', 'V74', 'V75', 'V76', 'V77', 'V78', 'V81', 'V82', 'V83', 'V84', 'V85', 'V86', 'V89', 'V90', 'V91', 'V93', 'V94', 'V95', 'V97', 'V98', 'V99', 'V101', 'V102', 'V103', 'V105', 'V106', 'V107', 'V109', 'V110', 'V111', 'V113', 'V114', 'V115', 'V117', 'V118']
Categorical columns: ['V7', 'V8', 'V15', 'V16', 'V23', 'V24', 'V31', 'V32', 'V39', 'V40', 'V47', 'V48', 'V55', 'V56', 'V63', 'V64', 'V71', 'V72', 'V79', 'V80', 'V87', 'V88', 'V92', 'V96', 'V100', 'V104', 'V108', 'V112', 'V116']
Columns: ['V1', 'V2', 'V3', 'V4', 'V5', 'V6', 'V7', 'V8



Using -- Dataset:ldpa Aggregator:sum
Target mapping: {3: 0, 8: 1, 7: 2, 11: 3, 1: 4, 4: 5, 10: 6, 5: 7, 9: 8, 2: 9, 6: 10}
Numerical columns: ['V3', 'V4', 'V5', 'V6', 'V7']
Categorical columns: ['V1', 'V2']
Columns: ['V1', 'V2', 'V3', 'V4', 'V5', 'V6', 'V7', 'Class']
Using -- Dataset:ldpa Aggregator:mean
Target mapping: {3: 0, 8: 1, 7: 2, 11: 3, 1: 4, 4: 5, 10: 6, 5: 7, 9: 8, 2: 9, 6: 10}
Numerical columns: ['V3', 'V4', 'V5', 'V6', 'V7']
Categorical columns: ['V1', 'V2']
Columns: ['V1', 'V2', 'V3', 'V4', 'V5', 'V6', 'V7', 'Class']
Using -- Dataset:ldpa Aggregator:max
Target mapping: {3: 0, 8: 1, 7: 2, 11: 3, 1: 4, 4: 5, 10: 6, 5: 7, 9: 8, 2: 9, 6: 10}
Numerical columns: ['V3', 'V4', 'V5', 'V6', 'V7']
Categorical columns: ['V1', 'V2']
Columns: ['V1', 'V2', 'V3', 'V4', 'V5', 'V6', 'V7', 'Class']


In [6]:
errors

[]

In [7]:
results

{'sylvine': {'cls': {0: {'trial': 0,
    'config': {'n_layers': 2,
     'optimizer__lr': 0.005614550964088092,
     'n_head': 2,
     'n_hid': 512,
     'dropout': 0.19933655105104087,
     'numerical_passthrough': True},
    'trial_balanced_accuracy_max': 0.8643328367737817,
    'training_iter_sec': 5.400446510314939,
    'non_trainable_params': 0,
    'trainable_params': 1070400},
   1: {'trial': 1,
    'config': {'n_layers': 3,
     'optimizer__lr': 0.0969734797928502,
     'n_head': 32,
     'n_hid': 32,
     'dropout': 0.3775996282196361,
     'numerical_passthrough': False},
    'trial_balanced_accuracy_max': 0.9495789377679142,
    'training_iter_sec': 38.24851706981657,
    'non_trainable_params': 0,
    'trainable_params': 3290721},
   2: {'trial': 2,
    'config': {'n_layers': 4,
     'optimizer__lr': 0.006356116028811235,
     'n_head': 1,
     'n_hid': 64,
     'dropout': 0.4914561095534663,
     'numerical_passthrough': False},
    'trial_balanced_accuracy_max': 0.68807953

In [8]:
results_backup = results

# Clean results

In [9]:
ds_results = [] 
for ds_key in results:
    for agg_key in results[ds_key]:       
        for trial in results[ds_key][agg_key]:
            ds_trial_dict = {
                "dataset": ds_key,
                "aggregator": agg_key,
                **results[ds_key][agg_key][trial]["config"],
                **results[ds_key][agg_key][trial]
            }
            
            del ds_trial_dict["config"]
            
            ds_results.append(ds_trial_dict)

ds_results = pd.DataFrame(ds_results)

In [10]:
ds_results

Unnamed: 0,dataset,aggregator,n_layers,optimizer__lr,n_head,n_hid,dropout,numerical_passthrough,trial,trial_balanced_accuracy_max,training_iter_sec,non_trainable_params,trainable_params
0,sylvine,cls,2,0.005615,2,512,0.199337,True,0,0.864333,5.400447,0,1070400
1,sylvine,cls,3,0.096973,32,32,0.377600,False,1,0.949579,38.248517,0,3290721
2,sylvine,cls,4,0.006356,1,64,0.491456,False,2,0.688080,8.777578,0,36225
3,sylvine,cls,2,0.000444,16,32,0.225905,True,3,0.557634,3.126000,0,157952
4,sylvine,cls,3,0.000013,4,128,0.029897,False,4,0.501931,7.945999,0,3585921
...,...,...,...,...,...,...,...,...,...,...,...,...,...
1435,ldpa,max,4,0.056866,1,1024,0.168485,True,25,0.162316,970.861393,0,3169100
1436,ldpa,max,2,0.001652,4,256,0.374320,False,26,0.311577,711.008068,0,2651147
1437,ldpa,max,1,0.021208,4,256,0.251674,False,27,0.368412,470.559668,0,405771
1438,ldpa,max,2,0.000017,8,256,0.491890,False,28,0.141553,97.329034,0,2651147


# Add extra info

## Dataset info

In [11]:
with open("selected_datasets.json", "r") as f:
    ds_info = json.load(f)

In [12]:
ds_info_df = pd.DataFrame(ds_info)[["name", "label"]]

In [13]:
ds_info_df

Unnamed: 0,name,label
27,sylvine,0
23,volkert,0
5,adult,1
16,australian,1
0,anneal,2
26,jasmine,3
1,kr_vs_kp,3
10,nomao,4
9,ldpa,4


In [14]:
ds_results = ds_results.merge(ds_info_df, left_on="dataset", right_on="name").drop("name", axis=1)

In [15]:
ds_results

Unnamed: 0,dataset,aggregator,n_layers,optimizer__lr,n_head,n_hid,dropout,numerical_passthrough,trial,trial_balanced_accuracy_max,training_iter_sec,non_trainable_params,trainable_params,label
0,sylvine,cls,2,0.005615,2,512,0.199337,True,0,0.864333,5.400447,0,1070400,0
1,sylvine,cls,3,0.096973,32,32,0.377600,False,1,0.949579,38.248517,0,3290721,0
2,sylvine,cls,4,0.006356,1,64,0.491456,False,2,0.688080,8.777578,0,36225,0
3,sylvine,cls,2,0.000444,16,32,0.225905,True,3,0.557634,3.126000,0,157952,0
4,sylvine,cls,3,0.000013,4,128,0.029897,False,4,0.501931,7.945999,0,3585921,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1435,ldpa,max,4,0.056866,1,1024,0.168485,True,25,0.162316,970.861393,0,3169100,4
1436,ldpa,max,2,0.001652,4,256,0.374320,False,26,0.311577,711.008068,0,2651147,4
1437,ldpa,max,1,0.021208,4,256,0.251674,False,27,0.368412,470.559668,0,405771,4
1438,ldpa,max,2,0.000017,8,256,0.491890,False,28,0.141553,97.329034,0,2651147,4


## Evaluation results

In [16]:
with open("eval_results.json", "r") as f:
    eval_results = json.load(f)

In [17]:
eval_results_df = []

for ds_key in eval_results:
    for agg_key in eval_results[ds_key]:       
        eval_result_dict = {
            "dataset": ds_key,
            "aggregator": agg_key,
            **eval_results[ds_key][agg_key]
        }
        
        eval_results_df.append(eval_result_dict)
        
eval_results_df =  pd.DataFrame(eval_results_df)        

In [18]:
eval_results_df

Unnamed: 0,dataset,aggregator,balanced_accuracy_train,balanced_accuracy_val,balanced_accuracy_test,n_parameters,n_trainable
0,sylvine,cls,0.973561,0.9612,0.942541,3290721,3290721
1,sylvine,concatenate,0.920683,0.920507,0.918676,171137,171137
2,sylvine,rnn,0.974764,0.947648,0.944894,4531553,4531553
3,sylvine,sum,0.923219,0.924254,0.9274,8442369,8442369
4,sylvine,mean,0.962298,0.947724,0.93553,1215617,1215617
5,sylvine,max,0.960259,0.941933,0.932105,2268801,2268801
6,adult,cls,0.760682,0.777153,0.759971,12760065,12760065
7,adult,concatenate,0.76085,0.77714,0.760074,3439809,3439809
8,adult,rnn,0.530588,0.526987,0.527268,4666881,4666881
9,adult,sum,0.760459,0.776315,0.760074,3232769,3232769


In [19]:
ds_results = ds_results.merge(
    eval_results_df, 
    left_on=["dataset", "aggregator"], 
    right_on=["dataset", "aggregator"]
    
)

In [20]:
ds_results

Unnamed: 0,dataset,aggregator,n_layers,optimizer__lr,n_head,n_hid,dropout,numerical_passthrough,trial,trial_balanced_accuracy_max,training_iter_sec,non_trainable_params,trainable_params,label,balanced_accuracy_train,balanced_accuracy_val,balanced_accuracy_test,n_parameters,n_trainable
0,sylvine,cls,2,0.005615,2,512,0.199337,True,0,0.864333,5.400447,0,1070400,0,0.973561,0.961200,0.942541,3290721,3290721
1,sylvine,cls,3,0.096973,32,32,0.377600,False,1,0.949579,38.248517,0,3290721,0,0.973561,0.961200,0.942541,3290721,3290721
2,sylvine,cls,4,0.006356,1,64,0.491456,False,2,0.688080,8.777578,0,36225,0,0.973561,0.961200,0.942541,3290721,3290721
3,sylvine,cls,2,0.000444,16,32,0.225905,True,3,0.557634,3.126000,0,157952,0,0.973561,0.961200,0.942541,3290721,3290721
4,sylvine,cls,3,0.000013,4,128,0.029897,False,4,0.501931,7.945999,0,3585921,0,0.973561,0.961200,0.942541,3290721,3290721
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1435,ldpa,max,4,0.056866,1,1024,0.168485,True,25,0.162316,970.861393,0,3169100,4,0.440873,0.438302,0.423968,405771,405771
1436,ldpa,max,2,0.001652,4,256,0.374320,False,26,0.311577,711.008068,0,2651147,4,0.440873,0.438302,0.423968,405771,405771
1437,ldpa,max,1,0.021208,4,256,0.251674,False,27,0.368412,470.559668,0,405771,4,0.440873,0.438302,0.423968,405771,405771
1438,ldpa,max,2,0.000017,8,256,0.491890,False,28,0.141553,97.329034,0,2651147,4,0.440873,0.438302,0.423968,405771,405771


# Export

In [21]:
ds_results.to_csv("all_info.csv", index=False, encoding="utf-8")