In [1]:
import yaml


def generate_model_template(
    dataset_name, model_name, best_params=None, param_grid=None
):
    template = {
        "dataset": dataset_name,
        "model": model_name,
        "best_params": {},
        "param_grid": {},
    }

    if best_params:
        template["best_params"] = best_params

    if param_grid:
        template["param_grid"] = param_grid

    return template

# S1DCNN

In [2]:
s1dcnn_medium_param_grid = {
    "outer_params": {
        "hyperopt_evals": 10,
        "max_epochs": 1000,
        "early_stopping": True,
        "shuffle": True,
        "validation_fraction": 0.15,
        "early_stopping_patience": 6,
    },
    "batch_size": [512, 1024, 2048],
    "hidden_size": [1024, 2048, 4096],
    "optimizer_fn": {
        "Adam": {"weight_decay": [0.00001, 0.001], "learning_rate": [0.0001, 0.001]},
        "AdamW": {"weight_decay": [0.00001, 0.001], "learning_rate": [0.0001, 0.001]},
    },
    "scheduler_fn": {
        "ReduceLROnPlateau": {"factor": [0.1, 0.9], "patience": [3, 5]},
        "ExponentialLR": {"gamma": [0.9, 0.99]},  # Adjusted for ExponentialLR
        "StepLR": {"step_size": [5, 10], "gamma": [0.9, 0.99]},
    },
}


s1dcnn_large_param_grid = {
    "outer_params": {
        "hyperopt_evals": 10,
        "max_epochs": 1000,
        "early_stopping": True,
        "shuffle": True,
        "validation_fraction": 0.15,
        "early_stopping_patience": 6,
    },
    "batch_size": [1024, 2048, 4096],
    "hidden_size": [1024, 2048, 4096],
    "optimizer_fn": {
        "Adam": {"weight_decay": [0.00001, 0.001], "learning_rate": [0.0001, 0.001]},
        "AdamW": {"weight_decay": [0.00001, 0.001], "learning_rate": [0.0001, 0.001]},
    },
    "scheduler_fn": {
        "ReduceLROnPlateau": {"factor": [0.1, 0.9], "patience": [3, 5]},
        "ExponentialLR": {"gamma": [0.9, 0.99]},  # Adjusted for ExponentialLR
        "StepLR": {"step_size": [5, 10], "gamma": [0.9, 0.99]},
    },
}

# ResNet

In [3]:
# Set the desired parameters
resnet_medium_param_grid = {
    "outer_params": {
        "hyperopt_evals": 10,
        "max_epochs": 1000,
        "early_stopping": True,
        "early_stopping_patience": 6,
        "validation_fraction": 0.15,
    },
    "batch_size": [512, 1024, 2048],
    "resnet_depth": ["resnet18", "resnet34", "resnet50"],
    "optimizer_fn": {
        "Adam": {"weight_decay": [0.00001, 0.001], "learning_rate": [0.0001, 0.001]},
        "AdamW": {"weight_decay": [0.00001, 0.001], "learning_rate": [0.0001, 0.001]},
    },
    "scheduler_fn": {
        "ReduceLROnPlateau": {"factor": [0.1, 0.9], "patience": [3, 5]},
        "ExponentialLR": {"gamma": [0.9, 0.99]},  # Adjusted for ExponentialLR
        "StepLR": {"step_size": [5, 10], "gamma": [0.9, 0.99]},
    },
}


# Set the desired parameters
resnet_large_param_grid = {
    "outer_params": {
        "hyperopt_evals": 10,
        "max_epochs": 1000,
        "early_stopping": True,
        "early_stopping_patience": 6,
        "validation_fraction": 0.15,
    },
    "batch_size": [1024, 2048, 4096],
    "resnet_depth": ["resnet18", "resnet34", "resnet50"],
    "optimizer_fn": {
        "Adam": {"weight_decay": [0.00001, 0.001], "learning_rate": [0.0001, 0.001]},
        "AdamW": {"weight_decay": [0.00001, 0.001], "learning_rate": [0.0001, 0.001]},
    },
    "scheduler_fn": {
        "ReduceLROnPlateau": {"factor": [0.1, 0.9], "patience": [3, 5]},
        "ExponentialLR": {"gamma": [0.9, 0.99]},  # Adjusted for ExponentialLR
        "StepLR": {"step_size": [5, 10], "gamma": [0.9, 0.99]},
    },
}

# MLP

In [4]:
mlp_large_param_grid = {
    "outer_params": {
        "cv_iterations": 10,
        "early_stopping": True,
        "cv_size": 5,
        "validation_fraction": 0.15,
        "n_iter_no_change": 6,
        "max_iter": 1000,
    },
    "hidden_layer_sizes": [[64, 32, 16], [256, 128, 64, 32], [128, 64, 32, 16]],
    "activation": ["relu", "tanh", "logistic"],
    "solver": ["adam", "lbfgs"],
    "alpha": [0.0001, 0.001, 0.01],
    "learning_rate_init": [0.0001, 0.01, 0.1],
    "beta_1": [0.99, 0.8],
    "beta_2": [0.999, 0.9],
    "batch_size": [512, 1024, 2048],
}

