In [1]:
# Importing necessary libraries
import numpy as np
import tensorflow as tf
from sklearn.model_selection import KFold
from tensorflow.keras.callbacks import EarlyStopping
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import time

In [2]:
# Custom imports
from config import get_config
from dataset import DataSet, prepare_interaction_pairs, create_tf_dataset
from models.lstm_model import build_lstm_model
from models.lstm_attention_model import build_lstm_attention_model
from models.transformer_model import build_transformer_model

In [3]:
# Cell 1: Define CI score function
def ci_score(y_true, y_pred):
    ind = np.argsort(y_true)
    y_true = y_true[ind]
    y_pred = y_pred[ind]
    i = len(y_true)-1
    j = i-1
    z = 0
    S = 0
    while i > 0:
        while j >= 0:
            if y_true[i] > y_true[j]:
                z = z+1
                if y_pred[i] > y_pred[j]:
                    S = S + 1
            j = j - 1
        i = i - 1
        j = i-1
    ci = S/z
    return ci



In [4]:
# Cell 2: Define custom callback with tqdm
class TqdmCallback(tf.keras.callbacks.Callback):
    def __init__(self, epochs, metrics=['loss', 'val_loss']):
        super().__init__()
        self.epochs = epochs
        self.metrics = metrics
        self.tqdm_outer = None
        self.tqdm_inner = None
        
    def on_train_begin(self, logs=None):
        self.tqdm_outer = tqdm(total=self.epochs, desc='Epochs', position=0)
        
    def on_epoch_begin(self, epoch, logs=None):
        self.tqdm_inner = tqdm(total=self.params['steps'], desc=f'Epoch {epoch+1}/{self.epochs}', position=1, leave=False)
        
    def on_train_batch_end(self, batch, logs=None):
        self.tqdm_inner.update()
        
    def on_epoch_end(self, epoch, logs=None):
        self.tqdm_outer.update()
        self.tqdm_inner.close()
        metrics_str = ' - '.join([f'{m}: {logs[m]:.4f}' for m in self.metrics if m in logs])
        self.tqdm_outer.set_postfix_str(metrics_str)
        
    def on_train_end(self, logs=None):
        self.tqdm_outer.close()



In [5]:
def run_experiment(config, dataset_type):
    print(f"\nRunning experiment for {dataset_type} dataset")
    
    if dataset_type == 'davis':
        dataset = DataSet(config['davis_path'], config['problem_type'], config['max_seq_len'], 
                          config['max_smi_len'], 'davis', config['davis_convert_to_log'])
    else:
        dataset = DataSet(config['kiba_path'], config['problem_type'], config['max_seq_len'], 
                          config['max_smi_len'], 'kiba', config['kiba_convert_to_log'])

    XD, XT, Y, label_row_inds, label_col_inds, _, _ = dataset.get_data()

    # Update config with dataset-specific alphabet sizes
    config['charsmiset_size'] = dataset.charsmiset_size
    config['charseqset_size'] = dataset.charseqset_size

    models = {
        'LSTM': build_lstm_model,
        'LSTM with Attention': build_lstm_attention_model,
        'Transformer': build_transformer_model
    }

    results = {}
    kf = KFold(n_splits=5, shuffle=True, random_state=42)

    for model_name, model_builder in models.items():
        print(f"\nTraining {model_name}...")
        model_ci_scores = []
        model_rmse_scores = []
        
        for fold, (train_index, val_index) in enumerate(kf.split(label_row_inds)):
            print(f"Fold {fold + 1}/5")
            
            train_drug_data, train_target_data, train_affinity = prepare_interaction_pairs(
                XD, XT, Y, label_row_inds[train_index], label_col_inds[train_index])
            
            val_drug_data, val_target_data, val_affinity = prepare_interaction_pairs(
                XD, XT, Y, label_row_inds[val_index], label_col_inds[val_index])
            
            train_dataset = create_tf_dataset(train_drug_data, train_target_data, train_affinity, config['batch_size'])
            val_dataset = create_tf_dataset(val_drug_data, val_target_data, val_affinity, config['batch_size'], shuffle=False)
            
            model = model_builder(config)
            
            early_stopping = EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True)
            tqdm_callback = TqdmCallback(epochs=config['num_epoch'])
            
            start_time = time.time()
            history = model.fit(
                train_dataset,
                epochs=config['num_epoch'],
                validation_data=val_dataset,
                callbacks=[early_stopping, tqdm_callback],
                verbose=0
            )
            end_time = time.time()
            training_time = end_time - start_time
            
            val_predictions = model.predict(val_dataset, verbose=0)
            val_ci = ci_score(val_affinity, val_predictions.flatten())
            val_rmse = tf.keras.metrics.RootMeanSquaredError()(val_affinity, val_predictions.flatten()).numpy()
            
            model_ci_scores.append(val_ci)
            model_rmse_scores.append(val_rmse)
            
            print(f"Fold {fold + 1} - CI: {val_ci:.4f}, RMSE: {val_rmse:.4f}, Training Time: {training_time:.2f} seconds")
        
        avg_ci = np.mean(model_ci_scores)
        avg_rmse = np.mean(model_rmse_scores)
        results[model_name] = {'CI': avg_ci, 'RMSE': avg_rmse}
        
        print(f"{model_name} - Average CI: {avg_ci:.4f}, Average RMSE: {avg_rmse:.4f}")

    return results

