In [3]:
import numpy as np
import pandas as pd
from scipy import io
from sklearn.metrics import roc_auc_score
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout, BatchNormalization, LeakyReLU
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping, LearningRateScheduler, ReduceLROnPlateau
from tensorflow.keras.regularizers import l2
from imblearn.over_sampling import SMOTE

# Load data
y_tr = pd.read_csv('tox21_labels_train.csv.gz', index_col=0, compression="gzip")
y_te = pd.read_csv('tox21_labels_test.csv.gz', index_col=0, compression="gzip")
x_tr_dense = pd.read_csv('tox21_dense_train.csv.gz', index_col=0, compression="gzip").values
x_te_dense = pd.read_csv('tox21_dense_test.csv.gz', index_col=0, compression="gzip").values
x_tr_sparse = io.mmread('tox21_sparse_train.mtx.gz').tocsc()
x_te_sparse = io.mmread('tox21_sparse_test.mtx.gz').tocsc()

# Filter and concatenate dense + sparse features
sparse_col_idx = ((x_tr_sparse > 0).mean(0) > 0.05).A.ravel()
x_tr = np.hstack([x_tr_dense, x_tr_sparse[:, sparse_col_idx].A])
x_te = np.hstack([x_te_dense, x_te_sparse[:, sparse_col_idx].A])

# Standardize features with tanh scaling
scaler = StandardScaler()
x_tr = np.tanh(scaler.fit_transform(x_tr))
x_te = np.tanh(scaler.transform(x_te))

# Learning rate decay function
def lr_schedule(epoch, lr):
    return lr * 0.95 if epoch > 10 else lr

# Create a shared multi-task model architecture with LeakyReLU
def create_model(input_dim):
    model = Sequential([
        Dense(256, input_dim=input_dim, kernel_regularizer=l2(1e-5)),
        BatchNormalization(),
        LeakyReLU(alpha=0.2),
        Dropout(0.4),

        Dense(128, kernel_regularizer=l2(1e-5)),
        BatchNormalization(),
        LeakyReLU(alpha=0.2),
        Dropout(0.3),

        Dense(1, activation='sigmoid')  # Single output for binary classification
    ])
    return model

# Early stopping and learning rate scheduler
early_stopping = EarlyStopping(monitor='val_loss', patience=20, restore_best_weights=True)
reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=5, min_lr=1e-6)

# Train and evaluate the model for each task independently
auc_scores = []

for i, target in enumerate(y_tr.columns):
    print(f"\nTraining on assay: {target}")

    # Select valid rows for the current task
    valid_rows = np.isfinite(y_tr[target]).values
    x_target, y_target = x_tr[valid_rows], y_tr[target][valid_rows]

    # Apply SMOTE to handle class imbalance
    smote = SMOTE(random_state=42)
    x_balanced, y_balanced = smote.fit_resample(x_target, y_target)

    # Split into train/validation sets
    x_train, x_val, y_train, y_val = train_test_split(
        x_balanced, y_balanced, test_size=0.2, random_state=42
    )

    # Create and compile a new model and optimizer for the current task
    model = create_model(input_dim=x_tr.shape[1])
    optimizer = Adam(learning_rate=1e-3)  # Create a new optimizer instance
    model.compile(optimizer=optimizer, loss='binary_crossentropy', metrics=['AUC'])

    # Train the model
    model.fit(
        x_train, y_train, validation_data=(x_val, y_val),
        epochs=100, batch_size=32, callbacks=[early_stopping, reduce_lr], verbose=2
    )

    # Evaluate on the test set
    valid_test_rows = np.isfinite(y_te[target]).values
    y_test = y_te[target][valid_test_rows].values
    p_test = model.predict(x_te[valid_test_rows]).ravel()

    # Calculate the AUC score for the current task
    auc = roc_auc_score(y_test, p_test)
    auc_scores.append(auc)
    print(f"{target}: Test AUC = {auc:.3f}")

# Calculate and print the average AUC score
avg_auc = np.mean(auc_scores)
print(f"\nAverage Test AUC: {avg_auc:.3f}")