# XGB

In [5]:
xgb_large_param_grid = {
    "outer_params": {
        "hyperopt_evals": 10,
        "validation_fraction": 0.15,
        "early_stopping_rounds": 30,
        "verbose": False,
    },
    "n_estimators": [100, 1000],
    "max_bin": [256, 32],
    "tree_method": ["auto", "hist"],
    "max_depth": [4, 10],
    "learning_rate": [0.1, 0.33],
    "subsample": [0.7, 1.0],
    "colsample_bytree": [0.5, 1.0],
    "min_child_weight": [1, 10],
    "alpha": [0.0, 5.0],
    "gamma": [0.0, 5.0],
    "lambda": [0.0, 5.0],
}

# CatBoost

In [6]:
catboost_medium_param_grid = {
    "outer_params": {
        "hyperopt_evals": 10,
        "validation_fraction": 0.15,
        "early_stopping_rounds": 100,
        "verbose": False,
    },
    "iterations": [200, 2000],
    "learning_rate": [0.001, 0.1],
    "depth": [4, 10],
    # "colsample_bylevel": [0.05,1.0],
    "l2_leaf_reg": [0.5, 5.0],
    "min_child_samples": [1, 100],
    "bagging_temperature": [0.1, 2.0],
}

# GATE

In [7]:
gate_medium_param_grid = {
    "outer_params": {
        "hyperopt_evals": 10,
        "auto_lr_find": True,
        "max_epochs": 1000,
        "val_size": 0.15,
        "early_stopping_patience": 6,
    },
    # "learning_rate": [0.01, 0.001],
    "batch_size": [512, 1024, 2048],
    "tree_depth": [4, 7],
    "num_trees": [4, 10],
    "chain_trees": [False, True],
    "gflu_stages": [3, 8],
    "gflu_dropout": [0.0, 0.05],
    "tree_dropout": [0.0, 0.05],
    "tree_wise_attention_dropout": [0.0, 0.05],
    "embedding_dropout": [0, 0.2],
    "optimizer_fn": {
        "Adam": {"weight_decay": [0.0001, 0.00001]},
        "AdamW": {"weight_decay": [0.0001, 0.00001]},
    },
    "scheduler_fn": {
        "ReduceLROnPlateau": {"factor": [0.1, 0.9], "patience": [3, 5]},
        "ExponentialLR": {"gamma": [0.9, 0.99]},
    },
}

gate_large_param_grid = {
    "outer_params": {
        "hyperopt_evals": 10,
        "auto_lr_find": True,
        "max_epochs": 1000,
        "val_size": 0.15,
        "early_stopping_patience": 6,
    },
    # "learning_rate": [0.01, 0.001],
    "batch_size": [1024, 2048, 4096],
    "tree_depth": [4, 7],
    "num_trees": [4, 10],
    "chain_trees": [False, True],
    "gflu_stages": [3, 8],
    "gflu_dropout": [0.0, 0.05],
    "tree_dropout": [0.0, 0.05],
    "tree_wise_attention_dropout": [0.0, 0.05],
    "embedding_dropout": [0, 0.2],
    "optimizer_fn": {
        "Adam": {"weight_decay": [0.0001, 0.00001]},
        "AdamW": {"weight_decay": [0.0001, 0.00001]},
    },
    "scheduler_fn": {
        "ReduceLROnPlateau": {"factor": [0.1, 0.9], "patience": [3, 5]},
        "ExponentialLR": {"gamma": [0.9, 0.99]},
    },
}

# TABNET

In [8]:
tabnet_medium_param_grid = {
    "outer_params": {
        "hyperopt_evals": 10,
        "auto_lr_find": True,
        "precision": 16,
        "max_epochs": 1000,
        "val_size": 0.15,
        "early_stopping_patience": 6,
    },
    # "learning_rate": [0.001, 0.0001],
    "virtual_batch_size_ratio": [0.125, 0.25, 0.5, 1.0],
    "batch_size": [512, 1024, 2048],
    "weights": [0, 1],
    "mask_type": ["sparsemax", "entmax"],
    "n_d": [6, 32],
    "n_steps": [1, 6],
    "gamma": [1.0, 2.0],
    "n_independent": [1, 3],
    "n_shared": [1, 3],
    "optimizer_fn": {
        "Adam": {"weight_decay": [0.0001, 0.00001]},
        "AdamW": {"weight_decay": [0.0001, 0.00001]},
    },
    "scheduler_fn": {
        "ReduceLROnPlateau": {"factor": [0.1, 0.9], "patience": [3, 5]},
        "ExponentialLR": {"gamma": [0.9, 0.99]},
    },
}

