In [None]:



import numpy as np
import pandas as pd
from scipy import io
from sklearn.metrics import roc_auc_score
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout, BatchNormalization, LeakyReLU
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau
from tensorflow.keras.regularizers import l2
from imblearn.over_sampling import SMOTE

# Load data
y_tr = pd.read_csv('tox21_labels_train.csv.gz', index_col=0, compression="gzip")
y_te = pd.read_csv('tox21_labels_test.csv.gz', index_col=0, compression="gzip")
x_tr_dense = pd.read_csv('tox21_dense_train.csv.gz', index_col=0, compression="gzip").values
x_te_dense = pd.read_csv('tox21_dense_test.csv.gz', index_col=0, compression="gzip").values
x_tr_sparse = io.mmread('tox21_sparse_train.mtx.gz').tocsc()
x_te_sparse = io.mmread('tox21_sparse_test.mtx.gz').tocsc()

# Filter and concatenate dense + sparse features
sparse_col_idx = ((x_tr_sparse > 0).mean(0) > 0.05).A.ravel()
x_tr = np.hstack([x_tr_dense, x_tr_sparse[:, sparse_col_idx].A])
x_te = np.hstack([x_te_dense, x_te_sparse[:, sparse_col_idx].A])

# Standardize features with tanh scaling
scaler = StandardScaler()
x_tr = np.tanh(scaler.fit_transform(x_tr))
x_te = np.tanh(scaler.transform(x_te))

# Create a shared multi-task model architecture
def create_model(input_dim):
    model = Sequential([
        Dense(1024, input_dim=input_dim, kernel_regularizer=l2(1e-5)),
        BatchNormalization(),
        LeakyReLU(alpha=0.2),
        Dropout(0.5),  # Increased dropout to reduce overfitting

        Dense(512, kernel_regularizer=l2(1e-5)),
        BatchNormalization(),
        LeakyReLU(alpha=0.2),
        Dropout(0.4),

        Dense(1, activation='sigmoid')  # Binary output for classification
    ])
    return model

# Early stopping and learning rate reduction
early_stopping = EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True)
reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=5, min_lr=1e-6)

# Train and evaluate the model for each task independently
test_auc_scores = []
models = []  # Store models for later use

for i, target in enumerate(y_tr.columns):
    print(f"\nTraining on assay: {target}")

    # Select valid rows for the current task
    valid_rows = np.isfinite(y_tr[target]).values
    x_target, y_target = x_tr[valid_rows], y_tr[target][valid_rows]

    # Apply SMOTE to handle class imbalance
    smote = SMOTE(random_state=42)
    x_balanced, y_balanced = smote.fit_resample(x_target, y_target)

    # Split into train/validation sets
    x_train, x_val, y_train, y_val = train_test_split(
        x_balanced, y_balanced, test_size=0.2, random_state=42
    )

    # Create and compile a new model
    model = create_model(input_dim=x_tr.shape[1])
    optimizer = Adam(learning_rate=0.01)
    model.compile(optimizer=optimizer, loss='binary_crossentropy', metrics=['AUC'])

    # Train the model
    model.fit(
        x_train, y_train, validation_data=(x_val, y_val),
        epochs=100, batch_size=128,  # Increased batch size
        callbacks=[early_stopping, reduce_lr], verbose=2
    )

    # Store the model
    models.append(model)

    # Evaluate on the test set
    valid_test_rows = np.isfinite(y_te[target]).values
    y_test = y_te[target][valid_test_rows].values
    test_predictions = model.predict(x_te[valid_test_rows]).ravel()
    test_auc = roc_auc_score(y_test, test_predictions)
    test_auc_scores.append(test_auc)

    print(f"{target}: Test AUC = {test_auc:.3f}")

# Calculate and print the average Test AUC score
avg_test_auc = np.mean(test_auc_scores)
print(f"\nAverage Test AUC: {avg_test_auc:.3f}")


Training on assay: NR.AhR


  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


