In [None]:
import pandas as pd
import numpy as np
from sklearn.preprocessing import MinMaxScaler
import json
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras import regularizers

In [5]:
df = pd.read_csv("preprocessed_dataset/preprocessed_train.csv")

In [6]:
df['SEX'] = df['SEX'].map({'M': 1, 'F': 0})

In [7]:
min_max_scaler = MinMaxScaler()
df['AGE'] = min_max_scaler.fit_transform(df[['AGE']])

In [9]:
def obtain_evidences():
    with open("huggingface_dataset/ddxplus/release_evidences.json", "r") as f:
        release_evidences = json.load(f)
    set_evidences = []
    for key, value in release_evidences.items():
        if len(value['possible-values']) > 0:
            for val in value['possible-values']:
                set_evidences.append(f'{key}_@_{val}')
        else:
            set_evidences.append(f'{key}')
    return set_evidences

In [10]:
def obtain_conditions():
    with open("huggingface_dataset/ddxplus/release_conditions.json") as f:
        release_conditions = json.load(f)
    return [a for a in release_conditions.keys()]

In [11]:
features = ['AGE', 'SEX'] + obtain_evidences()
labels = obtain_conditions()

In [None]:
X_train = df[features].values
Y_train = df[labels].values

In [16]:
def elementwise_accuracy(y_true, y_pred, threshold=0.03):
    diffs = tf.abs(y_true - y_pred)
    all_close = tf.reduce_all(diffs <= threshold, axis=1)
    return tf.reduce_mean(tf.cast(all_close, tf.float32))

In [None]:


def top3_tolerant_accuracy(y_true, y_pred):
    top3_indices = tf.argsort(y_pred, direction='DESCENDING')[:, :3]
    batch_size = tf.shape(y_pred)[0]

    batch_indices = tf.range(batch_size)
    batch_indices = tf.reshape(batch_indices, (-1, 1))
    batch_indices = tf.tile(batch_indices, [1, 3])
    indices = tf.stack([batch_indices, top3_indices], axis=2)

    top3_true = tf.gather_nd(y_true, indices)
    top3_pred = tf.gather_nd(y_pred, indices)

    abs_diff = tf.abs(top3_true - top3_pred)
    correct = tf.reduce_all(abs_diff <= 0.05001, axis=1)
    return tf.reduce_mean(tf.cast(correct, tf.float32))

In [24]:
def combined_loss(y_true, y_pred):
    kl = tf.keras.losses.KLDivergence()(y_true, y_pred)
    top3 = 1.0 - top3_tolerant_accuracy(y_true, y_pred)
    return 0.999 * top3 + 0.0001 * kl

In [None]:
input_dim = len(features)
output_dim = len(labels)

model = models.Sequential([
    layers.Input(shape=(input_dim,), name='Input'),
    
    layers.Dense(256, activation='relu'),
    layers.BatchNormalization(),
    layers.Dropout(0.3),

    layers.Dense(128, activation='relu'),
    layers.BatchNormalization(),
    layers.Dropout(0.3),

    layers.Dense(64, activation='relu'),
    layers.BatchNormalization(),
    layers.Dropout(0.3),

    layers.Dense(32, activation='relu'),
    layers.BatchNormalization(),
    layers.Dropout(0.3),
    
    layers.Dense(output_dim, activation='softmax', name='Output')
])

In [25]:
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
    loss=tf.keras.losses.KLDivergence(),
    metrics=[
        elementwise_accuracy,
        top3_tolerant_accuracy
    ]
)

In [26]:
model.fit(
    X_train, Y_train,
    epochs=10,
    batch_size=512
)

Epoch 1/10
[1m2004/2004[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m44s[0m 21ms/step - elementwise_accuracy: 3.9302e-05 - loss: 1.6035 - top3_tolerant_accuracy: 0.0708
Epoch 2/10
[1m2004/2004[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m43s[0m 21ms/step - elementwise_accuracy: 2.8909e-04 - loss: 0.9331 - top3_tolerant_accuracy: 0.1030
Epoch 3/10
[1m2004/2004[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m43s[0m 21ms/step - elementwise_accuracy: 7.2923e-04 - loss: 0.8566 - top3_tolerant_accuracy: 0.1149
Epoch 4/10
[1m2004/2004[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m43s[0m 22ms/step - elementwise_accuracy: 0.0011 - loss: 0.8298 - top3_tolerant_accuracy: 0.1218
Epoch 5/10
[1m2004/2004[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m43s[0m 21ms/step - elementwise_accuracy: 0.0012 - loss: 0.8217 - top3_tolerant_accuracy: 0.1241
Epoch 6/10
[1m2004/2004[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m43s[0m 21ms/step - elementwise_accuracy: 0.0016 - loss: 0.8187 - top3_

<keras.src.callbacks.history.History at 0x3b65c2430>