tabnet_large_param_grid = {
    "outer_params": {
        "hyperopt_evals": 10,
        "auto_lr_find": True,
        "precision": 16,
        "max_epochs": 1000,
        "val_size": 0.15,
        "early_stopping_patience": 6,
    },
    # "learning_rate": [0.001, 0.0001],
    "virtual_batch_size_ratio": [0.125, 0.25, 0.5, 1.0],
    "batch_size": [1024, 2048, 4096],
    "weights": [0, 1],
    "mask_type": ["sparsemax", "entmax"],
    "n_d": [6, 32],
    "n_steps": [1, 6],
    "gamma": [1.0, 2.0],
    "n_independent": [1, 3],
    "n_shared": [1, 3],
    "optimizer_fn": {
        "Adam": {"weight_decay": [0.0001, 0.00001]},
        "AdamW": {"weight_decay": [0.0001, 0.00001]},
    },
    "scheduler_fn": {
        "ReduceLROnPlateau": {"factor": [0.1, 0.9], "patience": [3, 5]},
        "ExponentialLR": {"gamma": [0.9, 0.99]},
    },
}

# FTT Transformer

In [9]:
fttransformer_medium_param_grid = {
    "outer_params": {
        "hyperopt_evals": 10,
        "auto_lr_find": True,
        "precision": 16,
        "max_epochs": 1000,
        "val_size": 0.15,
        "early_stopping_patience": 6,
        "attn_feature_importance": False,
    },
    "learning_rate": [0.001, 0.0001],
    "batch_size": [512, 1024, 2048],
    "num_heads": [4, 10],
    "input_embed_dim_multiplier": [2, 6],
    "embedding_initialization": ["kaiming_uniform", "kaiming_normal"],
    "embedding_dropout": [0.05, 0.2],
    "shared_embedding_fraction": [0.125, 0.25, 0.5],
    "num_attn_blocks": [4, 8],
    "attn_dropout": [0.05, 0.2],
    "add_norm_dropout": [0.05, 0.2],
    "ff_dropout": [0.05, 0.2],
    "ff_hidden_multiplier": [4, 32],
    "transformer_activation": ["GEGLU", "ReGLU", "SwiGLU"],
    "optimizer_fn": {
        "Adam": {"weight_decay": [0.0001, 0.00001]},
        "AdamW": {"weight_decay": [0.0001, 0.00001]},
    },
    "scheduler_fn": {
        "ReduceLROnPlateau": {"factor": [0.1, 0.9], "patience": [3, 5]},
        "ExponentialLR": {"gamma": [0.9, 0.99]},
    },
}

fttransformer_large_param_grid = {
    "outer_params": {
        "hyperopt_evals": 10,
        "auto_lr_find": True,
        "precision": 16,
        "max_epochs": 1000,
        "val_size": 0.15,
        "early_stopping_patience": 6,
        "attn_feature_importance": False,
    },
    "learning_rate": [0.001, 0.0001],
    "batch_size": [1024, 2048, 4096],
    "num_heads": [4, 10],
    "input_embed_dim_multiplier": [2, 6],
    "embedding_initialization": ["kaiming_uniform", "kaiming_normal"],
    "embedding_dropout": [0.05, 0.2],
    "shared_embedding_fraction": [0.125, 0.25, 0.5],
    "num_attn_blocks": [4, 8],
    "attn_dropout": [0.05, 0.2],
    "add_norm_dropout": [0.05, 0.2],
    "ff_dropout": [0.05, 0.2],
    "ff_hidden_multiplier": [4, 32],
    "transformer_activation": ["GEGLU", "ReGLU", "SwiGLU"],
    "optimizer_fn": {
        "Adam": {"weight_decay": [0.0001, 0.00001]},
        "AdamW": {"weight_decay": [0.0001, 0.00001]},
    },
    "scheduler_fn": {
        "ReduceLROnPlateau": {"factor": [0.1, 0.9], "patience": [3, 5]},
        "ExponentialLR": {"gamma": [0.9, 0.99]},
    },
}

# GANDALF

In [10]:
gandalf_medium_param_grid = {
    "outer_params": {
        "hyperopt_evals": 10,
        "auto_lr_find": True,
        "precision": 16,
        "max_epochs": 1000,
        "val_size": 0.15,
        "early_stopping_patience": 6,
    },
    "learning_rate": [0.001, 0.0001],
    "batch_size": [512, 1024, 2048],
    "gflu_stages": [4, 10],
    "gflu_dropout": [0.0, 0.3],
    "embedding_dropout": [0.0, 0.3],
    "optimizer_fn": {
        "Adam": {"weight_decay": [0.0001, 0.00001]},
        "AdamW": {"weight_decay": [0.0001, 0.00001]},
    },
    "scheduler_fn": {
        "ReduceLROnPlateau": {"factor": [0.1, 0.9], "patience": [3, 5]},
        "ExponentialLR": {"gamma": [0.9, 0.99]},
    },
}