Epoch 1/100
94/94 - 12s - 131ms/step - AUC: 0.8717 - loss: 0.5891 - val_AUC: 0.9320 - val_loss: 0.4226 - learning_rate: 0.0100
Epoch 2/100
94/94 - 0s - 4ms/step - AUC: 0.9296 - loss: 0.4234 - val_AUC: 0.9282 - val_loss: 0.4390 - learning_rate: 0.0100
Epoch 3/100
94/94 - 0s - 4ms/step - AUC: 0.9412 - loss: 0.3971 - val_AUC: 0.9488 - val_loss: 0.3955 - learning_rate: 0.0100
Epoch 4/100
94/94 - 1s - 7ms/step - AUC: 0.9501 - loss: 0.3758 - val_AUC: 0.9635 - val_loss: 0.3481 - learning_rate: 0.0100
Epoch 5/100
94/94 - 1s - 6ms/step - AUC: 0.9560 - loss: 0.3621 - val_AUC: 0.9604 - val_loss: 0.3574 - learning_rate: 0.0100
Epoch 6/100
94/94 - 0s - 4ms/step - AUC: 0.9602 - loss: 0.3536 - val_AUC: 0.9689 - val_loss: 0.3486 - learning_rate: 0.0100
Epoch 7/100
94/94 - 1s - 6ms/step - AUC: 0.9640 - loss: 0.3508 - val_AUC: 0.9706 - val_loss: 0.3312 - learning_rate: 0.0100
Epoch 8/100
94/94 - 0s - 4ms/step - AUC: 0.9647 - loss: 0.3534 - val_AUC: 0.9703 - val_loss: 0.3313 - learning_rate: 0.0100
Epoch

  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


Epoch 1/100
116/116 - 10s - 85ms/step - AUC: 0.8689 - loss: 0.5491 - val_AUC: 0.9250 - val_loss: 0.5261 - learning_rate: 0.0100
Epoch 2/100
116/116 - 3s - 24ms/step - AUC: 0.9361 - loss: 0.4037 - val_AUC: 0.9417 - val_loss: 0.4593 - learning_rate: 0.0100
Epoch 3/100
116/116 - 1s - 5ms/step - AUC: 0.9560 - loss: 0.3618 - val_AUC: 0.9652 - val_loss: 0.4050 - learning_rate: 0.0100
Epoch 4/100
116/116 - 1s - 5ms/step - AUC: 0.9688 - loss: 0.3285 - val_AUC: 0.9798 - val_loss: 0.3108 - learning_rate: 0.0100
Epoch 5/100
116/116 - 0s - 3ms/step - AUC: 0.9745 - loss: 0.3141 - val_AUC: 0.9825 - val_loss: 0.2881 - learning_rate: 0.0100
Epoch 6/100
116/116 - 1s - 5ms/step - AUC: 0.9784 - loss: 0.3035 - val_AUC: 0.9821 - val_loss: 0.2945 - learning_rate: 0.0100
Epoch 7/100
116/116 - 1s - 7ms/step - AUC: 0.9801 - loss: 0.3018 - val_AUC: 0.9861 - val_loss: 0.2888 - learning_rate: 0.0100
Epoch 8/100
116/116 - 1s - 5ms/step - AUC: 0.9818 - loss: 0.2946 - val_AUC: 0.9867 - val_loss: 0.2824 - learning_ra

  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


