<a target="_blank" href="https://colab.research.google.com/github/yandex-research/rtdl-revisiting-models/blob/main/package/example.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

---

**See also** [RTDL](https://github.com/yandex-research/rtdl)
-- **other projects on tabular deep learning**.

---

- This notebook provides a usage example of the
  [rtdl_revisiting_models](https://github.com/yandex-research/rtdl-revisiting-models)
  package.
- Hyperparameters are not tuned and may be suboptimal.

In [19]:
%pip install delu==0.0.23
%pip install rtdl_revisiting_models



In [29]:
# ruff: noqa: E402
import math
import warnings
from typing import Dict, Literal

warnings.simplefilter("ignore")
import delu  # Deep Learning Utilities: https://github.com/Yura52/delu
import numpy as np
import scipy.special
import sklearn.datasets
import sklearn.metrics
import sklearn.model_selection
import sklearn.preprocessing
import torch
import torch.nn.functional as F
import torch.optim
from torch import Tensor
from tqdm.std import tqdm
from scipy.stats import mode

warnings.resetwarnings()

from rtdl_revisiting_models import MLP, ResNet, FTTransformer

In [21]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Set random seeds in all libraries.
delu.random.seed(11)

11

## Dataset

In [22]:
pip install openml



In [30]:
import openml
import pandas as pd
dataset_id = 1590   #45068
dataset = openml.datasets.get_dataset(dataset_id)
X, y, _, _ = dataset.get_data(target=dataset.default_target_attribute)
full_data = X.copy()
full_data[dataset.default_target_attribute] = y

csv_file = 'adult.csv'
full_data.to_csv(csv_file, index=False)

if isinstance(X, pd.DataFrame):
    print("Data types of features in X:")
    print(X.dtypes)
else:
    print(f"X is not a DataFrame; its type is {type(X)}")

Data types of features in X:
age                  uint8
workclass         category
fnlwgt             float64
education         category
education-num        uint8
marital-status    category
occupation        category
relationship      category
race              category
sex               category
capital-gain       float64
capital-loss       float64
hours-per-week       uint8
native-country    category
dtype: object


In [36]:
print(X.shape)


(48842, 14)


In [39]:
from sklearn.datasets import fetch_openml
from sklearn.preprocessing import OrdinalEncoder
from scipy.stats import mode
import numpy as np
# Fetch dataset from OpenML
TaskType = Literal["regression", "binclass", "multiclass"]

task_type: TaskType = "binclass"
n_classes = 2
#dataset = sklearn.datasets.fetch_california_housing()
dataset = fetch_openml(data_id=1590, as_frame=False)
X: np.ndarray = dataset["data"]
Y: np.ndarray = dataset["target"]

# Check the dtype of the target
print(f"Original dtype of Y: {Y.dtype}")

# Map class '1' to '0' and class '2' to '1' while preserving object dtype
mapping = {'<=50K': '0', '>50K': '1'}
Y = np.array([mapping[str(y)] if str(y) in mapping else y for y in Y], dtype=object)
# NOTE: uncomment to solve a classification task.
# n_classes = 2
#assert n_classes >= 2
# task_type: TaskType = 'binclass' if n_classes == 2 else 'multiclass'
# X_cont, Y = sklearn.datasets.make_classification(
#     n_samples=20000,
#     n_features=8,
#     n_classes=n_classes,
#     n_informative=3,
#     n_redundant=2,
# )

#numerical_indices = [0,1,2,3,4,5, 6,7,8,9,10, 11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,64,65,66,67,68,69,70]  # Replace with actual indices of numerical features
numerical_indices = [0,2,4,10,11,12]
categorical_indices = [1,3,5,6,7,8,9,13]
X_cont: np.ndarray = X[:, numerical_indices]
X_cat: np.ndarray = X[:, categorical_indices] if categorical_indices else None
X_cont: np.ndarray = X_cont.astype(np.float32)
n_cont_features = X_cont.shape[1]
###########################################################################################
# if X_cat is not None:
#    print(f"X_cat dtype: {X_cat.dtype}")
#    print("Example categorical data:", X_cat[:5])
# else:
#    print("No categorical features.")

#if X_cat is not None:
    # Handle NaN values (replace NaN with a placeholder, e.g., "missing")
#    X_cat = np.where(X_cat == np.array(None), "missing", X_cat)
#    X_cat = np.where(X_cat == np.nan, "missing", X_cat)


if X_cat is not None:
    # Loop through each column
    for i in range(X_cat.shape[1]):
        col = X_cat[:, i]

        try:
            if not np.issubdtype(col.dtype, np.number):  # Non-numeric columns
                col = np.where(col == None, "missing", col)  # Replace None
                col = np.where(col == "nan", "missing", col)  # Replace string "nan"
            else:  # Numeric columns
                col = np.where(col == None, np.nan, col)  # Replace None with np.nan
                col = np.where(col == "nan", np.nan, col)  # Replace string "nan"
                valid_values = col[~np.isnan(col)]  # Get non-NaN values
                if len(valid_values) > 0:  # Ensure there are valid values
                    col_mode_result = mode(valid_values, nan_policy="omit")
                    if col_mode_result.count[0] > 0:  # Check if mode computation was successful
                        col_mode = col_mode_result.mode[0]
                        col[np.isnan(col)] = col_mode  # Replace NaN with the mode
                    else:
                        print(f"Column {i}: No valid mode found.")
                else:
                    print(f"Column {i} contains only NaN values and cannot be imputed.")
        except Exception as e:
            print(f"Error processing column {i}: {e}")

    # Use OrdinalEncoder to convert categories to integers
encoder = OrdinalEncoder(handle_unknown='use_encoded_value', unknown_value=-1)
X_cat = encoder.fit_transform(X_cat)

    # Convert to integers for compatibility with embedding layers
X_cat = X_cat.astype(np.int64)  # <- Change here

cat_cardinalities = [len(np.unique(X_cat[:, col])) for col in range(X_cat.shape[1])]
#print(f"Cardinalities of categorical features: {cat_cardinalities}")


cat_cardinalities = [len(np.unique(X_cat[:, col])) for col in range(X_cat.shape[1])]
#print(f"Cardinalities of categorical features: {cat_cardinalities}")
#########################################################################################
########################################################################################
# >>> Categorical features.
# NOTE: the above datasets do not have categorical features, but,
# for the demonstration purposes, it is possible to generate them.
#cat_cardinalities = [
    # NOTE: uncomment the two lines below to add two categorical features.
    # 4,  # Allowed values: [0, 1, 2, 3].
    # 7,  # Allowed values: [0, 1, 2, 3, 4, 5, 6].
#]
# X_cat = (
#     np.column_stack(
#         [np.random.randint(0, c, (len(X_cont),)) for c in cat_cardinalities]
#     )
#     if cat_cardinalities
#     else None
# )

# >>> Labels.
# Regression labels must be represented by float32.
if task_type == "regression":
    Y = Y.astype(np.float32)
else:
    assert n_classes is not None
    Y = Y.astype(np.int64)
    assert set(Y.tolist()) == set(
        range(n_classes)
    ), "Classification labels must form the range [0, 1, ..., n_classes - 1]"

# >>> Split the dataset.
all_idx = np.arange(len(Y))
trainval_idx, test_idx = sklearn.model_selection.train_test_split(
    all_idx, train_size=0.8
)
train_idx, val_idx = sklearn.model_selection.train_test_split(
    trainval_idx, train_size=0.8125
)
data_numpy = {
    "train": {"x_cont": X_cont[train_idx], "y": Y[train_idx]},
    "val": {"x_cont": X_cont[val_idx], "y": Y[val_idx]},
    "test": {"x_cont": X_cont[test_idx], "y": Y[test_idx]},
}
if X_cat is not None:
    data_numpy["train"]["x_cat"] = X_cat[train_idx]
    data_numpy["val"]["x_cat"] = X_cat[val_idx]
    data_numpy["test"]["x_cat"] = X_cat[test_idx]

Original dtype of Y: object


  X_cat = X_cat.astype(np.int64)  # <- Change here


## Preprocessing

In [40]:
# >>> Feature preprocessing.
# NOTE
# The choice between preprocessing strategies depends on a task and a model.

# (A) Simple preprocessing strategy.
# preprocessing = sklearn.preprocessing.StandardScaler().fit(
#     data_numpy['train']['x_cont']
# )

# (B) Fancy preprocessing strategy.
# The noise is added to improve the output of QuantileTransformer in some cases.
X_cont_train_numpy = data_numpy["train"]["x_cont"]
noise = (
    np.random.default_rng(0)
    .normal(0.0, 1e-5, X_cont_train_numpy.shape)
    .astype(X_cont_train_numpy.dtype)
)
preprocessing = sklearn.preprocessing.QuantileTransformer(
    n_quantiles=max(min(len(train_idx) // 30, 1000), 10),
    output_distribution="normal",
    subsample=10**9,
).fit(X_cont_train_numpy + noise)
del X_cont_train_numpy

for part in data_numpy:
    data_numpy[part]["x_cont"] = preprocessing.transform(data_numpy[part]["x_cont"])

# >>> Label preprocessing.
if task_type == "regression":
    Y_mean = data_numpy["train"]["y"].mean().item()
    Y_std = data_numpy["train"]["y"].std().item()
    for part in data_numpy:
        data_numpy[part]["y"] = (data_numpy[part]["y"] - Y_mean) / Y_std

# >>> Convert data to tensors.
data = {
    part: {k: torch.as_tensor(v, device=device) for k, v in data_numpy[part].items()}
    for part in data_numpy
}

if task_type != "multiclass":
    # Required by F.binary_cross_entropy_with_logits
    for part in data:
        data[part]["y"] = data[part]["y"].float()

## Model

In [41]:
pip install optuna



In [43]:
import optuna
def objective(trial):
    # Precompute valid (d_block, attention_n_heads) pairs
    valid_pairs = []
    for d_block in [64, 128, 192, 256]:  # Possible values for d_block
        valid_n_heads = [h for h in range(1, d_block + 1) if d_block % h == 0]  # Ensure divisibility
        for n_heads in valid_n_heads:
            valid_pairs.append((d_block, n_heads))

    # Sample a valid pair
    d_block, attention_n_heads = trial.suggest_categorical("d_block_attention_n_heads", valid_pairs)

    # Sample other hyperparameters
    n_blocks = trial.suggest_int("n_blocks", 1, 4)
    attention_dropout = trial.suggest_float("attention_dropout", 0.0, 0.5, step=0.1)
    ffn_d_hidden_multiplier = trial.suggest_float("ffn_d_hidden_multiplier", 0.5, 2.0, step=0.1)
    learning_rate = trial.suggest_float("learning_rate", 1e-5, 1e-3, log=True)
    batch_size = trial.suggest_categorical("batch_size", [32, 64, 128])

    # Update FTTransformer configuration
    default_kwargs = FTTransformer.get_default_kwargs()
    default_kwargs.update({
        "n_blocks": n_blocks,
        "d_block": d_block,
        "attention_n_heads": attention_n_heads,
        "attention_dropout": attention_dropout,
        "ffn_d_hidden_multiplier": ffn_d_hidden_multiplier,
    })

    model = FTTransformer(
        n_cont_features=n_cont_features,
        cat_cardinalities=cat_cardinalities,
        d_out=n_classes,  # Output size matches number of classes
        **default_kwargs,
    ).to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    def apply_model(batch):
        if isinstance(model, (MLP, ResNet)):
            x_cat_ohe = (
                [
                    F.one_hot(column, cardinality)
                    for column, cardinality in zip(batch["x_cat"].T)  # cat_cardinalities
                ]
                if "x_cat" in batch
                else []
            )
            return model(torch.column_stack([batch["x_cont"]] + x_cat_ohe)).squeeze(-1)

        elif isinstance(model, FTTransformer):
            return model(batch["x_cont"], batch.get("x_cat"))

        else:
            raise RuntimeError(f"Unknown model type: {type(model)}")

    loss_fn = (
        F.cross_entropy  # Handles both binary and multi-class classification
        if task_type in ["binclass", "multiclass"]
        else F.mse_loss
    )

    @torch.no_grad()
    def evaluate(part):
        model.eval()
        eval_batch_size = 128
        y_pred = (
            torch.cat(
                [
                    apply_model(batch)
                    for batch in delu.iter_batches(data[part], eval_batch_size)
                ]
            )
            .cpu()
            .numpy()
        )
        y_true = data[part]["y"].cpu().numpy()

        if task_type == "binclass":
            y_pred = y_pred.argmax(1)  # Convert logits to class predictions
            score = sklearn.metrics.accuracy_score(y_true, y_pred)
        elif task_type == "multiclass":
            y_pred = y_pred.argmax(1)
            score = sklearn.metrics.accuracy_score(y_true, y_pred)
        else:
            assert task_type == "regression"
            score = -(sklearn.metrics.mean_squared_error(y_true, y_pred) ** 0.5 * Y_std)
        return score

    patience = 30
    n_epochs = 1000
    epoch_size = math.ceil(len(train_idx) / batch_size)
    early_stopping = delu.tools.EarlyStopping(patience, mode="max")

    best = {
        "val": -math.inf,
    }

    for epoch in range(n_epochs):
        for batch in tqdm(
            delu.iter_batches(data["train"], batch_size, shuffle=True),
            desc=f"Epoch {epoch}",
            total=epoch_size,
        ):
            model.train()
            optimizer.zero_grad()
            loss = loss_fn(apply_model(batch), batch["y"].long())  # Ensure targets are long integers
            loss.backward()
            optimizer.step()

        val_score = evaluate("val")
        early_stopping.update(val_score)

        if early_stopping.should_stop():
            break

        if val_score > best["val"]:
            best["val"] = val_score

    return best["val"]


# Run the Optuna study
study = optuna.create_study(direction="maximize")
study.optimize(objective, n_trials=100)

# Display the best trial
print("Best trial:")
print(study.best_trial)



[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Epoch 18: 100%|██████████| 993/993 [00:07<00:00, 126.52it/s]
Epoch 19: 100%|██████████| 993/993 [00:07<00:00, 125.00it/s]
Epoch 20: 100%|██████████| 993/993 [00:08<00:00, 122.62it/s]
Epoch 21: 100%|██████████| 993/993 [00:07<00:00, 125.67it/s]
Epoch 22: 100%|██████████| 993/993 [00:08<00:00, 123.48it/s]
Epoch 23: 100%|██████████| 993/993 [00:08<00:00, 122.68it/s]
Epoch 24: 100%|██████████| 993/993 [00:07<00:00, 126.06it/s]
Epoch 25: 100%|██████████| 993/993 [00:08<00:00, 122.44it/s]
Epoch 26: 100%|██████████| 993/993 [00:07<00:00, 124.38it/s]
Epoch 27: 100%|██████████| 993/993 [00:07<00:00, 125.19it/s]
Epoch 28: 100%|██████████| 993/993 [00:07<00:00, 128.30it/s]
Epoch 29: 100%|██████████| 993/993 [00:08<00:00, 123.55it/s]
Epoch 30: 100%|██████████| 993/993 [00:07<00:00, 124.32it/s]
Epoch 31: 100%|██████████| 993/993 [00:07<00:00, 126.21it/s]
Epoch 32: 100%|██████████| 993/993 [00:07<00:00, 126.92it/s]
Epoch 33: 100%|█████

KeyboardInterrupt: 

In [47]:
best_params = study.best_params

In [44]:
best_params = study.best_params

In [45]:
print(best_params)

{'d_block_attention_n_heads': (64, 4), 'n_blocks': 4, 'attention_dropout': 0.4, 'ffn_d_hidden_multiplier': 0.8, 'learning_rate': 8.845091084611816e-05, 'batch_size': 32}


  and should_run_async(code)


In [48]:
print(best_params)

{'d_block_attention_n_heads': (64, 4), 'n_blocks': 4, 'attention_dropout': 0.4, 'ffn_d_hidden_multiplier': 0.8, 'learning_rate': 8.845091084611816e-05, 'batch_size': 32}


In [52]:
# Extract the best hyperparameters from the study
best_params = study.best_params

# Extract d_block and attention_n_heads
d_block, attention_n_heads = best_params["d_block_attention_n_heads"]

# Update FTTransformer configuration with the best hyperparameters
default_kwargs = FTTransformer.get_default_kwargs()
default_kwargs.update({
    "n_blocks": best_params["n_blocks"],
    "d_block": d_block,
    "attention_n_heads": attention_n_heads,
    "attention_dropout": best_params["attention_dropout"],
    "ffn_d_hidden_multiplier": best_params["ffn_d_hidden_multiplier"],
})

# Instantiate the model with the best hyperparameters
model = FTTransformer(
    n_cont_features=n_cont_features,
    cat_cardinalities=cat_cardinalities,
    d_out=n_classes,  # Output size matches number of classes
    **default_kwargs,
).to(device)

# Define optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=best_params["learning_rate"])

# Define batch size
batch_size = best_params["batch_size"]


# Set up for multiple runs with different seeds
def retrain_model_with_seed(seed):
    torch.manual_seed(seed)
    np.random.seed(seed)

    n_epochs = 50  # Adjust as needed
    epoch_size = math.ceil(len(train_idx) / batch_size)
    loss_fn = (
        F.cross_entropy  # Handles both binary and multi-class classification
        if task_type in ["binclass", "multiclass"]
        else F.mse_loss  # Regression
    )

    for epoch in range(n_epochs):
        model.train()
        for batch in tqdm(delu.iter_batches(data["train"], batch_size, shuffle=True), desc=f"Epoch {epoch} (Seed: {seed})", total=epoch_size):
            optimizer.zero_grad()
            y_pred = apply_model(batch, model, task_type)
            loss = loss_fn(y_pred, batch["y"].long() if task_type != "regression" else batch["y"])
            loss.backward()
            optimizer.step()

@torch.no_grad()
def evaluate_on_test_with_seed(seed):
    torch.manual_seed(seed)
    np.random.seed(seed)
    return evaluate("test", model, task_type)

# Perform training and evaluation for multiple seeds
num_runs = 5
test_scores = []

for seed in range(num_runs):
    print(f"\n--- Run {seed + 1}/{num_runs} with Seed: {seed} ---")
    retrain_model_with_seed(seed)
    test_score = evaluate_on_test_with_seed(seed)
    test_scores.append(test_score)
    print(f"Test score for Seed {seed}: {test_score:.4f}")

# Summarize results
average_test_score = sum(test_scores) / num_runs
print("\n=== Final Results ===")
print(f"Test Scores: {test_scores}")
print(f"Average Test Score: {average_test_score:.4f}")
print(f"Standard Deviation: {np.std(test_scores):.4f}")



--- Run 1/5 with Seed: 0 ---


Epoch 0 (Seed: 0): 100%|██████████| 993/993 [00:13<00:00, 71.72it/s]
Epoch 1 (Seed: 0): 100%|██████████| 993/993 [00:13<00:00, 72.38it/s]
Epoch 2 (Seed: 0): 100%|██████████| 993/993 [00:13<00:00, 72.59it/s]
Epoch 3 (Seed: 0): 100%|██████████| 993/993 [00:13<00:00, 72.36it/s]
Epoch 4 (Seed: 0): 100%|██████████| 993/993 [00:13<00:00, 72.89it/s]
Epoch 5 (Seed: 0): 100%|██████████| 993/993 [00:13<00:00, 71.18it/s]
Epoch 6 (Seed: 0): 100%|██████████| 993/993 [00:13<00:00, 73.17it/s]
Epoch 7 (Seed: 0): 100%|██████████| 993/993 [00:13<00:00, 72.50it/s]
Epoch 8 (Seed: 0): 100%|██████████| 993/993 [00:13<00:00, 72.49it/s]
Epoch 9 (Seed: 0): 100%|██████████| 993/993 [00:13<00:00, 72.44it/s]
Epoch 10 (Seed: 0): 100%|██████████| 993/993 [00:13<00:00, 72.42it/s]
Epoch 11 (Seed: 0): 100%|██████████| 993/993 [00:13<00:00, 71.79it/s]
Epoch 12 (Seed: 0): 100%|██████████| 993/993 [00:13<00:00, 72.84it/s]
Epoch 13 (Seed: 0): 100%|██████████| 993/993 [00:13<00:00, 72.85it/s]
Epoch 14 (Seed: 0): 100%|█████

Test score for Seed 0: 0.8614

--- Run 2/5 with Seed: 1 ---


Epoch 0 (Seed: 1): 100%|██████████| 993/993 [00:13<00:00, 70.94it/s]
Epoch 1 (Seed: 1): 100%|██████████| 993/993 [00:13<00:00, 71.45it/s]
Epoch 2 (Seed: 1): 100%|██████████| 993/993 [00:13<00:00, 72.45it/s]
Epoch 3 (Seed: 1): 100%|██████████| 993/993 [00:13<00:00, 72.29it/s]
Epoch 4 (Seed: 1): 100%|██████████| 993/993 [00:13<00:00, 73.49it/s]
Epoch 5 (Seed: 1): 100%|██████████| 993/993 [00:13<00:00, 72.54it/s]
Epoch 6 (Seed: 1): 100%|██████████| 993/993 [00:13<00:00, 73.88it/s]
Epoch 7 (Seed: 1): 100%|██████████| 993/993 [00:13<00:00, 72.63it/s]
Epoch 8 (Seed: 1): 100%|██████████| 993/993 [00:13<00:00, 71.77it/s]
Epoch 9 (Seed: 1): 100%|██████████| 993/993 [00:13<00:00, 71.61it/s]
Epoch 10 (Seed: 1): 100%|██████████| 993/993 [00:13<00:00, 72.22it/s]
Epoch 11 (Seed: 1): 100%|██████████| 993/993 [00:13<00:00, 73.74it/s]
Epoch 12 (Seed: 1): 100%|██████████| 993/993 [00:13<00:00, 72.43it/s]
Epoch 13 (Seed: 1): 100%|██████████| 993/993 [00:13<00:00, 72.03it/s]
Epoch 14 (Seed: 1): 100%|█████

Test score for Seed 1: 0.8590

--- Run 3/5 with Seed: 2 ---


Epoch 0 (Seed: 2): 100%|██████████| 993/993 [00:13<00:00, 73.97it/s]
Epoch 1 (Seed: 2): 100%|██████████| 993/993 [00:13<00:00, 73.71it/s]
Epoch 2 (Seed: 2): 100%|██████████| 993/993 [00:13<00:00, 74.39it/s]
Epoch 3 (Seed: 2): 100%|██████████| 993/993 [00:13<00:00, 72.05it/s]
Epoch 4 (Seed: 2): 100%|██████████| 993/993 [00:13<00:00, 73.15it/s]
Epoch 5 (Seed: 2): 100%|██████████| 993/993 [00:13<00:00, 73.07it/s]
Epoch 6 (Seed: 2): 100%|██████████| 993/993 [00:13<00:00, 73.68it/s]
Epoch 7 (Seed: 2): 100%|██████████| 993/993 [00:13<00:00, 73.97it/s]
Epoch 8 (Seed: 2): 100%|██████████| 993/993 [00:13<00:00, 71.95it/s]
Epoch 9 (Seed: 2): 100%|██████████| 993/993 [00:13<00:00, 71.15it/s]
Epoch 10 (Seed: 2): 100%|██████████| 993/993 [00:13<00:00, 73.72it/s]
Epoch 11 (Seed: 2): 100%|██████████| 993/993 [00:13<00:00, 73.26it/s]
Epoch 12 (Seed: 2): 100%|██████████| 993/993 [00:13<00:00, 72.47it/s]
Epoch 13 (Seed: 2): 100%|██████████| 993/993 [00:13<00:00, 74.14it/s]
Epoch 14 (Seed: 2): 100%|█████

Test score for Seed 2: 0.8562

--- Run 4/5 with Seed: 3 ---


Epoch 0 (Seed: 3): 100%|██████████| 993/993 [00:13<00:00, 73.97it/s]
Epoch 1 (Seed: 3): 100%|██████████| 993/993 [00:13<00:00, 73.23it/s]
Epoch 2 (Seed: 3): 100%|██████████| 993/993 [00:13<00:00, 73.64it/s]
Epoch 3 (Seed: 3): 100%|██████████| 993/993 [00:13<00:00, 72.45it/s]
Epoch 4 (Seed: 3): 100%|██████████| 993/993 [00:13<00:00, 72.15it/s]
Epoch 5 (Seed: 3): 100%|██████████| 993/993 [00:13<00:00, 73.14it/s]
Epoch 6 (Seed: 3): 100%|██████████| 993/993 [00:13<00:00, 72.86it/s]
Epoch 7 (Seed: 3): 100%|██████████| 993/993 [00:13<00:00, 72.98it/s]
Epoch 8 (Seed: 3): 100%|██████████| 993/993 [00:13<00:00, 73.37it/s]
Epoch 9 (Seed: 3): 100%|██████████| 993/993 [00:13<00:00, 72.92it/s]
Epoch 10 (Seed: 3): 100%|██████████| 993/993 [00:13<00:00, 72.95it/s]
Epoch 11 (Seed: 3): 100%|██████████| 993/993 [00:13<00:00, 73.24it/s]
Epoch 12 (Seed: 3): 100%|██████████| 993/993 [00:13<00:00, 72.61it/s]
Epoch 13 (Seed: 3): 100%|██████████| 993/993 [00:13<00:00, 72.46it/s]
Epoch 14 (Seed: 3): 100%|█████

Test score for Seed 3: 0.8500

--- Run 5/5 with Seed: 4 ---


Epoch 0 (Seed: 4): 100%|██████████| 993/993 [00:13<00:00, 73.43it/s]
Epoch 1 (Seed: 4): 100%|██████████| 993/993 [00:13<00:00, 73.61it/s]
Epoch 2 (Seed: 4): 100%|██████████| 993/993 [00:13<00:00, 72.74it/s]
Epoch 3 (Seed: 4): 100%|██████████| 993/993 [00:13<00:00, 72.97it/s]
Epoch 4 (Seed: 4): 100%|██████████| 993/993 [00:13<00:00, 72.21it/s]
Epoch 5 (Seed: 4): 100%|██████████| 993/993 [00:13<00:00, 73.11it/s]
Epoch 6 (Seed: 4): 100%|██████████| 993/993 [00:13<00:00, 73.82it/s]
Epoch 7 (Seed: 4): 100%|██████████| 993/993 [00:13<00:00, 74.04it/s]
Epoch 8 (Seed: 4): 100%|██████████| 993/993 [00:13<00:00, 72.93it/s]
Epoch 9 (Seed: 4): 100%|██████████| 993/993 [00:13<00:00, 72.59it/s]
Epoch 10 (Seed: 4): 100%|██████████| 993/993 [00:13<00:00, 72.17it/s]
Epoch 11 (Seed: 4): 100%|██████████| 993/993 [00:13<00:00, 72.99it/s]
Epoch 12 (Seed: 4): 100%|██████████| 993/993 [00:13<00:00, 73.98it/s]
Epoch 13 (Seed: 4): 100%|██████████| 993/993 [00:13<00:00, 73.65it/s]
Epoch 14 (Seed: 4): 100%|█████

Test score for Seed 4: 0.8427

=== Final Results ===
Test Scores: [0.8613983007472618, 0.8590439144231754, 0.8561777049851571, 0.8500358276179752, 0.8426655747773569]
Average Test Score: 0.8539
Standard Deviation: 0.0068


In [17]:
# Extract the best hyperparameters from the study
best_params = study.best_params

# Extract d_block and attention_n_heads
d_block, attention_n_heads = best_params["d_block_attention_n_heads"]

# Update FTTransformer configuration with the best hyperparameters
default_kwargs = FTTransformer.get_default_kwargs()
default_kwargs.update({
    "n_blocks": best_params["n_blocks"],
    "d_block": d_block,
    "attention_n_heads": attention_n_heads,
    "attention_dropout": best_params["attention_dropout"],
    "ffn_d_hidden_multiplier": best_params["ffn_d_hidden_multiplier"],
})

# Instantiate the model with the best hyperparameters
model = FTTransformer(
    n_cont_features=n_cont_features,
    cat_cardinalities=None,
    d_out=n_classes,  # Output size matches number of classes
    **default_kwargs,
).to(device)

# Define optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=best_params["learning_rate"])

# Define batch size
batch_size = best_params["batch_size"]

# Training loop
# Training loop
def retrain_model():
    n_epochs = 100  # Adjust as needed
    epoch_size = math.ceil(len(train_idx) / batch_size)
    loss_fn = (
        F.cross_entropy  # Handles both binary and multi-class classification
        if task_type in ["binclass", "multiclass"]
        else F.mse_loss  # Regression
    )

    for epoch in range(n_epochs):
        model.train()
        for batch in tqdm(delu.iter_batches(data["train"], batch_size, shuffle=True), desc=f"Epoch {epoch}", total=epoch_size):
            optimizer.zero_grad()
            y_pred = apply_model(batch, model, task_type)
            loss = loss_fn(y_pred, batch["y"].long() if task_type != "regression" else batch["y"])
            loss.backward()
            optimizer.step()

# Adjust the evaluate function to properly pass arguments
@torch.no_grad()
def evaluate(part, model, task_type):
    model.eval()
    eval_batch_size = 128
    y_pred = (
        torch.cat(
            [
                apply_model(batch, model, task_type)
                for batch in delu.iter_batches(data[part], eval_batch_size)
            ]
        )
        .cpu()
        .numpy()
    )
    y_true = data[part]["y"].cpu().numpy()

    if task_type == "binclass":
        y_pred = y_pred.argmax(1)  # Convert logits to class predictions
        score = sklearn.metrics.accuracy_score(y_true, y_pred)
    elif task_type == "multiclass":
        y_pred = y_pred.argmax(1)
        score = sklearn.metrics.accuracy_score(y_true, y_pred)
    else:
        assert task_type == "regression"
        score = -(sklearn.metrics.mean_squared_error(y_true, y_pred) ** 0.5 * Y_std)
    return score

# Adjusted apply_model function
def apply_model(batch, model, task_type):
    if isinstance(model, (MLP, ResNet)):
        x_cat_ohe = (
            [
                F.one_hot(column, cardinality)
                for column, cardinality in zip(batch["x_cat"].T)  # cat_cardinalities
            ]
            if "x_cat" in batch
            else []
        )
        return model(torch.column_stack([batch["x_cont"]] + x_cat_ohe)).squeeze(-1)

    elif isinstance(model, FTTransformer):
        return model(batch["x_cont"], batch.get("x_cat"))

    else:
        raise RuntimeError(f"Unknown model type: {type(model)}")

# Define evaluation for the test set
def evaluate_on_test():
    return evaluate("test", model, task_type)

# Retrain the model
retrain_model()

# Evaluate on the test set
test_score = evaluate_on_test()
print(f"Test score using the best parameters: {test_score:.4f}")



Epoch 0: 100%|██████████| 52/52 [00:00<00:00, 93.24it/s]
Epoch 1: 100%|██████████| 52/52 [00:00<00:00, 99.01it/s]
Epoch 2: 100%|██████████| 52/52 [00:00<00:00, 96.46it/s]
Epoch 3: 100%|██████████| 52/52 [00:00<00:00, 98.52it/s]
Epoch 4: 100%|██████████| 52/52 [00:00<00:00, 96.76it/s]
Epoch 5: 100%|██████████| 52/52 [00:00<00:00, 97.43it/s]
Epoch 6: 100%|██████████| 52/52 [00:00<00:00, 97.43it/s]
Epoch 7: 100%|██████████| 52/52 [00:00<00:00, 97.19it/s]
Epoch 8: 100%|██████████| 52/52 [00:00<00:00, 97.32it/s]
Epoch 9: 100%|██████████| 52/52 [00:00<00:00, 97.11it/s]
Epoch 10: 100%|██████████| 52/52 [00:00<00:00, 97.25it/s]
Epoch 11: 100%|██████████| 52/52 [00:00<00:00, 96.43it/s]
Epoch 12: 100%|██████████| 52/52 [00:00<00:00, 96.96it/s]
Epoch 13: 100%|██████████| 52/52 [00:00<00:00, 96.39it/s]
Epoch 14: 100%|██████████| 52/52 [00:00<00:00, 96.75it/s]
Epoch 15: 100%|██████████| 52/52 [00:00<00:00, 96.15it/s]
Epoch 16: 100%|██████████| 52/52 [00:00<00:00, 96.34it/s]
Epoch 17: 100%|█████████

Test score using the best parameters: 0.9231


In [248]:
# The output size.
d_out = n_classes if task_type == "multiclass" else 1

# # NOTE: uncomment to train MLP
# model = MLP(
#     d_in=n_cont_features + sum(cat_cardinalities),
#     d_out=d_out,
#     n_blocks=2,
#     d_block=384,
#     dropout=0.1,
# ).to(device)
# optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-5)

# # # NOTE: uncomment to train ResNet
# model = ResNet(
#     d_in=n_cont_features + sum(cat_cardinalities),
#     d_out=d_out,
#     n_blocks=2,
#     d_block=192,
#     d_hidden=None,
#     d_hidden_multiplier=2.0,
#     dropout1=0.3,
#     dropout2=0.0,
# ).to(device)
# optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-5)

model = FTTransformer(
    n_cont_features=n_cont_features,
    cat_cardinalities=None, #cat_cardinalities,
    d_out=d_out,
    **FTTransformer.get_default_kwargs(),
).to(device)
optimizer = model.make_default_optimizer()

In [249]:
print(default_kwargs)

{'n_blocks': 2, 'd_block': 128, 'attention_n_heads': 16, 'attention_dropout': 0.1, 'ffn_d_hidden': None, 'ffn_d_hidden_multiplier': 1.5, 'ffn_dropout': 0.1, 'residual_dropout': 0.0, '_is_default': True}


## Training

In [250]:
def apply_model(batch: Dict[str, Tensor]) -> Tensor:
    if isinstance(model, (MLP, ResNet)):
        x_cat_ohe = (
            [
                F.one_hot(column, cardinality)
                for column, cardinality in zip(batch["x_cat"].T) #cat_cardinalities)
            ]
            if "x_cat" in batch
            else []
        )
        return model(torch.column_stack([batch["x_cont"]] + x_cat_ohe)).squeeze(-1)

    elif isinstance(model, FTTransformer):
        return model(batch["x_cont"], batch.get("x_cat")).squeeze(-1)

    else:
        raise RuntimeError(f"Unknown model type: {type(model)}")


loss_fn = (
    F.binary_cross_entropy_with_logits
    if task_type == "binclass"
    else F.cross_entropy
    if task_type == "multiclass"
    else F.mse_loss
)


@torch.no_grad()
def evaluate(part: str) -> float:
    model.eval()

    eval_batch_size = 8096
    y_pred = (
        torch.cat(
            [
                apply_model(batch)
                for batch in delu.iter_batches(data[part], eval_batch_size)
            ]
        )
        .cpu()
        .numpy()
    )
    y_true = data[part]["y"].cpu().numpy()

    if task_type == "binclass":
        y_pred = np.round(scipy.special.expit(y_pred))
        score = sklearn.metrics.accuracy_score(y_true, y_pred)
    elif task_type == "multiclass":
        y_pred = y_pred.argmax(1)
        score = sklearn.metrics.accuracy_score(y_true, y_pred)
    else:
        assert task_type == "regression"
        score = -(sklearn.metrics.mean_squared_error(y_true, y_pred) ** 0.5 * Y_std)
    return score  # The higher -- the better.


print(f'Test score before training: {evaluate("test"):.4f}')

Test score before training: 0.1617


In [251]:
n_epochs = 1_000_000_000
patience = 30

batch_size = 256
epoch_size = math.ceil(len(train_idx) / batch_size)
timer = delu.tools.Timer()
early_stopping = delu.tools.EarlyStopping(patience, mode="max")

best = {
    "val": -math.inf,
    "test": None,  # Store test score only for the best validation score
    "epoch": -1,
}

test_runs = 5  # Number of test evaluations to average

print(f"Device: {device.type.upper()}")
print("-" * 88 + "\n")
timer.run()

for epoch in range(n_epochs):
    for batch in tqdm(
        delu.iter_batches(data["train"], batch_size, shuffle=True),
        desc=f"Epoch {epoch}",
        total=epoch_size,
    ):
        model.train()
        optimizer.zero_grad()
        loss = loss_fn(apply_model(batch), batch["y"])
        loss.backward()
        optimizer.step()

    val_score = evaluate("val")

    # Update early stopping
    early_stopping.update(val_score)

    # Check if early stopping indicates to stop
    if early_stopping.should_stop():
        print("\nEarly stopping triggered. Evaluating test score for the best validation score...\n")

        # Evaluate test score multiple times for the best validation score
        test_scores = [evaluate("test") for _ in range(test_runs)]
        print(f" the test score for epoch {epoch} is {test_scores}")
        average_test_score = sum(test_scores) / test_runs
        best["test"] = average_test_score

        print(f"Best validation score: {best['val']:.4f}")
        print(f"Averaged Test score over {test_runs} runs: {best['test']:.4f}")
        break

    # Check if current epoch has the best validation score
    if val_score > best["val"]:
        print("🌸 New best epoch! 🌸")
        best = {"val": val_score, "test": None, "epoch": epoch}

    print(f"(val) {val_score:.4f} [time] {timer}")
    print()

print("\n\nResult:")
print(f"Best Epoch: {best['epoch']}")
print(f"Validation Score: {best['val']:.4f}")
if best["test"] is not None:
    print(f"Averaged Test Score: {best['test']:.4f}")
else:
    print("Test score was not evaluated.")

Device: CUDA
----------------------------------------------------------------------------------------



Epoch 0: 100%|██████████| 7/7 [00:00<00:00, 27.49it/s]


🌸 New best epoch! 🌸
(val) 0.9344 [time] 0:00:00.454160



Epoch 1: 100%|██████████| 7/7 [00:00<00:00, 38.48it/s]


(val) 0.9344 [time] 0:00:00.833775



Epoch 2: 100%|██████████| 7/7 [00:00<00:00, 37.47it/s]


(val) 0.9344 [time] 0:00:01.222856



Epoch 3: 100%|██████████| 7/7 [00:00<00:00, 37.19it/s]


(val) 0.9344 [time] 0:00:01.609510



Epoch 4: 100%|██████████| 7/7 [00:00<00:00, 37.26it/s]


(val) 0.9344 [time] 0:00:01.994793



Epoch 5: 100%|██████████| 7/7 [00:00<00:00, 37.30it/s]


(val) 0.9344 [time] 0:00:02.383518



Epoch 6: 100%|██████████| 7/7 [00:00<00:00, 36.99it/s]


(val) 0.9344 [time] 0:00:02.771579



Epoch 7: 100%|██████████| 7/7 [00:00<00:00, 36.79it/s]


(val) 0.9344 [time] 0:00:03.161706



Epoch 8: 100%|██████████| 7/7 [00:00<00:00, 36.80it/s]


🌸 New best epoch! 🌸
(val) 0.9370 [time] 0:00:03.552782



Epoch 9: 100%|██████████| 7/7 [00:00<00:00, 37.32it/s]


🌸 New best epoch! 🌸
(val) 0.9423 [time] 0:00:03.938879



Epoch 10: 100%|██████████| 7/7 [00:00<00:00, 36.73it/s]


(val) 0.9423 [time] 0:00:04.331701



Epoch 11: 100%|██████████| 7/7 [00:00<00:00, 37.12it/s]


(val) 0.9265 [time] 0:00:04.719886



Epoch 12: 100%|██████████| 7/7 [00:00<00:00, 37.95it/s]


(val) 0.9423 [time] 0:00:05.105997



Epoch 13: 100%|██████████| 7/7 [00:00<00:00, 36.97it/s]


🌸 New best epoch! 🌸
(val) 0.9475 [time] 0:00:05.497648



Epoch 14: 100%|██████████| 7/7 [00:00<00:00, 36.69it/s]


(val) 0.9318 [time] 0:00:05.888995



Epoch 15: 100%|██████████| 7/7 [00:00<00:00, 36.83it/s]


(val) 0.9370 [time] 0:00:06.280526



Epoch 16: 100%|██████████| 7/7 [00:00<00:00, 37.22it/s]


(val) 0.9423 [time] 0:00:06.668918



Epoch 17: 100%|██████████| 7/7 [00:00<00:00, 36.88it/s]


(val) 0.9475 [time] 0:00:07.059625



Epoch 18: 100%|██████████| 7/7 [00:00<00:00, 36.70it/s]


(val) 0.9475 [time] 0:00:07.452042



Epoch 19: 100%|██████████| 7/7 [00:00<00:00, 37.38it/s]


(val) 0.9475 [time] 0:00:07.838730



Epoch 20: 100%|██████████| 7/7 [00:00<00:00, 37.06it/s]


(val) 0.9475 [time] 0:00:08.231434



Epoch 21: 100%|██████████| 7/7 [00:00<00:00, 36.67it/s]


🌸 New best epoch! 🌸
(val) 0.9501 [time] 0:00:08.622608



Epoch 22: 100%|██████████| 7/7 [00:00<00:00, 36.61it/s]


(val) 0.9239 [time] 0:00:09.014412



Epoch 23: 100%|██████████| 7/7 [00:00<00:00, 37.20it/s]


(val) 0.9501 [time] 0:00:09.406848



Epoch 24: 100%|██████████| 7/7 [00:00<00:00, 36.46it/s]


(val) 0.9423 [time] 0:00:09.802016



Epoch 25: 100%|██████████| 7/7 [00:00<00:00, 37.30it/s]


(val) 0.9475 [time] 0:00:10.193431



Epoch 26: 100%|██████████| 7/7 [00:00<00:00, 36.64it/s]


(val) 0.9423 [time] 0:00:10.589431



Epoch 27: 100%|██████████| 7/7 [00:00<00:00, 36.60it/s]


🌸 New best epoch! 🌸
(val) 0.9528 [time] 0:00:10.982805



Epoch 28: 100%|██████████| 7/7 [00:00<00:00, 36.70it/s]


(val) 0.9318 [time] 0:00:11.378905



Epoch 29: 100%|██████████| 7/7 [00:00<00:00, 36.89it/s]


(val) 0.9475 [time] 0:00:11.772438



Epoch 30: 100%|██████████| 7/7 [00:00<00:00, 36.42it/s]


(val) 0.9475 [time] 0:00:12.168027



Epoch 31: 100%|██████████| 7/7 [00:00<00:00, 36.95it/s]


(val) 0.9318 [time] 0:00:12.562079



Epoch 32: 100%|██████████| 7/7 [00:00<00:00, 36.56it/s]


(val) 0.9528 [time] 0:00:12.957627



Epoch 33: 100%|██████████| 7/7 [00:00<00:00, 37.00it/s]


(val) 0.9449 [time] 0:00:13.348811



Epoch 34: 100%|██████████| 7/7 [00:00<00:00, 36.00it/s]


(val) 0.9370 [time] 0:00:13.746994



Epoch 35: 100%|██████████| 7/7 [00:00<00:00, 36.53it/s]


(val) 0.9344 [time] 0:00:14.141150



Epoch 36: 100%|██████████| 7/7 [00:00<00:00, 36.46it/s]


(val) 0.9213 [time] 0:00:14.536065



Epoch 37: 100%|██████████| 7/7 [00:00<00:00, 36.15it/s]


(val) 0.9475 [time] 0:00:14.933829



Epoch 38: 100%|██████████| 7/7 [00:00<00:00, 36.34it/s]


(val) 0.9423 [time] 0:00:15.331387



Epoch 39: 100%|██████████| 7/7 [00:00<00:00, 36.77it/s]


(val) 0.9239 [time] 0:00:15.729546



Epoch 40: 100%|██████████| 7/7 [00:00<00:00, 36.57it/s]


(val) 0.9423 [time] 0:00:16.128593



Epoch 41: 100%|██████████| 7/7 [00:00<00:00, 36.30it/s]


(val) 0.9423 [time] 0:00:16.524126



Epoch 42: 100%|██████████| 7/7 [00:00<00:00, 36.38it/s]


(val) 0.9213 [time] 0:00:16.918842



Epoch 43: 100%|██████████| 7/7 [00:00<00:00, 36.68it/s]


(val) 0.9423 [time] 0:00:17.314871



Epoch 44: 100%|██████████| 7/7 [00:00<00:00, 37.01it/s]


(val) 0.9475 [time] 0:00:17.709709



Epoch 45: 100%|██████████| 7/7 [00:00<00:00, 36.45it/s]


(val) 0.9449 [time] 0:00:18.105909



Epoch 46: 100%|██████████| 7/7 [00:00<00:00, 36.03it/s]


(val) 0.9055 [time] 0:00:18.505555



Epoch 47: 100%|██████████| 7/7 [00:00<00:00, 35.62it/s]


(val) 0.9475 [time] 0:00:18.908814



Epoch 48: 100%|██████████| 7/7 [00:00<00:00, 35.74it/s]


(val) 0.9501 [time] 0:00:19.309974



Epoch 49: 100%|██████████| 7/7 [00:00<00:00, 36.48it/s]


(val) 0.9423 [time] 0:00:19.710234



Epoch 50: 100%|██████████| 7/7 [00:00<00:00, 36.17it/s]


(val) 0.9449 [time] 0:00:20.112595



Epoch 51: 100%|██████████| 7/7 [00:00<00:00, 36.48it/s]


(val) 0.9475 [time] 0:00:20.510886



Epoch 52: 100%|██████████| 7/7 [00:00<00:00, 36.59it/s]


(val) 0.9318 [time] 0:00:20.904872



Epoch 53: 100%|██████████| 7/7 [00:00<00:00, 36.38it/s]


(val) 0.9396 [time] 0:00:21.302571



Epoch 54: 100%|██████████| 7/7 [00:00<00:00, 35.91it/s]


(val) 0.9344 [time] 0:00:21.702742



Epoch 55: 100%|██████████| 7/7 [00:00<00:00, 35.88it/s]


(val) 0.9396 [time] 0:00:22.103426



Epoch 56: 100%|██████████| 7/7 [00:00<00:00, 36.05it/s]


(val) 0.9318 [time] 0:00:22.504252



Epoch 57: 100%|██████████| 7/7 [00:00<00:00, 36.30it/s]



Early stopping triggered. Evaluating test score for the best validation score...

 the test score for epoch 57 is [0.9388560157790927, 0.9388560157790927, 0.9388560157790927, 0.9388560157790927, 0.9388560157790927]
Best validation score: 0.9528
Averaged Test score over 5 runs: 0.9389


Result:
Best Epoch: 27
Validation Score: 0.9528
Averaged Test Score: 0.9389


In [252]:
# For demonstration purposes (fast training and bad performance),
# one can set smaller values:
# n_epochs = 20
# patience = 2
n_epochs = 1_000_000_000
patience = 16

batch_size = 256
epoch_size = math.ceil(len(train_idx) / batch_size)
timer = delu.tools.Timer()
early_stopping = delu.tools.EarlyStopping(patience, mode="max")
best = {
    "val": -math.inf,
    "test": -math.inf,
    "epoch": -1,
}

print(f"Device: {device.type.upper()}")
print("-" * 88 + "\n")
timer.run()
for epoch in range(n_epochs):
    for batch in tqdm(
        delu.iter_batches(data["train"], batch_size, shuffle=True),
        desc=f"Epoch {epoch}",
        total=epoch_size,
    ):
        model.train()
        optimizer.zero_grad()
        loss = loss_fn(apply_model(batch), batch["y"])
        loss.backward()
        optimizer.step()

    val_score = evaluate("val")
    test_score = evaluate("test")
    print(f"(val) {val_score:.4f} (test) {test_score:.4f} [time] {timer}")

    early_stopping.update(val_score)
    if early_stopping.should_stop():
        break

    if val_score > best["val"]:
        print("🌸 New best epoch! 🌸")
        best = {"val": val_score, "test": test_score, "epoch": epoch}
    print()

print("\n\nResult:")
print(best)

Device: CUDA
----------------------------------------------------------------------------------------



Epoch 0: 100%|██████████| 7/7 [00:00<00:00, 29.46it/s]


(val) 0.9318 (test) 0.9290 [time] 0:00:00.486449
🌸 New best epoch! 🌸



Epoch 1: 100%|██████████| 7/7 [00:00<00:00, 36.85it/s]


(val) 0.9449 (test) 0.9349 [time] 0:00:00.917690
🌸 New best epoch! 🌸



Epoch 2: 100%|██████████| 7/7 [00:00<00:00, 36.29it/s]


(val) 0.9475 (test) 0.9369 [time] 0:00:01.357165
🌸 New best epoch! 🌸



Epoch 3: 100%|██████████| 7/7 [00:00<00:00, 36.15it/s]


(val) 0.9291 (test) 0.9329 [time] 0:00:01.788525



Epoch 4: 100%|██████████| 7/7 [00:00<00:00, 36.93it/s]


(val) 0.9081 (test) 0.9250 [time] 0:00:02.219209



Epoch 5: 100%|██████████| 7/7 [00:00<00:00, 37.00it/s]


(val) 0.9423 (test) 0.9349 [time] 0:00:02.651217



Epoch 6: 100%|██████████| 7/7 [00:00<00:00, 35.85it/s]


(val) 0.9370 (test) 0.9389 [time] 0:00:03.085070



Epoch 7: 100%|██████████| 7/7 [00:00<00:00, 36.41it/s]


(val) 0.9265 (test) 0.9211 [time] 0:00:03.520750



Epoch 8: 100%|██████████| 7/7 [00:00<00:00, 36.86it/s]


(val) 0.9370 (test) 0.9389 [time] 0:00:03.951609



Epoch 9: 100%|██████████| 7/7 [00:00<00:00, 36.51it/s]


(val) 0.9291 (test) 0.9290 [time] 0:00:04.385092



Epoch 10: 100%|██████████| 7/7 [00:00<00:00, 36.92it/s]


(val) 0.9318 (test) 0.9290 [time] 0:00:04.813484



Epoch 11: 100%|██████████| 7/7 [00:00<00:00, 36.00it/s]


(val) 0.9344 (test) 0.9290 [time] 0:00:05.249231



Epoch 12: 100%|██████████| 7/7 [00:00<00:00, 37.09it/s]


(val) 0.9370 (test) 0.9290 [time] 0:00:05.676646



Epoch 13: 100%|██████████| 7/7 [00:00<00:00, 37.38it/s]


(val) 0.9160 (test) 0.9172 [time] 0:00:06.107090



Epoch 14: 100%|██████████| 7/7 [00:00<00:00, 35.66it/s]


(val) 0.9186 (test) 0.9211 [time] 0:00:06.545377



Epoch 15: 100%|██████████| 7/7 [00:00<00:00, 37.03it/s]


(val) 0.9370 (test) 0.9349 [time] 0:00:06.972477



Epoch 16: 100%|██████████| 7/7 [00:00<00:00, 36.52it/s]


(val) 0.9344 (test) 0.9329 [time] 0:00:07.405049



Epoch 17: 100%|██████████| 7/7 [00:00<00:00, 36.81it/s]


(val) 0.9423 (test) 0.9369 [time] 0:00:07.836598



Epoch 18: 100%|██████████| 7/7 [00:00<00:00, 36.74it/s]


(val) 0.9186 (test) 0.9270 [time] 0:00:08.268169


Result:
{'val': 0.94750656167979, 'test': 0.9368836291913215, 'epoch': 2}