gandalf_large_param_grid = {
    "outer_params": {
        "hyperopt_evals": 10,
        "auto_lr_find": True,
        "precision": 16,
        "max_epochs": 1000,
        "val_size": 0.15,
        "early_stopping_patience": 6,
    },
    "learning_rate": [0.001, 0.0001],
    "batch_size": [1024, 2048, 4096],
    "gflu_stages": [4, 10],
    "gflu_dropout": [0.0, 0.3],
    "embedding_dropout": [0.0, 0.3],
    "optimizer_fn": {
        "Adam": {"weight_decay": [0.0001, 0.00001]},
        "AdamW": {"weight_decay": [0.0001, 0.00001]},
    },
    "scheduler_fn": {
        "ReduceLROnPlateau": {"factor": [0.1, 0.9], "patience": [3, 5]},
        "ExponentialLR": {"gamma": [0.9, 0.99]},
    },
}

# Node

In [11]:
node_medium_param_grid = {
    "outer_params": {
        "hyperopt_evals": 10,
        "auto_lr_find": True,
        "max_epochs": 1000,
        "val_size": 0.15,
        "early_stopping_patience": 6,
    },
    # "learning_rate": [0.2, 0.001],
    "batch_size": [512, 1024, 2048],
    "num_layers": [1, 2, 3],
    "num_trees": [12, 128],
    "additional_tree_output_dim": [2, 3, 4],
    "depth": [5, 7],
    "choice_function": ["entmax15", "sparsemax"],
    "bin_function": ["entmoid15", "sparsemoid"],
    "input_dropout": [0.0, 0.1],
    "embedding_dropout": [0.0, 0.1],
    "embed_categorical": [True, False],
    "optimizer_fn": {
        "Adam": {"weight_decay": [0.0001, 0.00001]},
        "AdamW": {"weight_decay": [0.0001, 0.00001]},
    },
    "scheduler_fn": {
        "ReduceLROnPlateau": {"factor": [0.1, 0.9], "patience": [3, 5]},
        "ExponentialLR": {"gamma": [0.9, 0.99]},
    },
}

node_large_param_grid = {
    "outer_params": {
        "hyperopt_evals": 10,
        "auto_lr_find": True,
        "max_epochs": 1000,
        "val_size": 0.15,
        "early_stopping_patience": 6,
    },
    # "learning_rate": [0.2, 0.001],
    "batch_size": [1024, 2048, 4096],
    "num_layers": [1, 2, 3],
    "num_trees": [12, 128],
    "additional_tree_output_dim": [2, 3, 4],
    "depth": [5, 7],
    "choice_function": ["entmax15", "sparsemax"],
    "bin_function": ["entmoid15", "sparsemoid"],
    "input_dropout": [0.0, 0.1],
    "embedding_dropout": [0.0, 0.1],
    "embed_categorical": [True, False],
    "optimizer_fn": {
        "Adam": {"weight_decay": [0.0001, 0.00001]},
        "AdamW": {"weight_decay": [0.0001, 0.00001]},
    },
    "scheduler_fn": {
        "ReduceLROnPlateau": {"factor": [0.1, 0.9], "patience": [3, 5]},
        "ExponentialLR": {"gamma": [0.9, 0.99]},
    },
}

# Category Embedding

In [12]:
catembed_medium_param_grid = {
    "outer_params": {
        "hyperopt_evals": 10,
        "auto_lr_find": True,
        "precision": 16,
        "max_epochs": 1000,
        "val_size": 0.15,
        "early_stopping_patience": 6,
    },
    "batch_size": [512, 1024, 2048],
    "learning_rate": [0.001, 0.0001],
    "layers": ["128-64-32", "128-64-32-16", "256-128-64"],
    "activation": ["ReLU", "LeakyReLU", "Tanh"],
    "initialization": ["kaiming", "xavier"],
    "dropout": [0.0, 0.3],
    "embedding_dropout": [0.0, 0.3],
    "optimizer_fn": {
        "Adam": {"weight_decay": [0.0001, 0.00001]},
        "AdamW": {"weight_decay": [0.0001, 0.00001]},
    },
    "scheduler_fn": {
        "ReduceLROnPlateau": {"factor": [0.1, 0.9], "patience": [3, 5]},
        "ExponentialLR": {"gamma": [0.9, 0.99]},
    },
}

catembed_large_param_grid = {
    "outer_params": {
        "hyperopt_evals": 10,
        "auto_lr_find": True,
        "precision": 16,
        "max_epochs": 1000,
        "val_size": 0.15,
        "early_stopping_patience": 6,
        "precision": 16,
    },
    "batch_size": [1024, 2048, 4096],
    "learning_rate": [0.001, 0.0001],
    "layers": ["128-64-32", "128-64-32-16", "256-128-64"],
    "activation": ["ReLU", "LeakyReLU", "Tanh"],
    "initialization": ["kaiming", "xavier"],
    "dropout": [0.0, 0.3],
    "embedding_dropout": [0.0, 0.3],
    "optimizer_fn": {
        "Adam": {"weight_decay": [0.0001, 0.00001]},
        "AdamW": {"weight_decay": [0.0001, 0.00001]},
    },
    "scheduler_fn": {
        "ReduceLROnPlateau": {"factor": [0.1, 0.9], "patience": [3, 5]},
        "ExponentialLR": {"gamma": [0.9, 0.99]},
    },
}