Epoch 1/100
107/107 - 10s - 89ms/step - AUC: 0.8989 - loss: 0.5326 - val_AUC: 0.9435 - val_loss: 0.4775 - learning_rate: 0.0100
Epoch 2/100
107/107 - 3s - 27ms/step - AUC: 0.9699 - loss: 0.3090 - val_AUC: 0.9916 - val_loss: 0.2518 - learning_rate: 0.0100
Epoch 3/100
107/107 - 1s - 6ms/step - AUC: 0.9829 - loss: 0.2579 - val_AUC: 0.9954 - val_loss: 0.2247 - learning_rate: 0.0100
Epoch 4/100
107/107 - 1s - 6ms/step - AUC: 0.9871 - loss: 0.2402 - val_AUC: 0.9915 - val_loss: 0.2205 - learning_rate: 0.0100
Epoch 5/100
107/107 - 1s - 6ms/step - AUC: 0.9882 - loss: 0.2371 - val_AUC: 0.9940 - val_loss: 0.2174 - learning_rate: 0.0100
Epoch 6/100
107/107 - 1s - 6ms/step - AUC: 0.9910 - loss: 0.2188 - val_AUC: 0.9926 - val_loss: 0.2292 - learning_rate: 0.0100
Epoch 7/100
107/107 - 0s - 4ms/step - AUC: 0.9918 - loss: 0.2120 - val_AUC: 0.9968 - val_loss: 0.1800 - learning_rate: 0.0100
Epoch 8/100
107/107 - 1s - 6ms/step - AUC: 0.9924 - loss: 0.2121 - val_AUC: 0.9965 - val_loss: 0.1852 - learning_ra

  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


Epoch 1/100
89/89 - 10s - 109ms/step - AUC: 0.8408 - loss: 0.6785 - val_AUC: 0.9246 - val_loss: 0.4428 - learning_rate: 0.0100
Epoch 2/100
89/89 - 0s - 4ms/step - AUC: 0.9308 - loss: 0.4231 - val_AUC: 0.9458 - val_loss: 0.4012 - learning_rate: 0.0100
Epoch 3/100
89/89 - 0s - 4ms/step - AUC: 0.9531 - loss: 0.3685 - val_AUC: 0.9545 - val_loss: 0.3875 - learning_rate: 0.0100
Epoch 4/100
89/89 - 0s - 4ms/step - AUC: 0.9660 - loss: 0.3318 - val_AUC: 0.9790 - val_loss: 0.2907 - learning_rate: 0.0100
Epoch 5/100
89/89 - 1s - 7ms/step - AUC: 0.9729 - loss: 0.3157 - val_AUC: 0.9811 - val_loss: 0.3135 - learning_rate: 0.0100
Epoch 6/100
89/89 - 1s - 7ms/step - AUC: 0.9775 - loss: 0.3029 - val_AUC: 0.9831 - val_loss: 0.2727 - learning_rate: 0.0100
Epoch 7/100
89/89 - 0s - 4ms/step - AUC: 0.9808 - loss: 0.2906 - val_AUC: 0.9864 - val_loss: 0.2716 - learning_rate: 0.0100
Epoch 8/100
89/89 - 1s - 7ms/step - AUC: 0.9835 - loss: 0.2841 - val_AUC: 0.9815 - val_loss: 0.2950 - learning_rate: 0.0100
Epoch

  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


Epoch 1/100
88/88 - 9s - 104ms/step - AUC: 0.7386 - loss: 0.7464 - val_AUC: 0.7882 - val_loss: 0.6972 - learning_rate: 0.0100
Epoch 2/100
88/88 - 4s - 49ms/step - AUC: 0.8076 - loss: 0.6126 - val_AUC: 0.8132 - val_loss: 0.6259 - learning_rate: 0.0100
Epoch 3/100
88/88 - 1s - 6ms/step - AUC: 0.8349 - loss: 0.5809 - val_AUC: 0.8641 - val_loss: 0.5477 - learning_rate: 0.0100
Epoch 4/100
88/88 - 1s - 8ms/step - AUC: 0.8629 - loss: 0.5464 - val_AUC: 0.8740 - val_loss: 0.5451 - learning_rate: 0.0100
Epoch 5/100
88/88 - 1s - 6ms/step - AUC: 0.8767 - loss: 0.5337 - val_AUC: 0.8668 - val_loss: 0.5752 - learning_rate: 0.0100
Epoch 6/100
88/88 - 0s - 5ms/step - AUC: 0.8895 - loss: 0.5215 - val_AUC: 0.8897 - val_loss: 0.5380 - learning_rate: 0.0100
Epoch 7/100
88/88 - 0s - 4ms/step - AUC: 0.8952 - loss: 0.5222 - val_AUC: 0.8991 - val_loss: 0.5386 - learning_rate: 0.0100
Epoch 8/100
88/88 - 1s - 7ms/step - AUC: 0.9019 - loss: 0.5193 - val_AUC: 0.9070 - val_loss: 0.5214 - learning_rate: 0.0100
Epoch

  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