Training on assay: NR.AhR


  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


Epoch 1/100
373/373 - 8s - 21ms/step - AUC: 0.9036 - loss: 0.4028 - val_AUC: 0.9506 - val_loss: 0.2895 - learning_rate: 0.0010
Epoch 2/100
373/373 - 5s - 13ms/step - AUC: 0.9375 - loss: 0.3252 - val_AUC: 0.9591 - val_loss: 0.2625 - learning_rate: 0.0010
Epoch 3/100
373/373 - 1s - 3ms/step - AUC: 0.9497 - loss: 0.2928 - val_AUC: 0.9670 - val_loss: 0.2374 - learning_rate: 0.0010
Epoch 4/100
373/373 - 1s - 3ms/step - AUC: 0.9559 - loss: 0.2739 - val_AUC: 0.9728 - val_loss: 0.2162 - learning_rate: 0.0010
Epoch 5/100
373/373 - 1s - 2ms/step - AUC: 0.9656 - loss: 0.2444 - val_AUC: 0.9696 - val_loss: 0.2261 - learning_rate: 0.0010
Epoch 6/100
373/373 - 1s - 3ms/step - AUC: 0.9680 - loss: 0.2356 - val_AUC: 0.9780 - val_loss: 0.1925 - learning_rate: 0.0010
Epoch 7/100
373/373 - 1s - 2ms/step - AUC: 0.9719 - loss: 0.2215 - val_AUC: 0.9760 - val_loss: 0.2034 - learning_rate: 0.0010
Epoch 8/100
373/373 - 1s - 3ms/step - AUC: 0.9742 - loss: 0.2124 - val_AUC: 0.9788 - val_loss: 0.1917 - learning_rat

  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


Epoch 1/100
464/464 - 9s - 20ms/step - AUC: 0.9099 - loss: 0.3787 - val_AUC: 0.9637 - val_loss: 0.2642 - learning_rate: 0.0010
Epoch 2/100
464/464 - 3s - 7ms/step - AUC: 0.9576 - loss: 0.2711 - val_AUC: 0.9776 - val_loss: 0.2087 - learning_rate: 0.0010
Epoch 3/100
464/464 - 1s - 2ms/step - AUC: 0.9737 - loss: 0.2175 - val_AUC: 0.9849 - val_loss: 0.1709 - learning_rate: 0.0010
Epoch 4/100
464/464 - 1s - 3ms/step - AUC: 0.9815 - loss: 0.1834 - val_AUC: 0.9865 - val_loss: 0.1582 - learning_rate: 0.0010
Epoch 5/100
464/464 - 1s - 2ms/step - AUC: 0.9838 - loss: 0.1723 - val_AUC: 0.9928 - val_loss: 0.1166 - learning_rate: 0.0010
Epoch 6/100
464/464 - 1s - 3ms/step - AUC: 0.9870 - loss: 0.1549 - val_AUC: 0.9908 - val_loss: 0.1254 - learning_rate: 0.0010
Epoch 7/100
464/464 - 1s - 2ms/step - AUC: 0.9897 - loss: 0.1386 - val_AUC: 0.9934 - val_loss: 0.1148 - learning_rate: 0.0010
Epoch 8/100
464/464 - 1s - 2ms/step - AUC: 0.9909 - loss: 0.1318 - val_AUC: 0.9955 - val_loss: 0.1000 - learning_rate

  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


