# evaluation

Evaluate both baseline and TabTransformer models with test set

In [11]:
import keras_preprocessing, tensorflow_addons, keras
from keras import layers
import tensorflow as tf

from pathlib import Path
import pandas as pd


print(tf.config.list_physical_devices('GPU'))

[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]


In [12]:
def split_label(data: pd.DataFrame):
    x = data.copy().drop('stroke', axis=1)
    y = data["stroke"]  # labels

    return x, y



In [13]:
CSV_HEADER = [
    "gender",
    "age",
    "hypertension",
    "heart_disease",
    "ever_married",
    "work_type",
    "residence_type",
    "avg_glucose_level",
    "bmi",
    "smoking_status",
    "stroke",
]

FEATURES = CSV_HEADER[:-1]
TARGET = CSV_HEADER[-1]

test_data_path = Path().resolve().joinpath("dataset/test_data.csv")
test_data_file = str(test_data_path.absolute())
test_data = pd.read_csv(test_data_file, names=CSV_HEADER)

x_test, y_test = split_label(test_data)

y_test = y_test.replace({"No": 0, 'Yes': 1})

x_test = x_test
y_test = y_test


In [14]:
LEARNING_RATE = 0.001
WEIGHT_DECAY = 0.0001
DROPOUT_RATE = 0.1
BATCH_SIZE = 32
NUM_EPOCHS = 100

MLP_MODEL_PATH=str(Path().resolve().joinpath('model/mlp_model'))
TABTRANSFORMER_MODEL_PATH=str(Path().resolve().joinpath('model/tabtransformer_model'))

TARGET_FEATURE_NAME='stroke'
TARGET_LABELS = ["1", "0"]

In [15]:
# data proccessing pipeline

target_label_lookup = layers.StringLookup(
    vocabulary=TARGET_LABELS, mask_token=None, num_oov_indices=0
)


def prepare_example(features, target):
    #target_index = target_label_lookup(target)
    target_index = target
    return features, target_index


def get_dataset_from_csv(csv_file_path, batch_size=128, shuffle=False):
    """dataset from, csv"""
    dataset = tf.data.experimental.make_csv_dataset(
        csv_file_path,
        batch_size=batch_size,
        column_names=CSV_HEADER,
        label_name=TARGET_FEATURE_NAME,
        num_epochs=1,
        header=False,
        na_value="?",
        shuffle=shuffle,
    ).map(prepare_example, num_parallel_calls=tf.data.AUTOTUNE, deterministic=False)
    return dataset.cache()


  return bool(asarray(a1 == a2).all())


In [16]:
def evalate_model(model: keras.Model, test_data_file):
    test_data = get_dataset_from_csv(test_data_file)

    model.compile(
        loss=tf.keras.losses.BinaryCrossentropy(),
        metrics=[
            tf.keras.metrics.AUC(
                num_thresholds=200,
                curve="ROC",
                summation_method="interpolation",
                name="auc",
            ),
            tf.keras.metrics.BinaryAccuracy(name="accuracy"),
            tf.keras.metrics.Precision(),
            tf.keras.metrics.Recall(),
        ]
    )

    model.evaluate(
        x=test_data,
        batch_size=BATCH_SIZE,
        verbose="auto",
        steps=None,
        callbacks=None,
        max_queue_size=10,
        workers=1,
        use_multiprocessing=False,
        return_dict=False,
    )



def predict_model(model: keras.Model, test_data_file):
    test_data = get_dataset_from_csv(test_data_file)

    return model.predict(
        x=test_data,
        batch_size=BATCH_SIZE,
        verbose="auto",
        steps=None,
        callbacks=None,
        max_queue_size=10,
        workers=1,
    )


In [22]:
baseline_model = keras.models.load_model(MLP_MODEL_PATH)
tt_model = keras.models.load_model(TABTRANSFORMER_MODEL_PATH)

evalate_model(baseline_model, test_data_file)
evalate_model(tt_model, test_data_file)



### MLP:

loss: 0.2073 - auc: 0.8241 - accuracy: 0.9397

loss: 0.2145 - auc: 0.8316 - accuracy: 0.9397

### TabTransformer:

loss: 0.2197 - auc: 0.7708 - accuracy: 0.9335

In [18]:
predict_model(baseline_model, test_data_file=test_data_file)



array([[1.70984521e-07],
       [8.14032108e-02],
       [8.77970681e-02],
       [2.46146023e-01],
       [1.66627788e-03],
       [1.65201331e-04],
       [9.12212818e-06],
       [1.26078874e-01],
       [1.11809105e-01],
       [1.36309057e-01],
       [2.07828606e-07],
       [3.21202388e-05],
       [2.73655951e-06],
       [9.29656289e-06],
       [1.04619801e-01],
       [1.20754521e-06],
       [1.49911575e-04],
       [1.24936670e-01],
       [6.52877532e-07],
       [1.81881770e-01],
       [7.24479854e-02],
       [3.59288186e-01],
       [3.29822183e-01],
       [6.66534603e-02],
       [4.61322287e-07],
       [5.52329839e-05],
       [3.66007589e-04],
       [2.28323713e-01],
       [1.33997932e-01],
       [1.66605115e-01],
       [7.15434965e-07],
       [2.21519826e-06],
       [4.62187512e-04],
       [1.27763301e-02],
       [1.41527578e-01],
       [3.10954750e-01],
       [3.60689569e-06],
       [2.95855403e-01],
       [1.01430778e-04],
       [1.84071049e-01],