Epoch 1/100
108/108 - 10s - 90ms/step - AUC: 0.8481 - loss: 0.6373 - val_AUC: 0.9208 - val_loss: 0.4735 - learning_rate: 0.0100
Epoch 2/100
108/108 - 3s - 25ms/step - AUC: 0.9356 - loss: 0.4119 - val_AUC: 0.9548 - val_loss: 0.4064 - learning_rate: 0.0100
Epoch 3/100
108/108 - 1s - 5ms/step - AUC: 0.9585 - loss: 0.3557 - val_AUC: 0.9731 - val_loss: 0.3863 - learning_rate: 0.0100
Epoch 4/100
108/108 - 0s - 4ms/step - AUC: 0.9669 - loss: 0.3353 - val_AUC: 0.9683 - val_loss: 0.3503 - learning_rate: 0.0100
Epoch 5/100
108/108 - 0s - 4ms/step - AUC: 0.9731 - loss: 0.3160 - val_AUC: 0.9831 - val_loss: 0.2822 - learning_rate: 0.0100
Epoch 6/100
108/108 - 0s - 4ms/step - AUC: 0.9787 - loss: 0.2990 - val_AUC: 0.9850 - val_loss: 0.2839 - learning_rate: 0.0100
Epoch 7/100
108/108 - 0s - 4ms/step - AUC: 0.9792 - loss: 0.3045 - val_AUC: 0.9470 - val_loss: 0.4944 - learning_rate: 0.0100
Epoch 8/100
108/108 - 0s - 4ms/step - AUC: 0.9814 - loss: 0.2967 - val_AUC: 0.9838 - val_loss: 0.2979 - learning_ra

  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


Epoch 1/100
103/103 - 10s - 93ms/step - AUC: 0.8618 - loss: 0.6226 - val_AUC: 0.9667 - val_loss: 0.3359 - learning_rate: 0.0100
Epoch 2/100
103/103 - 0s - 4ms/step - AUC: 0.9702 - loss: 0.3094 - val_AUC: 0.9890 - val_loss: 0.2666 - learning_rate: 0.0100
Epoch 3/100
103/103 - 1s - 6ms/step - AUC: 0.9851 - loss: 0.2524 - val_AUC: 0.9929 - val_loss: 0.2069 - learning_rate: 0.0100
Epoch 4/100
103/103 - 1s - 6ms/step - AUC: 0.9894 - loss: 0.2295 - val_AUC: 0.9955 - val_loss: 0.1904 - learning_rate: 0.0100
Epoch 5/100
103/103 - 1s - 6ms/step - AUC: 0.9908 - loss: 0.2243 - val_AUC: 0.9950 - val_loss: 0.2094 - learning_rate: 0.0100
Epoch 6/100
103/103 - 1s - 6ms/step - AUC: 0.9920 - loss: 0.2179 - val_AUC: 0.9885 - val_loss: 0.2412 - learning_rate: 0.0100
Epoch 7/100
103/103 - 1s - 6ms/step - AUC: 0.9926 - loss: 0.2191 - val_AUC: 0.9899 - val_loss: 0.2493 - learning_rate: 0.0100
Epoch 8/100
103/103 - 0s - 4ms/step - AUC: 0.9937 - loss: 0.2086 - val_AUC: 0.9954 - val_loss: 0.1995 - learning_rat

  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