Epoch 1/100
428/428 - 8s - 19ms/step - AUC: 0.9493 - loss: 0.2928 - val_AUC: 0.9900 - val_loss: 0.1482 - learning_rate: 0.0010
Epoch 2/100
428/428 - 5s - 11ms/step - AUC: 0.9799 - loss: 0.1877 - val_AUC: 0.9924 - val_loss: 0.1180 - learning_rate: 0.0010
Epoch 3/100
428/428 - 1s - 2ms/step - AUC: 0.9878 - loss: 0.1482 - val_AUC: 0.9962 - val_loss: 0.0799 - learning_rate: 0.0010
Epoch 4/100
428/428 - 1s - 2ms/step - AUC: 0.9920 - loss: 0.1191 - val_AUC: 0.9958 - val_loss: 0.0845 - learning_rate: 0.0010
Epoch 5/100
428/428 - 1s - 3ms/step - AUC: 0.9930 - loss: 0.1103 - val_AUC: 0.9967 - val_loss: 0.0730 - learning_rate: 0.0010
Epoch 6/100
428/428 - 1s - 2ms/step - AUC: 0.9936 - loss: 0.1038 - val_AUC: 0.9980 - val_loss: 0.0629 - learning_rate: 0.0010
Epoch 7/100
428/428 - 1s - 2ms/step - AUC: 0.9940 - loss: 0.1054 - val_AUC: 0.9978 - val_loss: 0.0629 - learning_rate: 0.0010
Epoch 8/100
428/428 - 1s - 2ms/step - AUC: 0.9954 - loss: 0.0898 - val_AUC: 0.9968 - val_loss: 0.0674 - learning_rat

  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


Epoch 1/100
354/354 - 8s - 24ms/step - AUC: 0.9035 - loss: 0.4021 - val_AUC: 0.9613 - val_loss: 0.2568 - learning_rate: 0.0010
Epoch 2/100
354/354 - 1s - 2ms/step - AUC: 0.9528 - loss: 0.2827 - val_AUC: 0.9764 - val_loss: 0.1967 - learning_rate: 0.0010
Epoch 3/100
354/354 - 1s - 2ms/step - AUC: 0.9694 - loss: 0.2274 - val_AUC: 0.9817 - val_loss: 0.1727 - learning_rate: 0.0010
Epoch 4/100
354/354 - 1s - 2ms/step - AUC: 0.9764 - loss: 0.2000 - val_AUC: 0.9884 - val_loss: 0.1394 - learning_rate: 0.0010
Epoch 5/100
354/354 - 1s - 2ms/step - AUC: 0.9800 - loss: 0.1841 - val_AUC: 0.9855 - val_loss: 0.1550 - learning_rate: 0.0010
Epoch 6/100
354/354 - 1s - 2ms/step - AUC: 0.9838 - loss: 0.1658 - val_AUC: 0.9909 - val_loss: 0.1209 - learning_rate: 0.0010
Epoch 7/100
354/354 - 1s - 3ms/step - AUC: 0.9862 - loss: 0.1545 - val_AUC: 0.9931 - val_loss: 0.1132 - learning_rate: 0.0010
Epoch 8/100
354/354 - 1s - 2ms/step - AUC: 0.9891 - loss: 0.1368 - val_AUC: 0.9901 - val_loss: 0.1286 - learning_rate

  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


Epoch 1/100
350/350 - 9s - 25ms/step - AUC: 0.7813 - loss: 0.5784 - val_AUC: 0.8461 - val_loss: 0.4870 - learning_rate: 0.0010
Epoch 2/100
350/350 - 1s - 2ms/step - AUC: 0.8340 - loss: 0.5073 - val_AUC: 0.8700 - val_loss: 0.4696 - learning_rate: 0.0010
Epoch 3/100
350/350 - 1s - 2ms/step - AUC: 0.8622 - loss: 0.4686 - val_AUC: 0.8769 - val_loss: 0.4572 - learning_rate: 0.0010
Epoch 4/100
350/350 - 1s - 2ms/step - AUC: 0.8789 - loss: 0.4449 - val_AUC: 0.8846 - val_loss: 0.4370 - learning_rate: 0.0010
Epoch 5/100
350/350 - 1s - 4ms/step - AUC: 0.8952 - loss: 0.4167 - val_AUC: 0.9050 - val_loss: 0.4039 - learning_rate: 0.0010
Epoch 6/100
350/350 - 1s - 4ms/step - AUC: 0.9046 - loss: 0.3997 - val_AUC: 0.9141 - val_loss: 0.3903 - learning_rate: 0.0010
Epoch 7/100
350/350 - 1s - 2ms/step - AUC: 0.9133 - loss: 0.3849 - val_AUC: 0.9295 - val_loss: 0.3534 - learning_rate: 0.0010
Epoch 8/100
350/350 - 1s - 2ms/step - AUC: 0.9250 - loss: 0.3620 - val_AUC: 0.9390 - val_loss: 0.3360 - learning_rate

  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


