In [15]:
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score
import random
import numpy as np
import pandas as pd
import os
# import wandb
import logging
import warnings

# configure logging at the root level of Lightning
logging.getLogger("pytorch_lightning").setLevel(logging.ERROR)
os.environ['PT_LOGLEVEL'] = "CRITICAL" # Setting Log Level for PyTorch Tabular. Need to do it before importing the modules
# %load_ext autoreload
# %autoreload 2

# Utility Functions

In [2]:
def make_mixed_classification(n_samples, n_features, n_categories):
    X, y = make_classification(n_samples=n_samples, n_features=n_features, random_state=42, n_informative=5)
    cat_cols = random.choices(list(range(X.shape[-1])), k=n_categories)
    num_cols = [i for i in range(X.shape[-1]) if i not in cat_cols]
    for col in cat_cols:
        X[:, col] = pd.qcut(X[:, col], q=4).codes.astype(int)
    col_names = []
    num_col_names = []
    cat_col_names = []
    for i in range(X.shape[-1]):
        if i in cat_cols:
            col_names.append(f"cat_col_{i}")
            cat_col_names.append(f"cat_col_{i}")
        if i in num_cols:
            col_names.append(f"num_col_{i}")
            num_col_names.append(f"num_col_{i}")
    X = pd.DataFrame(X, columns=col_names)
    y = pd.Series(y, name="target")
    data = X.join(y)
    return data, cat_col_names, num_col_names


def print_metrics(y_true, y_pred, tag):
    if isinstance(y_true, pd.DataFrame) or isinstance(y_true, pd.Series):
        y_true = y_true.values
    if isinstance(y_pred, pd.DataFrame) or isinstance(y_pred, pd.Series):
        y_pred = y_pred.values
    if y_true.ndim > 1:
        y_true = y_true.ravel()
    if y_pred.ndim > 1:
        y_pred = y_pred.ravel()
    val_acc = accuracy_score(y_true, y_pred)
    val_f1 = f1_score(y_true, y_pred)
    print(f"{tag} Acc: {val_acc} | {tag} F1: {val_f1}")


# Generate Synthetic Data 

First of all, let's create a synthetic data which is a mix of numerical and categorical features

In [3]:
data, cat_col_names, num_col_names = make_mixed_classification(n_samples=10000, n_features=20, n_categories=4)


# Importing the Library

In [4]:
from pytorch_tabular import TabularModel
from pytorch_tabular.models import (
    CategoryEmbeddingModelConfig,
    NodeConfig,
    TabNetModelConfig,
    GatedAdditiveTreeEnsembleConfig,
)
from pytorch_tabular.config import DataConfig, OptimizerConfig, TrainerConfig, ExperimentConfig
from pytorch_tabular.models.common.heads import LinearHeadConfig




# Cross Validation

In [5]:
train, test = train_test_split(data, random_state=42)


In [6]:
data_config = DataConfig(
    target=[
        "target"
    ],  # target should always be a list. Multi-targets are only supported for regression. Multi-Task Classification is not implemented
    continuous_cols=num_col_names,
    categorical_cols=cat_col_names,
)
trainer_config = TrainerConfig(
    batch_size=1024,
    max_epochs=100,
    early_stopping="valid_loss",  # Monitor valid_loss for early stopping
    early_stopping_mode="min",  # Set the mode as min because for val_loss, lower is better
    early_stopping_patience=5,  # No. of epochs of degradation training will wait before terminating
    checkpoints="valid_loss",  # Save best checkpoint monitoring val_loss
    load_best=True,  # After training, load the best checkpoint
    progress_bar="none",  # Turning off Progress bar
    trainer_kwargs=dict(enable_model_summary=False),  # Turning off model summary
)
optimizer_config = OptimizerConfig()

head_config = LinearHeadConfig(
    layers="", dropout=0.1, initialization="kaiming"  # No additional layer in head, just a mapping layer to output_dim
).__dict__  # Convert to dict to pass to the model config (OmegaConf doesn't accept objects)

model_config = CategoryEmbeddingModelConfig(
    task="classification",
    layers="1024-512-512",  # Number of nodes in each layer
    activation="LeakyReLU",  # Activation between each layers
    learning_rate=1e-3,
    head="LinearHead",  # Linear Head
    head_config=head_config,  # Linear Head Config
)

tabular_model = TabularModel(
    data_config=data_config,
    model_config=model_config,
    optimizer_config=optimizer_config,
    trainer_config=trainer_config,
)


## Using High-Level API

In [9]:
# cross validation loop usnig sklearn
from sklearn.model_selection import KFold
from sklearn.metrics import accuracy_score, f1_score

kf = KFold(n_splits=5, shuffle=True, random_state=42)