Epoch 1/100
79/79 - 9s - 111ms/step - AUC: 0.7400 - loss: 0.7576 - val_AUC: 0.8293 - val_loss: 0.6726 - learning_rate: 0.0100
Epoch 2/100
79/79 - 0s - 4ms/step - AUC: 0.8380 - loss: 0.5897 - val_AUC: 0.8492 - val_loss: 0.6134 - learning_rate: 0.0100
Epoch 3/100
79/79 - 0s - 4ms/step - AUC: 0.8730 - loss: 0.5378 - val_AUC: 0.8928 - val_loss: 0.5195 - learning_rate: 0.0100
Epoch 4/100
79/79 - 0s - 4ms/step - AUC: 0.8938 - loss: 0.5012 - val_AUC: 0.9062 - val_loss: 0.5054 - learning_rate: 0.0100
Epoch 5/100
79/79 - 1s - 8ms/step - AUC: 0.9090 - loss: 0.4788 - val_AUC: 0.9033 - val_loss: 0.5392 - learning_rate: 0.0100
Epoch 6/100
79/79 - 0s - 4ms/step - AUC: 0.9187 - loss: 0.4681 - val_AUC: 0.9025 - val_loss: 0.5291 - learning_rate: 0.0100
Epoch 7/100
79/79 - 0s - 4ms/step - AUC: 0.9245 - loss: 0.4652 - val_AUC: 0.9232 - val_loss: 0.4742 - learning_rate: 0.0100
Epoch 8/100
79/79 - 0s - 4ms/step - AUC: 0.9280 - loss: 0.4654 - val_AUC: 0.9235 - val_loss: 0.4804 - learning_rate: 0.0100
Epoch 

  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


Epoch 1/100
113/113 - 9s - 80ms/step - AUC: 0.8547 - loss: 0.6322 - val_AUC: 0.9553 - val_loss: 0.3784 - learning_rate: 0.0100
Epoch 2/100
113/113 - 4s - 39ms/step - AUC: 0.9626 - loss: 0.3410 - val_AUC: 0.9705 - val_loss: 0.3140 - learning_rate: 0.0100
Epoch 3/100
113/113 - 1s - 5ms/step - AUC: 0.9783 - loss: 0.2861 - val_AUC: 0.9849 - val_loss: 0.2770 - learning_rate: 0.0100
Epoch 4/100
113/113 - 0s - 4ms/step - AUC: 0.9817 - loss: 0.2821 - val_AUC: 0.9859 - val_loss: 0.2704 - learning_rate: 0.0100
Epoch 5/100
113/113 - 1s - 5ms/step - AUC: 0.9854 - loss: 0.2702 - val_AUC: 0.9886 - val_loss: 0.2460 - learning_rate: 0.0100
Epoch 6/100
113/113 - 0s - 4ms/step - AUC: 0.9873 - loss: 0.2627 - val_AUC: 0.9875 - val_loss: 0.2500 - learning_rate: 0.0100
Epoch 7/100
113/113 - 0s - 4ms/step - AUC: 0.9861 - loss: 0.2727 - val_AUC: 0.9825 - val_loss: 0.3125 - learning_rate: 0.0100
Epoch 8/100
113/113 - 0s - 3ms/step - AUC: 0.9890 - loss: 0.2583 - val_AUC: 0.9896 - val_loss: 0.2479 - learning_rat

  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


Epoch 1/100
100/100 - 9s - 93ms/step - AUC: 0.8090 - loss: 0.6656 - val_AUC: 0.9176 - val_loss: 0.4631 - learning_rate: 0.0100
Epoch 2/100
100/100 - 0s - 4ms/step - AUC: 0.9122 - loss: 0.4729 - val_AUC: 0.9380 - val_loss: 0.4932 - learning_rate: 0.0100
Epoch 3/100
100/100 - 0s - 4ms/step - AUC: 0.9418 - loss: 0.4122 - val_AUC: 0.9228 - val_loss: 0.5886 - learning_rate: 0.0100
Epoch 4/100
100/100 - 0s - 4ms/step - AUC: 0.9579 - loss: 0.3745 - val_AUC: 0.9719 - val_loss: 0.3843 - learning_rate: 0.0100
Epoch 5/100
100/100 - 0s - 3ms/step - AUC: 0.9652 - loss: 0.3616 - val_AUC: 0.9716 - val_loss: 0.3447 - learning_rate: 0.0100
Epoch 6/100
100/100 - 1s - 6ms/step - AUC: 0.9708 - loss: 0.3482 - val_AUC: 0.9796 - val_loss: 0.3159 - learning_rate: 0.0100
Epoch 7/100
100/100 - 0s - 3ms/step - AUC: 0.9714 - loss: 0.3501 - val_AUC: 0.9764 - val_loss: 0.3563 - learning_rate: 0.0100
Epoch 8/100
100/100 - 0s - 4ms/step - AUC: 0.9740 - loss: 0.3468 - val_AUC: 0.9681 - val_loss: 0.4063 - learning_rate

  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


