In [None]:
# Importing necessary libraries
import numpy as np
import tensorflow as tf
from sklearn.model_selection import KFold
from tensorflow.keras.callbacks import EarlyStopping
import matplotlib.pyplot as plt

# Custom imports
from config import get_config
from dataset import DataSet, prepare_interaction_pairs, create_tf_dataset
from lstm_model import build_lstm_model
from lstm_attention_model import build_lstm_attention_model
from transformer_model import build_transformer_model

# Cell 1: Define CI score function
def ci_score(y_true, y_pred):
    ind = np.argsort(y_true)
    y_true = y_true[ind]
    y_pred = y_pred[ind]
    i = len(y_true)-1
    j = i-1
    z = 0
    S = 0
    while i > 0:
        while j >= 0:
            if y_true[i] > y_true[j]:
                z = z+1
                if y_pred[i] > y_pred[j]:
                    S = S + 1
            j = j - 1
        i = i - 1
        j = i-1
    ci = S/z
    return ci

# Cell 2: Load configuration and dataset
config = get_config()
dataset_type = 'davis'  # Change to 'kiba' for KIBA dataset

if dataset_type == 'davis':
    dataset = DataSet(config['davis_path'], config['problem_type'], config['max_seq_len'],
                      config['max_smi_len'], 'davis', config['davis_convert_to_log'])
else:
    dataset = DataSet(config['kiba_path'], config['problem_type'], config['max_seq_len'],
                      config['max_smi_len'], 'kiba', config['kiba_convert_to_log'])

XD, XT, Y, label_row_inds, label_col_inds, _, _ = dataset.get_data()

# Cell 3: Define models to be tested
models = {
    'LSTM': build_lstm_model,
    'LSTM with Attention': build_lstm_attention_model,
    'Transformer': build_transformer_model
}

# Cell 4: Experiment setup
results = {}
kf = KFold(n_splits=5, shuffle=True, random_state=42)

# Cell 5: Run experiments
for model_name, model_builder in models.items():
    print(f"\nTraining {model_name}...")
    model_ci_scores = []
    model_rmse_scores = []

    for fold, (train_index, val_index) in enumerate(kf.split(label_row_inds)):
        print(f"Fold {fold + 1}/5")

        train_drug_data, train_target_data, train_affinity = prepare_interaction_pairs(
            XD, XT, Y, label_row_inds[train_index], label_col_inds[train_index])

        val_drug_data, val_target_data, val_affinity = prepare_interaction_pairs(
            XD, XT, Y, label_row_inds[val_index], label_col_inds[val_index])

        train_dataset = create_tf_dataset(train_drug_data, train_target_data, train_affinity, config['batch_size'])
        val_dataset = create_tf_dataset(val_drug_data, val_target_data, val_affinity, config['batch_size'], shuffle=False)

        model = model_builder(config)

        early_stopping = EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True)

        history = model.fit(
            train_dataset,
            epochs=config['num_epoch'],
            validation_data=val_dataset,
            callbacks=[early_stopping],
            verbose=1
        )

        val_predictions = model.predict(val_dataset)
        val_ci = ci_score(val_affinity, val_predictions.flatten())
        val_rmse = tf.keras.metrics.RootMeanSquaredError()(val_affinity, val_predictions.flatten()).numpy()

        model_ci_scores.append(val_ci)
        model_rmse_scores.append(val_rmse)

        print(f"Fold {fold + 1} - CI: {val_ci:.4f}, RMSE: {val_rmse:.4f}")

    avg_ci = np.mean(model_ci_scores)
    avg_rmse = np.mean(model_rmse_scores)
    results[model_name] = {'CI': avg_ci, 'RMSE': avg_rmse}

    print(f"{model_name} - Average CI: {avg_ci:.4f}, Average RMSE: {avg_rmse:.4f}")

# Cell 6: Compare results
print("\nFinal Results:")
for model_name, scores in results.items():
    print(f"{model_name}:")
    print(f"  CI: {scores['CI']:.4f}")
    print(f"  RMSE: {scores['RMSE']:.4f}")

# Cell 7: Visualize results
def plot_results(results):
    models = list(results.keys())
    ci_scores = [results[model]['CI'] for model in models]
    rmse_scores = [results[model]['RMSE'] for model in models]

    x = np.arange(len(models))
    width = 0.35

    fig, ax1 = plt.subplots(figsize=(12, 6))
    ax2 = ax1.twinx()

    rects1 = ax1.bar(x - width/2, ci_scores, width, label='CI', color='b', alpha=0.7)
    rects2 = ax2.bar(x + width/2, rmse_scores, width, label='RMSE', color='r', alpha=0.7)

    ax1.set_xlabel('Models')
    ax1.set_ylabel('CI Score')
    ax2.set_ylabel('RMSE')
    ax1.set_title('Model Performance Comparison')
    ax1.set_xticks(x)
    ax1.set_xticklabels(models, rotation=45, ha='right')

    ax1.legend(loc='upper left')
    ax2.legend(loc='upper right')

    fig.tight_layout()
    plt.show()

plot_results(results)

# Cell 8: Compare with DeepDTA results
deepdta_results = {
    'Davis': {'CI': 0.878, 'MSE': 0.261},
    'KIBA': {'CI': 0.863, 'MSE': 0.194}
}

print(f"\nDeepDTA results on {dataset_type.capitalize()} dataset:")
print(f"CI: {deepdta_results[dataset_type.capitalize()]['CI']:.4f}")
print(f"MSE: {deepdta_results[dataset_type.capitalize()]['MSE']:.4f}")

print("\nComparison with DeepDTA:")
for model_name, scores in results.items():
    ci_diff = scores['CI'] - deepdta_results[dataset_type.capitalize()]['CI']
    rmse_diff = scores['RMSE'] - np.sqrt(deepdta_results[dataset_type.capitalize()]['MSE'])

    print(f"{model_name}:")
    print(f"  CI difference: {ci_diff:.4f} ({'better' if ci_diff > 0 else 'worse'})")
    print(f"  RMSE difference: {rmse_diff:.4f} ({'worse' if rmse_diff > 0 else 'better'})")