# TabTransformer

In [13]:
tabtransformer_medium_param_grid = {
    "outer_params": {
        "hyperopt_evals": 10,
        "auto_lr_find": True,
        "precision": 16,
        "max_epochs": 1000,
        "val_size": 0.15,
        "early_stopping_patience": 6,
    },
    "batch_size": [512, 1024, 2048],
    "learning_rate": [0.01, 0.0001],
    "embedding_bias": [True, False],
    "embedding_initialization": ["kaiming_uniform", "kaiming_normal"],
    "shared_embedding_fraction": [0.125, 0.25, 0.5],
    "num_attn_blocks": [4, 10],
    "attn_dropout": [0.05, 0.3],
    "add_norm_dropout": [0.05, 0.3],
    "ff_dropout": [0.05, 0.3],
    "ff_hidden_multiplier": [2, 6],
    "transformer_activation": ["GEGLU", "ReGLU", "SwiGLU"],
    "embedding_dropout": [0.05, 0.3],
    "optimizer_fn": {
        "Adam": {"weight_decay": [0.0001, 0.00001]},
        "AdamW": {"weight_decay": [0.0001, 0.00001]},
    },
    "scheduler_fn": {
        "ReduceLROnPlateau": {"factor": [0.1, 0.9], "patience": [3, 5]},
        "ExponentialLR": {"gamma": [0.9, 0.99]},
    },
}

tabtransformer_large_param_grid = {
    "outer_params": {
        "hyperopt_evals": 10,
        "auto_lr_find": True,
        "precision": 16,
        "max_epochs": 1000,
        "val_size": 0.15,
        "early_stopping_patience": 6,
    },
    "batch_size": [1024, 2048, 4096],
    "learning_rate": [0.01, 0.0001],
    "embedding_bias": [True, False],
    "embedding_initialization": ["kaiming_uniform", "kaiming_normal"],
    "shared_embedding_fraction": [0.125, 0.25, 0.5],
    "num_attn_blocks": [4, 10],
    "attn_dropout": [0.05, 0.3],
    "add_norm_dropout": [0.05, 0.3],
    "ff_dropout": [0.05, 0.3],
    "ff_hidden_multiplier": [2, 6],
    "transformer_activation": ["GEGLU", "ReGLU", "SwiGLU"],
    "embedding_dropout": [0.05, 0.3],
    "optimizer_fn": {
        "Adam": {"weight_decay": [0.0001, 0.00001]},
        "AdamW": {"weight_decay": [0.0001, 0.00001]},
    },
    "scheduler_fn": {
        "ReduceLROnPlateau": {"factor": [0.1, 0.9], "patience": [3, 5]},
        "ExponentialLR": {"gamma": [0.9, 0.99]},
    },
}

# AutoINT

In [14]:
autoint_medium_param_grid = {
    "outer_params": {
        "hyperopt_evals": 10,
        "auto_lr_find": True,
        "precision": 16,
        "max_epochs": 1000,
        "val_size": 0.15,
        "early_stopping_patience": 6,
    },
    "batch_size": [512, 1024, 2048],
    "learning_rate": [0.001, 0.0001],
    # "attn_embed_dim": [32, 16, 64],
    "attn_embed_dim_multiplier": [
        2,
        16,
    ],  # this (x) sets attn_embed_dim to be a multiple of num_heads * x
    "num_heads": [2, 8],
    "num_attn_blocks": [2, 6],
    "attn_dropouts": [0.0, 0.3],
    "embedding_dim": [8, 32],
    "embedding_initialization": ["kaiming_uniform", "kaiming_normal"],
    "embedding_bias": [True, False],
    "share_embedding": [True, False],
    "share_embedding_strategy": ["add", "fraction"],
    "shared_embedding_fraction": [0.25, 0.1, 0.5],
    "deep_layers": [True, False],
    "layers": ["128-64-32", "128-64-32-16", "256-128-64"],
    "dropout": [0.0, 0.3],
    "activation": ["ReLU", "LeakyReLU"],
    "initialization": ["kaiming", "xavier"],
    "attention_pooling": [True, False],
    "optimizer_fn": {
        "Adam": {"weight_decay": [0.0001, 0.00001]},
        "AdamW": {"weight_decay": [0.0001, 0.00001]},
    },
    "scheduler_fn": {
        "ReduceLROnPlateau": {"factor": [0.1, 0.9], "patience": [3, 5]},
        "ExponentialLR": {"gamma": [0.9, 0.99]},
    },
}