Epoch 1/100
80/80 - 9s - 116ms/step - AUC: 0.8575 - loss: 0.6222 - val_AUC: 0.9239 - val_loss: 0.5154 - learning_rate: 0.0100
Epoch 2/100
80/80 - 0s - 4ms/step - AUC: 0.9225 - loss: 0.4422 - val_AUC: 0.9482 - val_loss: 0.3815 - learning_rate: 0.0100
Epoch 3/100
80/80 - 1s - 8ms/step - AUC: 0.9373 - loss: 0.4072 - val_AUC: 0.9585 - val_loss: 0.3485 - learning_rate: 0.0100
Epoch 4/100
80/80 - 0s - 4ms/step - AUC: 0.9509 - loss: 0.3747 - val_AUC: 0.9572 - val_loss: 0.3662 - learning_rate: 0.0100
Epoch 5/100
80/80 - 1s - 8ms/step - AUC: 0.9567 - loss: 0.3619 - val_AUC: 0.9580 - val_loss: 0.3605 - learning_rate: 0.0100
Epoch 6/100
80/80 - 0s - 4ms/step - AUC: 0.9592 - loss: 0.3615 - val_AUC: 0.9650 - val_loss: 0.3561 - learning_rate: 0.0100
Epoch 7/100
80/80 - 0s - 4ms/step - AUC: 0.9619 - loss: 0.3597 - val_AUC: 0.9688 - val_loss: 0.3475 - learning_rate: 0.0100
Epoch 8/100
80/80 - 0s - 4ms/step - AUC: 0.9639 - loss: 0.3605 - val_AUC: 0.9663 - val_loss: 0.3524 - learning_rate: 0.0100
Epoch 

  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


Epoch 1/100
105/105 - 9s - 89ms/step - AUC: 0.7906 - loss: 0.7627 - val_AUC: 0.9101 - val_loss: 0.5160 - learning_rate: 0.0100
Epoch 2/100
105/105 - 0s - 4ms/step - AUC: 0.9157 - loss: 0.4554 - val_AUC: 0.9144 - val_loss: 0.4739 - learning_rate: 0.0100
Epoch 3/100
105/105 - 0s - 4ms/step - AUC: 0.9473 - loss: 0.3887 - val_AUC: 0.9381 - val_loss: 0.4188 - learning_rate: 0.0100
Epoch 4/100
105/105 - 1s - 6ms/step - AUC: 0.9568 - loss: 0.3727 - val_AUC: 0.9763 - val_loss: 0.3085 - learning_rate: 0.0100
Epoch 5/100
105/105 - 0s - 4ms/step - AUC: 0.9664 - loss: 0.3521 - val_AUC: 0.9746 - val_loss: 0.3570 - learning_rate: 0.0100
Epoch 6/100
105/105 - 0s - 4ms/step - AUC: 0.9711 - loss: 0.3433 - val_AUC: 0.9679 - val_loss: 0.3623 - learning_rate: 0.0100
Epoch 7/100
105/105 - 0s - 4ms/step - AUC: 0.9733 - loss: 0.3439 - val_AUC: 0.9750 - val_loss: 0.3469 - learning_rate: 0.0100
Epoch 8/100
105/105 - 0s - 4ms/step - AUC: 0.9733 - loss: 0.3477 - val_AUC: 0.9759 - val_loss: 0.3630 - learning_rate