def _accuracy(y_true, y_pred):
    return accuracy_score(y_true, y_pred['prediction'].values)


cv_scores, oof_predictions = tabular_model.cross_validate(
    2, train, metric=_accuracy, return_oof=False, reset_datamodule=False)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:639: Checkpoint directory saved_models exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.
/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` arg

In [10]:
print(f"KFold Mean: {np.mean(cv_scores)} | KFold SD: {np.std(cv_scores)}")

KFold Mean: 0.9356 | KFold SD: 0.021733333333333327


## Using Low-Level API    

Sometimes, the fitting the datamodule is an expensive operation. If the dataset is sufficiently large, we can take an approximation and prepare the `TabularDatamodule` once and then reuse the same for the other folds.

_P.S - The loop can easily be modified to do bagging (predict on test data using model from each fold and average it)_

In [11]:
kf = KFold(n_splits=5, shuffle=True, random_state=42)
# Initialize the tabular model onece
tabular_model = TabularModel(
    data_config=data_config,
    model_config=model_config,
    optimizer_config=optimizer_config,
    trainer_config=trainer_config,
)
metrics = []
datamodule = None
model = None
for fold, (train_idx, val_idx) in enumerate(kf.split(train)):
    print(f"Fold: {fold}")
    train_fold = train.iloc[train_idx]
    val_fold = train.iloc[val_idx]
    if datamodule is None:
        # Initialize datamodule and model in the first fold
        # uses train data from this fold to fit all transformers
        datamodule = tabular_model.prepare_dataloader(train=train_fold, validation=val_fold, seed=42)
        model = tabular_model.prepare_model(datamodule)
    else:
        # Preprocess the current fold data using the fitted transformers and save in datamodule
        datamodule.train, _ = datamodule.preprocess_data(train_fold, stage="inference")
        datamodule.validation, _ = datamodule.preprocess_data(val_fold, stage="inference")
    # Train the model
    tabular_model.train(model, datamodule)
    result = tabular_model.evaluate(val_fold, verbose=False)
    metrics.append(result[0]["test_accuracy"])
    # Reset the trained weights before next fold
    tabular_model.model.reset_weights()


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:639: Checkpoint directory saved_models exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Fold: 0


/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.
/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.
/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/loops/fit_loop.py:293: The number of training batches (6) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.
LOC

Fold: 1


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:639: Checkpoint directory saved_models exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider incr

Fold: 2


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:639: Checkpoint directory saved_models exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider incr

Fold: 3


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:639: Checkpoint directory saved_models exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider incr

Fold: 4


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.


In [12]:
print(f"KFold Mean: {np.mean(metrics)} | KFold SD: {np.std(metrics)}")


KFold Mean: 0.9637333273887634 | KFold SD: 0.016403796287977765


# Evaluating Multiple models without re-fitting DataModules    

Using the Low-level API, we can also train and evaluate multiple models without re-fitting a datamodule. 

In [30]:
train, test = train_test_split(data, random_state=42)
train, val = train_test_split(train, random_state=42)


In [31]:
results = []


In [32]:
data_config = DataConfig(
    target=[
        "target"
    ],  # target should always be a list. Multi-targets are only supported for regression. Multi-Task Classification is not implemented
    continuous_cols=num_col_names,
    categorical_cols=cat_col_names,
)
trainer_config = TrainerConfig(
    batch_size=1024,
    max_epochs=100,
    early_stopping="valid_loss",  # Monitor valid_loss for early stopping
    early_stopping_mode="min",  # Set the mode as min because for val_loss, lower is better
    early_stopping_patience=5,  # No. of epochs of degradation training will wait before terminating
    checkpoints="valid_loss",  # Save best checkpoint monitoring val_loss
    load_best=True,  # After training, load the best checkpoint
    #     progress_bar="none", # Turning off Progress bar
    #     trainer_kwargs=dict(
    #         enable_model_summary=False # Turning off model summary
    #     )
)
optimizer_config = OptimizerConfig()

head_config = LinearHeadConfig(
    layers="", dropout=0.1, initialization="kaiming"  # No additional layer in head, just a mapping layer to output_dim
).__dict__  # Convert to dict to pass to the model config (OmegaConf doesn't accept objects)

model1_config = CategoryEmbeddingModelConfig(
    task="classification",
    layers="1024-512-512",  # Number of nodes in each layer
    activation="LeakyReLU",  # Activation between each layers
    learning_rate=1e-3,
    head="LinearHead",  # Linear Head
    head_config=head_config,  # Linear Head Config
)

model_config = CategoryEmbeddingModelConfig(
    task="classification",
    layers="1024-512-512",  # Number of nodes in each layer
    activation="LeakyReLU",  # Activation between each layers
    learning_rate=1e-3,
    head="LinearHead",  # Linear Head
    head_config=head_config,  # Linear Head Config
)

alt_model_config = GatedAdditiveTreeEnsembleConfig(
    task="classification",
    learning_rate=1e-3,
    head="LinearHead",  # Linear Head
    head_config=head_config,  # Linear Head Config
)

tabular_model = TabularModel(
    data_config=data_config,
    model_config=model_config,
    optimizer_config=optimizer_config,
    trainer_config=trainer_config,
)


In [34]:
datamodule = tabular_model.prepare_dataloader(train=train, validation=val, seed=42)
model = tabular_model.prepare_model(datamodule)
tabular_model.train(model, datamodule)


Auto select gpus: [0]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Output()

<pytorch_lightning.trainer.trainer.Trainer at 0x7fd4491ea350>

In [35]:
result = tabular_model.evaluate(test)

result = result[0]
result["Model"] = "CategoryEmbedding"
results.append(result)


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
  rank_zero_warn(


Output()

In [36]:
alt_tabular_model = TabularModel(
    data_config=data_config,
    model_config=alt_model_config,
    optimizer_config=optimizer_config,
    trainer_config=trainer_config,
)
alt_model = alt_tabular_model.prepare_model(datamodule)
alt_tabular_model.train(alt_model, datamodule)


Auto select gpus: [0]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Output()

<pytorch_lightning.trainer.trainer.Trainer at 0x7fd41a0eb160>

In [37]:
result = alt_tabular_model.evaluate(test)

result = result[0]
result["Model"] = "GATE"
results.append(result)


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
  rank_zero_warn(


Output()

In [38]:
pd.DataFrame(results)


Unnamed: 0,test_loss,test_accuracy,Model
0,0.255171,0.9028,CategoryEmbedding
1,0.244969,0.912,GATE


# Hyperparameter Tuning

Using the Low-level API, we can also implement hyperparameter tuning

In [7]:
results = []


In [8]:
train, test = train_test_split(data, random_state=42)
train, val = train_test_split(train, random_state=42)


In [16]:
data_config = DataConfig(
    target=[
        "target"
    ],  # target should always be a list. Multi-targets are only supported for regression. Multi-Task Classification is not implemented
    continuous_cols=num_col_names,
    categorical_cols=cat_col_names,
)
trainer_config = TrainerConfig(
    batch_size=1024,
    max_epochs=100,
    early_stopping="valid_loss",  # Monitor valid_loss for early stopping
    early_stopping_mode="min",  # Set the mode as min because for val_loss, lower is better
    early_stopping_patience=5,  # No. of epochs of degradation training will wait before terminating
    checkpoints="valid_loss",  # Save best checkpoint monitoring val_loss
    load_best=True,  # After training, load the best checkpoint
    progress_bar="none",  # Turning off Progress bar
    trainer_kwargs=dict(enable_model_summary=False),  # Turning off model summary
)
optimizer_config = OptimizerConfig()

head_config = LinearHeadConfig(
    layers="", dropout=0.1, initialization="kaiming"  # No additional layer in head, just a mapping layer to output_dim
).__dict__  # Convert to dict to pass to the model config (OmegaConf doesn't accept objects)

model_config = CategoryEmbeddingModelConfig(
    task="classification",
    layers="1024-512-512",  # Number of nodes in each layer
    activation="LeakyReLU",  # Activation between each layers
    learning_rate=1e-3,
    head="LinearHead",  # Linear Head
    head_config=head_config,  # Linear Head Config
)

### Grid Search   

Note: For demonstration we are using the test split for tuning, but in real problems, please use a separate validation set for tuning purposes. Otherwise, you will be overfitting to the test set and have falsely high performance estimates.

#### Define the Hyperparameter Space
The hyperparameter space is defined as a dictionary. The keys are the hyperparameter names and the values are the list of values to be tried. The hyparameter names follow the below convention:
- `model_config__<hyperparameter_name>` for model hyperparameters
- `model_config.head_config__<hyperparameter_name>` for head hyperparameters
- `trainer_config__<hyperparameter_name>` for trainer hyperparameters
- `optimizer_config__<hyperparameter_name>` for optimizer hyperparameters
- We can't use data module hyperparameters for tuning as the datamodule is already fitted and we can't change it's hyperparameters.

In [17]:
search_space = {
    "model_config__layers": ["1024-512-512", "1024-512-256", "1024-512-128"],
    "model_config.head_config__dropout": [0.1, 0.2, 0.3],
    "trainer_config__batch_size": [1024, 2048, 4096],
    "optimizer_config__optimizer": ["RAdam", "AdamW"],
}
# Any other parameter which is not part of the search_space, will be kept constant during the search

In [18]:
from pytorch_tabular.tabular_model_tuner import TabularModelTuner

In [19]:
tuner = TabularModelTuner(
    data_config=data_config,
    model_config=model_config,
    optimizer_config=optimizer_config,
    trainer_config=trainer_config
)
with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    result = tuner.tune(
        train=train,
        validation=test, # Need not give validation is we use CV
        search_space=search_space,
        strategy="grid_search",
        # cv=5, # Uncomment this to do a 5 fold cross validation 
        metric="accuracy",
        mode="max",
        progress_bar=True,
        verbose=False # Make True if you want to log metrics and params each iteration
    )

Output()

Result is a namedtuple with trials_df, best_params, and best_score\

- trials_df: A dataframe with all the hyperparameter combinations and their corresponding scores
- best_params: The best hyperparameter combination
- best_score: The best score

In [20]:
result.trials_df.head()

Unnamed: 0,trial_id,model_config.head_config__dropout,model_config__layers,optimizer_config__optimizer,trainer_config__batch_size,loss,accuracy
0,0,0.1,1024-512-512,RAdam,1024,0.21234,0.912
1,1,0.1,1024-512-512,RAdam,2048,0.199576,0.9244
2,2,0.1,1024-512-512,RAdam,4096,0.208375,0.92
3,3,0.1,1024-512-512,AdamW,1024,0.230789,0.9104
4,4,0.1,1024-512-512,AdamW,2048,0.214804,0.9188


In [21]:
result.best_params, result.best_score

({'model_config.head_config__dropout': 0.2,
  'model_config__layers': '1024-512-512',
  'optimizer_config__optimizer': 'AdamW',
  'trainer_config__batch_size': 2048,
  'loss': 0.20326237380504608},
 0.9259999990463257)

### Random Search   

Note: For demonstration we are using the test split for tuning, but in real problems, please use a separate validation set for tuning purposes. Otherwise, you will be overfitting to the test set and have falsely high performance estimates.

#### Define the Hyperparameter Space
The hyperparameter space is defined as a dictionary. The keys are the hyperparameter names and the values are the list of values for categorical and distributions for continuous. The hyparameter names follow the below convention:
- `model_config__<hyperparameter_name>` for model hyperparameters
- `model_config.head_config__<hyperparameter_name>` for head hyperparameters
- `trainer_config__<hyperparameter_name>` for trainer hyperparameters
- `optimizer_config__<hyperparameter_name>` for optimizer hyperparameters
- We can't use data module hyperparameters for tuning as the datamodule is already fitted and we can't change it's hyperparameters.

In [22]:
from scipy.stats import uniform, randint, loguniform
search_space = {
    "model_config__layers": ["1024-512-512", "1024-512-256", "1024-512-128"],
    "model_config.head_config__dropout": uniform(0, 0.5),
    "trainer_config__batch_size": randint(128, 2048),
    "optimizer_config__optimizer": ["RAdam", "AdamW"],
}
# Any other parameter which is not part of the search_space, will be kept constant during the search

In [23]:
from pytorch_tabular.tabular_model_tuner import TabularModelTuner

In [24]:
tuner = TabularModelTuner(
    data_config=data_config,
    model_config=model_config,
    optimizer_config=optimizer_config,
    trainer_config=trainer_config
)
with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    result = tuner.tune(
        train=train,
        validation=test, # Need not give validation is we use CV
        search_space=search_space,
        n_trials=10,
        strategy="random_search",
        # cv=5, # Uncomment this to do a 5 fold cross validation 
        metric="accuracy",
        mode="max",
        progress_bar=True,
        verbose=False # Make True if you want to log metrics and params each iteration
    )

Output()

Result is a namedtuple with trials_df, best_params, and best_score\

- trials_df: A dataframe with all the hyperparameter combinations and their corresponding scores
- best_params: The best hyperparameter combination
- best_score: The best score

In [25]:
result.trials_df.head()

Unnamed: 0,trial_id,model_config.head_config__dropout,model_config__layers,optimizer_config__optimizer,trainer_config__batch_size,loss,accuracy
0,0,0.18727,1024-512-512,RAdam,1258,0.207948,0.92
1,1,0.389846,1024-512-512,RAdam,249,0.206577,0.9184
2,2,0.077997,1024-512-128,RAdam,215,0.219574,0.9176
3,3,0.166854,1024-512-128,AdamW,1460,0.215352,0.9192
4,4,0.484955,1024-512-256,AdamW,513,0.217733,0.9148


In [26]:
result.best_params, result.best_score

({'model_config.head_config__dropout': 0.09091248360355031,
  'model_config__layers': '1024-512-512',
  'optimizer_config__optimizer': 'RAdam',
  'trainer_config__batch_size': 587,
  'loss': 0.20665502548217773},
 0.9223999977111816)