autoint_large_param_grid = {
    "outer_params": {
        "hyperopt_evals": 10,
        "auto_lr_find": True,
        "precision": 16,
        "max_epochs": 1000,
        "val_size": 0.15,
        "early_stopping_patience": 6,
    },
    "batch_size": [1024, 2048, 4096],
    "learning_rate": [0.001, 0.0001],
    # "attn_embed_dim": [32, 16, 64],
    "attn_embed_dim_multiplier": [
        2,
        16,
    ],  # this (x) sets attn_embed_dim to be a multiple of num_heads * x
    "num_heads": [2, 8],
    "num_attn_blocks": [2, 6],
    "attn_dropouts": [0.0, 0.3],
    "embedding_dim": [8, 32],
    "embedding_initialization": ["kaiming_uniform", "kaiming_normal"],
    "embedding_bias": [True, False],
    "share_embedding": [True, False],
    "share_embedding_strategy": ["add", "fraction"],
    "shared_embedding_fraction": [0.25, 0.1, 0.5],
    "deep_layers": [True, False],
    "layers": ["128-64-32", "128-64-32-16", "256-128-64"],
    "dropout": [0.0, 0.3],
    "activation": ["ReLU", "LeakyReLU"],
    "initialization": ["kaiming", "xavier"],
    "attention_pooling": [True, False],
    "optimizer_fn": {
        "Adam": {"weight_decay": [0.0001, 0.00001]},
        "AdamW": {"weight_decay": [0.0001, 0.00001]},
    },
    "scheduler_fn": {
        "ReduceLROnPlateau": {"factor": [0.1, 0.9], "patience": [3, 5]},
        "ExponentialLR": {"gamma": [0.9, 0.99]},
    },
}