Epoch 1/100
430/430 - 8s - 19ms/step - AUC: 0.9117 - loss: 0.3852 - val_AUC: 0.9648 - val_loss: 0.2580 - learning_rate: 0.0010
Epoch 2/100
430/430 - 6s - 13ms/step - AUC: 0.9567 - loss: 0.2736 - val_AUC: 0.9762 - val_loss: 0.2073 - learning_rate: 0.0010
Epoch 3/100
430/430 - 1s - 3ms/step - AUC: 0.9695 - loss: 0.2313 - val_AUC: 0.9826 - val_loss: 0.1878 - learning_rate: 0.0010
Epoch 4/100
430/430 - 1s - 3ms/step - AUC: 0.9764 - loss: 0.2035 - val_AUC: 0.9875 - val_loss: 0.1543 - learning_rate: 0.0010
Epoch 5/100
430/430 - 1s - 3ms/step - AUC: 0.9801 - loss: 0.1869 - val_AUC: 0.9893 - val_loss: 0.1388 - learning_rate: 0.0010
Epoch 6/100
430/430 - 1s - 2ms/step - AUC: 0.9837 - loss: 0.1722 - val_AUC: 0.9928 - val_loss: 0.1234 - learning_rate: 0.0010
Epoch 7/100
430/430 - 1s - 2ms/step - AUC: 0.9865 - loss: 0.1557 - val_AUC: 0.9929 - val_loss: 0.1205 - learning_rate: 0.0010
Epoch 8/100
430/430 - 1s - 3ms/step - AUC: 0.9870 - loss: 0.1547 - val_AUC: 0.9930 - val_loss: 0.1230 - learning_rat

  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


Epoch 1/100
411/411 - 9s - 22ms/step - AUC: 0.9405 - loss: 0.3170 - val_AUC: 0.9697 - val_loss: 0.2565 - learning_rate: 0.0010
Epoch 2/100
411/411 - 4s - 10ms/step - AUC: 0.9814 - loss: 0.1762 - val_AUC: 0.9915 - val_loss: 0.1142 - learning_rate: 0.0010
Epoch 3/100
411/411 - 1s - 2ms/step - AUC: 0.9900 - loss: 0.1281 - val_AUC: 0.9958 - val_loss: 0.0815 - learning_rate: 0.0010
Epoch 4/100
411/411 - 1s - 3ms/step - AUC: 0.9928 - loss: 0.1108 - val_AUC: 0.9963 - val_loss: 0.0773 - learning_rate: 0.0010
Epoch 5/100
411/411 - 1s - 3ms/step - AUC: 0.9937 - loss: 0.1032 - val_AUC: 0.9981 - val_loss: 0.0657 - learning_rate: 0.0010
Epoch 6/100
411/411 - 1s - 3ms/step - AUC: 0.9953 - loss: 0.0885 - val_AUC: 0.9983 - val_loss: 0.0481 - learning_rate: 0.0010
Epoch 7/100
411/411 - 1s - 2ms/step - AUC: 0.9951 - loss: 0.0925 - val_AUC: 0.9975 - val_loss: 0.0609 - learning_rate: 0.0010
Epoch 8/100
411/411 - 1s - 2ms/step - AUC: 0.9960 - loss: 0.0821 - val_AUC: 0.9970 - val_loss: 0.0672 - learning_rat

  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