In [None]:
# Cell 4: Run experiments for both datasets
config = get_config()
davis_results = run_experiment(config, 'davis')
kiba_results = run_experiment(config, 'kiba')


Running experiment for davis dataset
Reading davis dataset from ../data/davis/
Parsing davis dataset

Training LSTM...
Fold 1/5




Epochs:   0%|          | 0/100 [00:00<?, ?it/s]

Epoch 1/100:   0%|          | 0/94 [00:00<?, ?it/s]

In [None]:
# Cell 5: Compare results
def print_results(results, dataset_name):
    print(f"\nFinal Results for {dataset_name} dataset:")
    for model_name, scores in results.items():
        print(f"{model_name}:")
        print(f"  CI: {scores['CI']:.4f}")
        print(f"  RMSE: {scores['RMSE']:.4f}")

print_results(davis_results, "Davis")
print_results(kiba_results, "KIBA")



In [None]:
# Cell 6: Visualize results
def plot_results(davis_results, kiba_results):
    models = list(davis_results.keys())
    x = np.arange(len(models))
    width = 0.2

    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 12))

    # CI Score plot
    davis_ci = [davis_results[model]['CI'] for model in models]
    kiba_ci = [kiba_results[model]['CI'] for model in models]

    ax1.bar(x - width/2, davis_ci, width, label='Davis', color='b', alpha=0.7)
    ax1.bar(x + width/2, kiba_ci, width, label='KIBA', color='r', alpha=0.7)
    ax1.set_ylabel('CI Score')
    ax1.set_title('CI Score Comparison')
    ax1.set_xticks(x)
    ax1.set_xticklabels(models, rotation=45, ha='right')
    ax1.legend()

    # RMSE plot
    davis_rmse = [davis_results[model]['RMSE'] for model in models]
    kiba_rmse = [kiba_results[model]['RMSE'] for model in models]

    ax2.bar(x - width/2, davis_rmse, width, label='Davis', color='b', alpha=0.7)
    ax2.bar(x + width/2, kiba_rmse, width, label='KIBA', color='r', alpha=0.7)
    ax2.set_ylabel('RMSE')
    ax2.set_title('RMSE Comparison')
    ax2.set_xticks(x)
    ax2.set_xticklabels(models, rotation=45, ha='right')
    ax2.legend()

    fig.tight_layout()
    plt.show()

plot_results(davis_results, kiba_results)



In [None]:
# Cell 7: Compare with DeepDTA results
deepdta_results = {
    'Davis': {'CI': 0.878, 'MSE': 0.261},
    'KIBA': {'CI': 0.863, 'MSE': 0.194}
}

def compare_with_deepdta(results, dataset_type):
    print(f"\nComparison with DeepDTA for {dataset_type} dataset:")
    deepdta_ci = deepdta_results[dataset_type]['CI']
    deepdta_mse = deepdta_results[dataset_type]['MSE']
    print(f"DeepDTA - CI: {deepdta_ci:.4f}, MSE: {deepdta_mse:.4f}")

    for model_name, scores in results.items():
        ci_diff = scores['CI'] - deepdta_ci
        rmse_diff = scores['RMSE'] - np.sqrt(deepdta_mse)
        
        print(f"{model_name}:")
        print(f"  CI difference: {ci_diff:.4f} ({'better' if ci_diff > 0 else 'worse'})")
        print(f"  RMSE difference: {rmse_diff:.4f} ({'worse' if rmse_diff > 0 else 'better'})")

compare_with_deepdta(davis_results, 'Davis')
compare_with_deepdta(kiba_results, 'KIBA')