In [15]:
templates = [
# generate_model_template("housing", "mlp", param_grid=mlp_large_param_grid.copy()),
# generate_model_template("housing", "xgb", param_grid=xgb_large_param_grid.copy()),
# generate_model_template(
#    "housing", "catboost", param_grid=catboost_medium_param_grid.copy()
# ),
#generate_model_template(
#    "housing", "resnet", param_grid=resnet_medium_param_grid.copy()
#),
#generate_model_template(
#    "housing", "s1dcnn", param_grid=s1dcnn_medium_param_grid.copy()
#),
#generate_model_template(
#    "housing", "tabnet", param_grid=tabnet_medium_param_grid.copy()
#),
#generate_model_template(
#    "housing", "gate", param_grid=gate_medium_param_grid.copy()
#),
#generate_model_template(
#    "housing", "fttransformer", param_grid=fttransformer_medium_param_grid.copy()
#),
#generate_model_template(
#    "housing", "categoryembedding", param_grid=catembed_medium_param_grid.copy()
#),
#generate_model_template(
#    "housing", "node", param_grid=node_medium_param_grid.copy()
#),
#generate_model_template(
#    "housing", "tabtransformer", param_grid=tabtransformer_medium_param_grid.copy()
#),
#generate_model_template(
#    "housing", "autoint", param_grid=autoint_medium_param_grid.copy()
#),
#generate_model_template(
#    "housing", "gandalf", param_grid=gandalf_medium_param_grid.copy()
#),
#  generate_model_template(
#      "adult", "catboost", param_grid=catboost_medium_param_grid.copy()
#  ),
#  generate_model_template("adult", "xgb", param_grid=xgb_large_param_grid.copy()),
#  generate_model_template("adult", "mlp", param_grid=mlp_large_param_grid.copy()),
#  generate_model_template(
#      "adult", "resnet", param_grid=resnet_medium_param_grid.copy()
#  ),
#  generate_model_template(
#      "adult", "s1dcnn", param_grid=s1dcnn_medium_param_grid.copy()
#  ),
#  generate_model_template(
#      "adult", "tabnet", param_grid=tabnet_medium_param_grid.copy()
#  ),
#  generate_model_template("adult", "gate", param_grid=gate_medium_param_grid.copy()),
#  generate_model_template(
#      "adult", "fttransformer", param_grid=fttransformer_medium_param_grid.copy()
#  ),
#  generate_model_template(
#      "adult", "categoryembedding", param_grid=catembed_medium_param_grid.copy()
#  ),
#  generate_model_template(
#      "adult", "gandalf", param_grid=gandalf_medium_param_grid.copy()
#  ),
#  generate_model_template("adult", "node", param_grid=node_medium_param_grid.copy()),
#  generate_model_template(
#      "adult", "autoint", param_grid=autoint_medium_param_grid.copy()
#  ),
#  generate_model_template(
#      "adult", "tabtransformer", param_grid=tabtransformer_medium_param_grid.copy()
#  ),
#  generate_model_template(
#      "heloc", "catboost", param_grid=catboost_medium_param_grid.copy()
#  ),
#  generate_model_template("heloc", "xgb", param_grid=xgb_large_param_grid.copy()),
#  generate_model_template("heloc", "mlp", param_grid=mlp_large_param_grid.copy()),
#  generate_model_template(
#      "heloc", "resnet", param_grid=resnet_medium_param_grid.copy()
#  ),
#  generate_model_template(
#      "heloc", "s1dcnn", param_grid=s1dcnn_medium_param_grid.copy()
#  ),
#  generate_model_template(
#      "heloc", "tabnet", param_grid=tabnet_medium_param_grid.copy()
#  ),
#  generate_model_template("heloc", "gate", param_grid=gate_medium_param_grid.copy()),
#  generate_model_template(
#      "heloc", "fttransformer", param_grid=fttransformer_medium_param_grid.copy()
#  ),
#  generate_model_template(
#      "heloc", "categoryembedding", param_grid=catembed_medium_param_grid.copy()
#  ),
#  generate_model_template(
#      "heloc", "gandalf", param_grid=gandalf_medium_param_grid.copy()
#  ),
#  generate_model_template("heloc", "node", param_grid=node_medium_param_grid.copy()),
#  generate_model_template(
#      "heloc", "autoint", param_grid=autoint_medium_param_grid.copy()
#  ),
#  generate_model_template(
#      "heloc", "tabtransformer", param_grid=tabtransformer_medium_param_grid.copy()
#  ),
#  generate_model_template(
#      "creditcard", "catboost", param_grid=catboost_medium_param_grid.copy()
#  ),
#  generate_model_template(
#      "creditcard", "xgb", param_grid=xgb_large_param_grid.copy()
#  ),
#  generate_model_template(
#      "creditcard", "mlp", param_grid=mlp_large_param_grid.copy()
#  ),
#  generate_model_template(
#      "creditcard", "resnet", param_grid=resnet_medium_param_grid.copy()
#  ),
#  generate_model_template(
#      "creditcard", "s1dcnn", param_grid=s1dcnn_medium_param_grid.copy()
#  ),
#  generate_model_template(
#      "creditcard", "tabnet", param_grid=tabnet_medium_param_grid.copy()
#  ),
  generate_model_template(
      "creditcard", "gate", param_grid=gate_medium_param_grid.copy()
  ),
  generate_model_template(
      "creditcard", "fttransformer", param_grid=fttransformer_medium_param_grid.copy()
  ),
  generate_model_template(
      "creditcard", "categoryembedding", param_grid=catembed_medium_param_grid.copy()
  ),
  generate_model_template(
      "creditcard", "gandalf", param_grid=gandalf_medium_param_grid.copy()
  ),
  generate_model_template(
      "creditcard", "node", param_grid=node_medium_param_grid.copy()
  ),
  generate_model_template(
      "creditcard", "autoint", param_grid=autoint_medium_param_grid.copy()
  ),
  generate_model_template(
      "creditcard",
      "tabtransformer",
      param_grid=tabtransformer_medium_param_grid.copy(),
  ),
     generate_model_template("iris", "mlp", param_grid=mlp_large_param_grid.copy()),
 generate_model_template("iris", "xgb", param_grid=xgb_large_param_grid.copy()),
 generate_model_template(
    "iris", "catboost", param_grid=catboost_medium_param_grid.copy()
 ),
    generate_model_template(
        "iris", "resnet", param_grid=resnet_large_param_grid.copy()
    ),
    generate_model_template(
        "iris", "s1dcnn", param_grid=s1dcnn_large_param_grid.copy()
    ),
    generate_model_template(
        "iris", "tabnet", param_grid=tabnet_large_param_grid.copy()
    ),
    generate_model_template(
        "iris", "gate", param_grid=gate_large_param_grid.copy()
    ),
    generate_model_template(
        "iris", "fttransformer", param_grid=fttransformer_large_param_grid.copy()
    ),
    generate_model_template(
        "iris", "categoryembedding", param_grid=catembed_large_param_grid.copy()
    ),
    generate_model_template(
        "iris", "gandalf", param_grid=gandalf_large_param_grid.copy()
    ),
    generate_model_template(
        "iris", "node", param_grid=node_large_param_grid.copy()
    ),
    generate_model_template(
        "iris", "autoint", param_grid=autoint_large_param_grid.copy()
    ),
    generate_model_template(
        "iris", "tabtransformer", param_grid=tabtransformer_large_param_grid.copy()
    ),
         generate_model_template("titanic", "mlp", param_grid=mlp_large_param_grid.copy()),
 generate_model_template("titanic", "xgb", param_grid=xgb_large_param_grid.copy()),
 generate_model_template(
    "titanic", "catboost", param_grid=catboost_medium_param_grid.copy()
 ),
    generate_model_template(
        "titanic", "resnet", param_grid=resnet_large_param_grid.copy()
    ),
    generate_model_template(
        "titanic", "s1dcnn", param_grid=s1dcnn_large_param_grid.copy()
    ),
    generate_model_template(
        "titanic", "tabnet", param_grid=tabnet_large_param_grid.copy()
    ),
    generate_model_template(
        "titanic", "gate", param_grid=gate_large_param_grid.copy()
    ),
    generate_model_template(
        "titanic", "fttransformer", param_grid=fttransformer_large_param_grid.copy()
    ),
    generate_model_template(
        "titanic", "categoryembedding", param_grid=catembed_large_param_grid.copy()
    ),
    generate_model_template(
        "titanic", "gandalf", param_grid=gandalf_large_param_grid.copy()
    ),
    generate_model_template(
        "titanic", "node", param_grid=node_large_param_grid.copy()
    ),
    generate_model_template(
        "titanic", "autoint", param_grid=autoint_large_param_grid.copy()
    ),
    generate_model_template(
        "titanic", "tabtransformer", param_grid=tabtransformer_large_param_grid.copy()
    ),

    generate_model_template(
        "breastcancer", "resnet", param_grid=resnet_large_param_grid.copy()
    ),
    generate_model_template(
        "breastcancer", "s1dcnn", param_grid=s1dcnn_large_param_grid.copy()
    ),
    generate_model_template(
        "breastcancer", "tabnet", param_grid=tabnet_large_param_grid.copy()
    ),
    generate_model_template(
        "breastcancer", "gate", param_grid=gate_large_param_grid.copy()
    ),
    generate_model_template(
        "breastcancer", "fttransformer", param_grid=fttransformer_large_param_grid.copy()
    ),
    generate_model_template(
        "breastcancer", "categoryembedding", param_grid=catembed_large_param_grid.copy()
    ),
    generate_model_template(
        "breastcancer", "gandalf", param_grid=gandalf_large_param_grid.copy()
    ),
    generate_model_template(
        "breastcancer", "node", param_grid=node_large_param_grid.copy()
    ),
    generate_model_template(
        "breastcancer", "autoint", param_grid=autoint_large_param_grid.copy()
    ),
    generate_model_template(
        "breastcancer", "tabtransformer", param_grid=tabtransformer_large_param_grid.copy()
    ),


    generate_model_template(
        "ageconditions", "resnet", param_grid=resnet_large_param_grid.copy()
    ),
    generate_model_template(
        "ageconditions", "s1dcnn", param_grid=s1dcnn_large_param_grid.copy()
    ),
    generate_model_template(
        "ageconditions", "tabnet", param_grid=tabnet_large_param_grid.copy()
    ),
    generate_model_template(
        "ageconditions", "gate", param_grid=gate_large_param_grid.copy()
    ),
    generate_model_template(
        "ageconditions", "fttransformer", param_grid=fttransformer_large_param_grid.copy()
    ),
    generate_model_template(
        "ageconditions", "categoryembedding", param_grid=catembed_large_param_grid.copy()
    ),
    generate_model_template(
        "ageconditions", "gandalf", param_grid=gandalf_large_param_grid.copy()
    ),
    generate_model_template(
        "ageconditions", "node", param_grid=node_large_param_grid.copy()
    ),
    generate_model_template(
        "ageconditions", "autoint", param_grid=autoint_large_param_grid.copy()
    ),
    generate_model_template(
        "ageconditions", "tabtransformer", param_grid=tabtransformer_large_param_grid.copy()
    ),


    generate_model_template(
      "covertype", "catboost", param_grid=catboost_medium_param_grid.copy()
  ),
  generate_model_template("covertype", "xgb", param_grid=xgb_large_param_grid.copy()),
  generate_model_template("covertype", "mlp", param_grid=mlp_large_param_grid.copy()),
    
   generate_model_template(
        "covertype", "resnet", param_grid=resnet_large_param_grid.copy()
    ),
    generate_model_template(
        "covertype", "s1dcnn", param_grid=s1dcnn_large_param_grid.copy()
    ),
    generate_model_template(
        "covertype", "tabnet", param_grid=tabnet_large_param_grid.copy()
    ),
    generate_model_template(
        "covertype", "gate", param_grid=gate_large_param_grid.copy()
    ),
    generate_model_template(
        "covertype", "fttransformer", param_grid=fttransformer_large_param_grid.copy()
    ),
    generate_model_template(
        "covertype", "categoryembedding", param_grid=catembed_large_param_grid.copy()
    ),
    generate_model_template(
        "covertype", "gandalf", param_grid=gandalf_large_param_grid.copy()
    ),
    generate_model_template(
        "covertype", "node", param_grid=node_large_param_grid.copy()
    ),
    generate_model_template(
        "covertype", "autoint", param_grid=autoint_large_param_grid.copy()
    ),
    generate_model_template(
        "covertype", "tabtransformer", param_grid=tabtransformer_large_param_grid.copy()
    ),


]

In [16]:
# create a dictionary with the "runs" key and the list of dictionaries
runs_dict = {"runs": templates}

# write the dictionary to a YAML file
with open("../configuration/experiment_runs.yml", "w") as f:
    yaml.dump(templates, f, sort_keys=False, default_flow_style=False)

In [17]:
i  = [-0.9988709989212998, -0.8863132201155857, -0.9997219786008158, -0.959104173635317, -0.9428239001500265]
import numpy as np
np.average(i)

-0.957366854284609