Epoch 1/100
313/313 - 8s - 27ms/step - AUC: 0.8015 - loss: 0.5582 - val_AUC: 0.8659 - val_loss: 0.4678 - learning_rate: 0.0010
Epoch 2/100
313/313 - 4s - 13ms/step - AUC: 0.8684 - loss: 0.4605 - val_AUC: 0.8977 - val_loss: 0.4102 - learning_rate: 0.0010
Epoch 3/100
313/313 - 1s - 2ms/step - AUC: 0.8946 - loss: 0.4166 - val_AUC: 0.9147 - val_loss: 0.3783 - learning_rate: 0.0010
Epoch 4/100
313/313 - 1s - 4ms/step - AUC: 0.9094 - loss: 0.3906 - val_AUC: 0.9084 - val_loss: 0.4010 - learning_rate: 0.0010
Epoch 5/100
313/313 - 1s - 4ms/step - AUC: 0.9180 - loss: 0.3735 - val_AUC: 0.9230 - val_loss: 0.3668 - learning_rate: 0.0010
Epoch 6/100
313/313 - 1s - 5ms/step - AUC: 0.9266 - loss: 0.3543 - val_AUC: 0.9397 - val_loss: 0.3245 - learning_rate: 0.0010
Epoch 7/100
313/313 - 1s - 3ms/step - AUC: 0.9376 - loss: 0.3313 - val_AUC: 0.9254 - val_loss: 0.3598 - learning_rate: 0.0010
Epoch 8/100
313/313 - 1s - 3ms/step - AUC: 0.9401 - loss: 0.3234 - val_AUC: 0.9493 - val_loss: 0.3043 - learning_rat

  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


Epoch 1/100
450/450 - 6s - 14ms/step - AUC: 0.9220 - loss: 0.3600 - val_AUC: 0.9672 - val_loss: 0.2302 - learning_rate: 0.0010
Epoch 2/100
450/450 - 1s - 2ms/step - AUC: 0.9688 - loss: 0.2267 - val_AUC: 0.9810 - val_loss: 0.1805 - learning_rate: 0.0010
Epoch 3/100
450/450 - 1s - 3ms/step - AUC: 0.9801 - loss: 0.1794 - val_AUC: 0.9915 - val_loss: 0.1275 - learning_rate: 0.0010
Epoch 4/100
450/450 - 2s - 3ms/step - AUC: 0.9864 - loss: 0.1485 - val_AUC: 0.9921 - val_loss: 0.1055 - learning_rate: 0.0010
Epoch 5/100
450/450 - 1s - 3ms/step - AUC: 0.9883 - loss: 0.1396 - val_AUC: 0.9949 - val_loss: 0.0990 - learning_rate: 0.0010
Epoch 6/100
450/450 - 1s - 2ms/step - AUC: 0.9905 - loss: 0.1240 - val_AUC: 0.9943 - val_loss: 0.0952 - learning_rate: 0.0010
Epoch 7/100
450/450 - 1s - 2ms/step - AUC: 0.9923 - loss: 0.1130 - val_AUC: 0.9954 - val_loss: 0.0800 - learning_rate: 0.0010
Epoch 8/100
450/450 - 1s - 2ms/step - AUC: 0.9931 - loss: 0.1065 - val_AUC: 0.9918 - val_loss: 0.1204 - learning_rate

  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


Epoch 1/100
399/399 - 9s - 22ms/step - AUC: 0.8603 - loss: 0.4773 - val_AUC: 0.9435 - val_loss: 0.3123 - learning_rate: 0.0010
Epoch 2/100
399/399 - 4s - 10ms/step - AUC: 0.9307 - loss: 0.3419 - val_AUC: 0.9706 - val_loss: 0.2402 - learning_rate: 0.0010
Epoch 3/100
399/399 - 1s - 2ms/step - AUC: 0.9538 - loss: 0.2812 - val_AUC: 0.9780 - val_loss: 0.1960 - learning_rate: 0.0010
Epoch 4/100
399/399 - 1s - 3ms/step - AUC: 0.9640 - loss: 0.2464 - val_AUC: 0.9762 - val_loss: 0.2123 - learning_rate: 0.0010
Epoch 5/100
399/399 - 1s - 3ms/step - AUC: 0.9741 - loss: 0.2126 - val_AUC: 0.9859 - val_loss: 0.1622 - learning_rate: 0.0010
Epoch 6/100
399/399 - 1s - 3ms/step - AUC: 0.9769 - loss: 0.2006 - val_AUC: 0.9899 - val_loss: 0.1309 - learning_rate: 0.0010
Epoch 7/100
399/399 - 1s - 3ms/step - AUC: 0.9821 - loss: 0.1777 - val_AUC: 0.9890 - val_loss: 0.1461 - learning_rate: 0.0010
Epoch 8/100
399/399 - 1s - 3ms/step - AUC: 0.9821 - loss: 0.1770 - val_AUC: 0.9896 - val_loss: 0.1316 - learning_rat

  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


Epoch 1/100
319/319 - 9s - 27ms/step - AUC: 0.8912 - loss: 0.4270 - val_AUC: 0.9404 - val_loss: 0.3193 - learning_rate: 0.0010
Epoch 2/100
319/319 - 4s - 13ms/step - AUC: 0.9311 - loss: 0.3423 - val_AUC: 0.9543 - val_loss: 0.2897 - learning_rate: 0.0010
Epoch 3/100
319/319 - 1s - 2ms/step - AUC: 0.9458 - loss: 0.3047 - val_AUC: 0.9651 - val_loss: 0.2481 - learning_rate: 0.0010
Epoch 4/100
319/319 - 1s - 4ms/step - AUC: 0.9516 - loss: 0.2880 - val_AUC: 0.9689 - val_loss: 0.2290 - learning_rate: 0.0010
Epoch 5/100
319/319 - 1s - 2ms/step - AUC: 0.9601 - loss: 0.2623 - val_AUC: 0.9717 - val_loss: 0.2170 - learning_rate: 0.0010
Epoch 6/100
319/319 - 1s - 4ms/step - AUC: 0.9654 - loss: 0.2451 - val_AUC: 0.9627 - val_loss: 0.2546 - learning_rate: 0.0010
Epoch 7/100
319/319 - 1s - 4ms/step - AUC: 0.9690 - loss: 0.2344 - val_AUC: 0.9796 - val_loss: 0.1884 - learning_rate: 0.0010
Epoch 8/100
319/319 - 1s - 3ms/step - AUC: 0.9730 - loss: 0.2182 - val_AUC: 0.9782 - val_loss: 0.1924 - learning_rat

  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


Epoch 1/100
417/417 - 7s - 18ms/step - AUC: 0.8783 - loss: 0.4435 - val_AUC: 0.9542 - val_loss: 0.2797 - learning_rate: 0.0010
Epoch 2/100
417/417 - 1s - 2ms/step - AUC: 0.9374 - loss: 0.3214 - val_AUC: 0.9583 - val_loss: 0.2620 - learning_rate: 0.0010
Epoch 3/100
417/417 - 1s - 3ms/step - AUC: 0.9567 - loss: 0.2694 - val_AUC: 0.9767 - val_loss: 0.2074 - learning_rate: 0.0010
Epoch 4/100
417/417 - 1s - 2ms/step - AUC: 0.9680 - loss: 0.2329 - val_AUC: 0.9827 - val_loss: 0.1701 - learning_rate: 0.0010
Epoch 5/100
417/417 - 1s - 3ms/step - AUC: 0.9735 - loss: 0.2126 - val_AUC: 0.9860 - val_loss: 0.1563 - learning_rate: 0.0010
Epoch 6/100
417/417 - 1s - 3ms/step - AUC: 0.9794 - loss: 0.1881 - val_AUC: 0.9884 - val_loss: 0.1358 - learning_rate: 0.0010
Epoch 7/100
417/417 - 1s - 3ms/step - AUC: 0.9820 - loss: 0.1781 - val_AUC: 0.9892 - val_loss: 0.1421 - learning_rate: 0.0010
Epoch 8/100
417/417 - 1s - 2ms/step - AUC: 0.9859 - loss: 0.1564 - val_AUC: 0.9921 - val_loss: 0.1125 - learning_rate