In [1]:
# %pip install matplotlib seaborn networkx

In [2]:
# !conda install -c conda-forge pygraphviz -y

In [3]:
import numpy as np
import pandas as pd
import seaborn as sns
import networkx as nx
import matplotlib.pyplot as plt

from math import pi
from scipy.interpolate import make_interp_spline

## AMVL WorkFlow

In [4]:
# Create a directed graph
G = nx.DiGraph()

# Define nodes
nodes = {
    "Input": "Input Matrices and Parameters",
    "Step1_Start": "Step 1: Bounded Matrix Completion (BMC)",
    "BMC": "Perform BMC to obtain T_mc",
    "GIP": "Compute GIP Similarities (Grr, Gdd)",
    "Step2_Start": "Step 2: Matrix Factorization with Similarity Regularization",
    "Combine_Similarity": "Combine GIP and Input Similarities",
    "Remove_Diagonal": "Remove Diagonal Elements (Wrr_ML, Wdd_ML)",
    "MSBMF": "Perform MSBMF (Matrix Factorization)",
    "Step3_Start": "Step 3: Multi-View Learning (MVL)",
    "Update_Similarity": "Update Similarity Matrices (SR, SD)",
    "Predict_MVL": "Generate Multi-View Prediction (F_mv)",
    "Step4_Start": "Step 4: Final Combination",
    "Combine_Results": "Combine Thresholded Results and F_mv",
    "Clip_Values": "Ensure Values in Range [0, 1]",
    "Output": "Output Final Prediction Matrix (F_final)"
}

# Add edges
edges = [
    ("Input", "Step1_Start"),
    ("Step1_Start", "BMC"),
    ("BMC", "GIP"),
    ("GIP", "Step2_Start"),
    ("Step2_Start", "Combine_Similarity"),
    ("Combine_Similarity", "Remove_Diagonal"),
    ("Remove_Diagonal", "MSBMF"),
    ("MSBMF", "Step3_Start"),
    ("Step3_Start", "Update_Similarity"),
    ("Update_Similarity", "Predict_MVL"),
    ("Predict_MVL", "Step4_Start"),
    ("Step4_Start", "Combine_Results"),
    ("Combine_Results", "Clip_Values"),
    ("Clip_Values", "Output")
]

# Add nodes and edges to the graph
for node, label in nodes.items():
    G.add_node(node, label=label)
G.add_edges_from(edges)

# Define positions using graphviz layout
pos = nx.nx_agraph.graphviz_layout(G, prog="dot")

# Draw the graph
plt.figure(figsize=(12, 8))
nx.draw(G, pos, with_labels=True, labels=nx.get_node_attributes(G, 'label'),
        node_size=5000, node_color="lightblue", font_size=9, font_weight="bold", arrows=True)
plt.title("AdaMVL Workflow", fontsize=14)
plt.show()

## Benchmark 5+2

In [3]:
# Models and metrics
models = ['AMVL', 'MLMC', 'MSBMF', 'HGIMC', 'ITRPCA', 'DRPADC', 'VDA-GKSBMF']
# metrics = ['AUC', 'AUPR', 'F1']
# Updated metrics list to include virtual metrics between each original metric
metrics = ['AUC', 'Virtual1', 'AUPR', 'Virtual2', 'F1', 'Virtual3']

In [4]:
# Original data for each model across different datasets
data = {
    'Fdataset': {
        'AMVL': [0.9587, 0.9656, 0.7131],
        'MLMC': [0.9573, 0.9644, 0.6909],
        'MSBMF': [0.9462, 0.9583, 0.6862],
        'HGIMC': [0.9260, 0.9432, 0.6769],
        'ITRPCA': [0.9333, 0.9399, 0.6323],
        'DRPADC': [0.9112, 0.9350, 0.6704],
        'VDA-GKSBMF': [0.9379, 0.9508, 0.6821],
    },
    'Cdataset': {
        'AMVL': [0.9702, 0.9753, 0.7213],
        'MLMC': [0.9689, 0.9742, 0.6967],
        'MSBMF': [0.9634, 0.9722, 0.6946],
        'HGIMC': [0.9448, 0.9592, 0.6862],
        'ITRPCA': [0.9500, 0.9553, 0.6408],
        'DRPADC': [0.9269, 0.9460, 0.6778],
        'VDA-GKSBMF': [0.9555, 0.9640, 0.6905],
    },
    'Ydataset': {
        'AMVL': [0.9709, 0.9749, 0.7222],
        'MLMC': [0.9526, 0.9624, 0.6902],
        'MSBMF': [0.9647, 0.9725, 0.6956],
        'HGIMC': [0.9565, 0.9658, 0.6917],
        'ITRPCA': [0.9473, 0.9512, 0.6371],
        'DRPADC': [0.9555, 0.9648, 0.6911],
        'VDA-GKSBMF': [0.9588, 0.9605, 0.6917],
    }
}

In [5]:
# Rescaling F1 to the range of AUC and AUPR min and max for each dataset
def rescale_f1_to_auc_aupr_range(data):
    scaled_data = {}

    for dataset, model_data in data.items():
        # Find the min and max across AUC and AUPR values for each dataset
        auc_aupr_values = [values[0:2] for values in model_data.values()]
        auc_aupr_min = min([min(values) for values in auc_aupr_values])
        auc_aupr_max = max([max(values) for values in auc_aupr_values])
        
        # Rescale F1 values within the range [auc_aupr_min, auc_aupr_max]
        f1_values = [values[2] for values in model_data.values()]
        min_f1_value = min(f1_values)
        max_f1_value = max(f1_values)
        
        # Rescale F1 values to the range of auc_aupr_min to auc_aupr_max
        def rescale_f1(value):
            return auc_aupr_min + (auc_aupr_max - auc_aupr_min) * (value - min_f1_value) / (max_f1_value - min_f1_value)
        
        # Update the dataset with rescaled F1 values
        scaled_data[dataset] = {}
        for model, values in model_data.items():
            scaled_f1 = rescale_f1(values[2])
            scaled_data[dataset][model] = [values[0], values[1], scaled_f1]
    
    return scaled_data

In [None]:
# Apply the rescaling
scaled_data_f1_to_auc_aupr_range = rescale_f1_to_auc_aupr_range(data)
scaled_data_f1_to_auc_aupr_range

In [7]:
# Function to add virtual metrics between each original metric
def add_virtual_metrics(data):
    extended_data = {}
    for dataset, model_data in data.items():
        extended_data[dataset] = {}
        for model, values in model_data.items():
            # Calculate virtual metrics as the average of each pair of consecutive original metrics
            virtual1 = (values[0] + values[1]) / 2
            virtual2 = (values[1] + values[2]) / 2
            virtual3 = (values[2] + values[0]) / 2
            # Extend the original values with the virtual metrics
            extended_values = [values[0], virtual1, values[1], virtual2, values[2], virtual3]
            extended_data[dataset][model] = extended_values
    return extended_data

# Apply the function to add virtual metrics
data_with_virtual_metrics = add_virtual_metrics(scaled_data_f1_to_auc_aupr_range)

In [8]:
# # Define Nature-style colors that are colorblind-friendly
# nature_colors = [
#     '#E69F00', '#56B4E9', '#009E73', '#F0E442', '#0072B2', '#D55E00', '#CC79A7'
# ]

# def create_radar_chart(ax, dataset, data, models):
#     # Number of metrics
#     num_vars = len(metrics)
#     angles = [n / float(num_vars) * 2 * pi for n in range(num_vars)]
#     # Adding virtual points between each pair of metrics
#     extended_angles = []
#     for i in range(len(angles)):
#         extended_angles.append(angles[i])
#         if i < len(angles) - 1:
#             # Add a virtual point between every pair of metrics
#             extended_angles.append(angles[i] + (angles[(i + 1) % len(angles)] - angles[i]) * 0.5)
#     angles = extended_angles + extended_angles[:1]  # Closing the loop for the radar chart

#     # Add virtual points between metrics in the data
#     extended_data = {}
#     for model in data[dataset]:
#         values = data[dataset][model]
#         extended_values = []
#         for i in range(len(values)):
#             extended_values.append(values[i])
#             if i < len(values) - 1:
#                 # Add a value halfway between each metric to create the virtual points
#                 extended_values.append((values[i] + values[(i + 1) % len(values)]) / 2)
#         extended_values += extended_values[:1]  # Close the loop for the data
#         extended_data[model] = extended_values

#     # Draw one axe per metric + add labels
#     ax.set_theta_offset(pi / 2)
#     ax.set_theta_direction(-1)
#     ax.set_xticks(angles[:-1:2])  # Set ticks for original metrics only
#     ax.set_xticklabels(metrics)

#     # Customize the grid
#     ax.grid(color='gray', linestyle='--', linewidth=0.5)
#     ax.set_facecolor('#f7f7f7')  # Light background color for better contrast

#     # Plot data for each model and save the line objects for the legend
#     lines = []
#     marker_styles = ['o', 's', 'D', '^', 'v', 'p', '*']  # Different markers for each model
#     line_styles = ['solid', 'dashed', 'dashdot', 'dotted']  # Different line styles

#     for idx, model in enumerate(models):
#         values = extended_data[model]
#         color = nature_colors[idx % len(nature_colors)]  # Cycle through Nature-style colors
#         marker = marker_styles[idx % len(marker_styles)]  # Cycle through marker styles
#         line_style = line_styles[idx % len(line_styles)]  # Cycle through line styles
#         line = ax.plot(angles, values, linewidth=2 + 0.5 * (idx % 2), linestyle=line_style, label=model, color=color, marker=marker)
#         ax.fill(angles, values, color=color, alpha=0.15 + 0.1 * (idx % 3))  # Vary opacity for distinction
#         lines.append(line[0])  # Save the line for legend

#     # Set dynamic ylim based on the max and min of AUC, AUPR, and rescaled F1
#     dataset_values = [value for model_data in data[dataset].values() for value in model_data]
#     ax.set_ylim(min(dataset_values) - 5e-3, max(dataset_values) + 5e-5)
    
#     return lines

In [9]:
# Define Nature-style colors that are colorblind-friendly
nature_colors = [
    '#E69F00', '#56B4E9', '#009E73', '#F0E442', '#0072B2', '#D55E00', '#CC79A7'
]

def create_radar_chart(ax, dataset, data, models):
    # Number of metrics
    num_vars = len(metrics)
    angles = [n / float(num_vars) * 2 * pi for n in range(num_vars)]
    angles += angles[:1]  # Closing the loop for the radar chart

    # Draw one axe per metric + add labels
    ax.set_theta_offset(pi / 2)
    ax.set_theta_direction(-1)
    ax.set_xticks(angles[:-1])
    # ax.set_xticklabels(metrics)
    ax.set_xticklabels([metric if 'Virtual' not in metric else '' for metric in metrics])

    # Customize the grid
    ax.grid(color='gray', linestyle='--', linewidth=0.5)
    ax.set_facecolor('#f7f7f7')  # Light background color for better contrast

    # Plot data for each model and save the line objects for the legend
    lines = []
    marker_styles = ['o', 's', 'D', '^', 'v', 'p', '*']  # Different markers for each model
    line_styles = ['solid', 'dashed', 'dashdot', 'dotted']  # Different line styles

    for idx, model in enumerate(models):
        values = data[dataset][model]
        values += values[:1]  # Close the loop for the data
        color = nature_colors[idx % len(nature_colors)]  # Cycle through Nature-style colors
        marker = marker_styles[idx % len(marker_styles)]  # Cycle through marker styles
        line_style = line_styles[idx % len(line_styles)]  # Cycle through line styles
        line = ax.plot(angles, values, linewidth=2 + 0.5 * (idx % 2), linestyle=line_style, label=model, color=color, marker=marker)
        ax.fill(angles, values, color=color, alpha=0.15 + 0.1 * (idx % 3))  # Vary opacity for distinction
        lines.append(line[0])  # Save the line for legend

    # Set dynamic ylim based on the max and min of AUC, AUPR, and rescaled F1
    dataset_values = [value for model_data in data[dataset].values() for value in model_data]
    ax.set_ylim(min(dataset_values) - 5e-3, max(dataset_values) + 5e-5)
    
    return lines

In [10]:
# # Function to create radar chart with dynamic ylim and matching legend colors
# def create_radar_chart(ax, dataset, data, models):
#     # Number of metrics
#     num_vars = len(metrics)
#     angles = [n / float(num_vars) * 2 * pi for n in range(num_vars)]
#     angles += angles[:1]  # Closing the loop for the radar chart

#     # Draw one axe per metric + add labels
#     ax.set_theta_offset(pi / 2)
#     ax.set_theta_direction(-1)
#     ax.set_xticks(angles[:-1])
#     ax.set_xticklabels(metrics)

#     # Plot data for each model and save the line objects for the legend
#     lines = []
#     for model in models:
#         values = data[dataset][model]
#         values += values[:1]  # Close the loop for the data
#         line = ax.plot(angles, values, linewidth=2, linestyle='solid', label=model)
#         ax.fill(angles, values, alpha=0.25)
#         lines.append(line[0])  # Save the line for legend

#     # Set dynamic ylim based on the max and min of AUC, AUPR, and rescaled F1
#     dataset_values = [values for model_data in data[dataset].values() for values in model_data]
#     ax.set_ylim(min(dataset_values) - 5e-3, max(dataset_values) + 5e-5)
    
#     return lines

In [None]:
# Create the overall figure with a 2x2 layout
fig, axs = plt.subplots(2, 2, figsize=(12, 12), subplot_kw=dict(polar=True))

# Create radar charts for each dataset with dynamic ylim and collect line objects for the legend
datasets = list(data_with_virtual_metrics.keys())
lines_1 = create_radar_chart(axs[0, 0], datasets[0], data_with_virtual_metrics, models)
axs[0, 0].set_title(f'{datasets[0]}', size=15, y=1.1)

lines_2 = create_radar_chart(axs[0, 1], datasets[1], data_with_virtual_metrics, models)
axs[0, 1].set_title(f'{datasets[1]}', size=15, y=1.1)

lines_3 = create_radar_chart(axs[1, 0], datasets[2], data_with_virtual_metrics, models)
axs[1, 0].set_title(f'{datasets[2]}', size=15, y=1.1)

# Adjust the fourth quadrant to display the legend with correct colors
axs[1, 1].axis('off')  # Turn off the polar axis for the legend area
fig.legend(handles=lines_1, loc='center', bbox_to_anchor=(0.75, 0.25), fontsize=20)  # Place the legend in the blank subplot

# Show the plot
plt.tight_layout()
plt.show()


In [12]:
# # Create the overall figure with a 2x2 layout
# fig, axs = plt.subplots(2, 2, figsize=(12, 12), subplot_kw=dict(polar=True))

# # Create radar charts for each dataset with dynamic ylim and collect line objects for the legend
# datasets = list(scaled_data_f1_to_auc_aupr_range.keys())
# lines_1 = create_radar_chart(axs[0, 0], datasets[0], scaled_data_f1_to_auc_aupr_range, models)
# axs[0, 0].set_title(f'{datasets[0]}', size=15, y=1.1)

# lines_2 = create_radar_chart(axs[0, 1], datasets[1], scaled_data_f1_to_auc_aupr_range, models)
# axs[0, 1].set_title(f'{datasets[1]}', size=15, y=1.1)

# lines_3 = create_radar_chart(axs[1, 0], datasets[2], scaled_data_f1_to_auc_aupr_range, models)
# axs[1, 0].set_title(f'{datasets[2]}', size=15, y=1.1)

# # Adjust the fourth quadrant to display the legend with correct colors
# axs[1, 1].axis('off')  # Turn off the polar axis for the legend area
# fig.legend(handles=lines_1, loc='center', bbox_to_anchor=(0.75, 0.25), fontsize=20)  # Place the legend in the blank subplot

# # Show the plot
# plt.tight_layout()
# plt.show()


In [13]:
fig.savefig("img/Figure 3.tiff", dpi=300, format='tiff')

## Embedding Validation

In [None]:
# # LLM Metrics
# heatmap_data = pd.DataFrame({
#     'Fdataset': [0.8815, 0.8737, 0.8754, 0.8758, 0.8557, 0.8914, 0.8864, 0.8819, 0.8814, 0.8589, 0.6526, 0.6292, 0.6490, 0.6491, 0.6382],
#     'Cdataset': [0.9057, 0.8892, 0.8931, 0.8996, 0.8887, 0.9161, 0.9034, 0.9029, 0.9103, 0.8915, 0.6652, 0.6394, 0.6590, 0.6621, 0.6554],
#     'Ydataset': [0.9298, 0.9094, 0.9275, 0.9242, 0.9042, 0.9330, 0.9156, 0.9297, 0.9269, 0.9060, 0.6768, 0.6495, 0.6770, 0.6740, 0.6639]
# }, index=['SVM AUC', 'RandomForest AUC', 'XGBoost AUC', 'LightGBM AUC', 'MLP AUC', 
#           'SVM AUPR', 'RandomForest AUPR', 'XGBoost AUPR', 'LightGBM AUPR', 'MLP AUPR', 
#           'SVM F1', 'RandomForest F1', 'XGBoost F1', 'LightGBM F1', 'MLP F1'])

In [None]:
# # Split the dataframe into separate metrics to create different color maps
# auc_data = heatmap_data.loc[['SVM AUC', 'RandomForest AUC', 'XGBoost AUC', 'LightGBM AUC', 'MLP AUC']].rename(index=lambda x: x.split()[0])
# aupr_data = heatmap_data.loc[['SVM AUPR', 'RandomForest AUPR', 'XGBoost AUPR', 'LightGBM AUPR', 'MLP AUPR']].rename(index=lambda x: x.split()[0])
# f1_data = heatmap_data.loc[['SVM F1', 'RandomForest F1', 'XGBoost F1', 'LightGBM F1', 'MLP F1']].rename(index=lambda x: x.split()[0])

# # Set up the figure
# fig, axes = plt.subplots(1, 3, figsize=(18, 9))

# # AUC Heatmap
# sns.heatmap(auc_data, annot=True, fmt=".4f", cmap="Blues", linewidths=.5, ax=axes[0])
# axes[0].set_title('AUC Performance', fontsize=14)
# axes[0].set_ylabel("Models", fontsize=14)

# # AUPR Heatmap
# sns.heatmap(aupr_data, annot=True, fmt=".4f", cmap="Greens", linewidths=.5, ax=axes[1])
# axes[1].set_title('AUPR Performance', fontsize=14)
# axes[1].set_xlabel("Dataset", fontsize=14)

# # F1 Heatmap
# sns.heatmap(f1_data, annot=True, fmt=".4f", cmap="Oranges", linewidths=.5, ax=axes[2])
# axes[2].set_title('F1 Performance', fontsize=14)

# # Show the plot
# plt.tight_layout()
# plt.show()

In [None]:
# Data for the three datasets with counts of drugs, diseases, and associations
dataset_counts = pd.DataFrame({
    'Dataset': ['Fdataset', 'Cdataset', 'Ydataset'],
    'Drugs': [593, 663, 1478],
    'Diseases': [313, 409, 655],
    'Associations': [1933, 2352, 8448]
})

# LLM Metrics heatmap data
heatmap_data = pd.DataFrame({
    'Fdataset': [0.8815, 0.8737, 0.8754, 0.8758, 0.8557, 0.8914, 0.8864, 0.8819, 0.8814, 0.8589, 0.6526, 0.6292, 0.6490, 0.6491, 0.6382],
    'Cdataset': [0.9057, 0.8892, 0.8931, 0.8996, 0.8887, 0.9161, 0.9034, 0.9029, 0.9103, 0.8915, 0.6652, 0.6394, 0.6590, 0.6621, 0.6554],
    'Ydataset': [0.9298, 0.9094, 0.9275, 0.9242, 0.9042, 0.9330, 0.9156, 0.9297, 0.9269, 0.9060, 0.6768, 0.6495, 0.6770, 0.6740, 0.6639]
}, index=['SVM AUC', 'RandomForest AUC', 'XGBoost AUC', 'LightGBM AUC', 'MLP AUC', 
          'SVM AUPR', 'RandomForest AUPR', 'XGBoost AUPR', 'LightGBM AUPR', 'MLP AUPR', 
          'SVM F1', 'RandomForest F1', 'XGBoost F1', 'LightGBM F1', 'MLP F1'])

# Split the data for AUC, AUPR, and F1
auc_data = heatmap_data.loc[['SVM AUC', 'RandomForest AUC', 'XGBoost AUC', 'LightGBM AUC', 'MLP AUC']].rename(index=lambda x: x.split()[0])
aupr_data = heatmap_data.loc[['SVM AUPR', 'RandomForest AUPR', 'XGBoost AUPR', 'LightGBM AUPR', 'MLP AUPR']].rename(index=lambda x: x.split()[0])
f1_data = heatmap_data.loc[['SVM F1', 'RandomForest F1', 'XGBoost F1', 'LightGBM F1', 'MLP F1']].rename(index=lambda x: x.split()[0])

# Set up the figure with an extra subplot for dataset counts
fig, axes = plt.subplots(1, 4, figsize=(24, 9), gridspec_kw={'wspace': 0.4})

# AUC Heatmap
sns.heatmap(auc_data, annot=True, fmt=".4f", cmap="Blues", linewidths=.5, ax=axes[0])
axes[0].set_title('AUC Performance', fontsize=14)
axes[0].set_ylabel("Models", fontsize=16, labelpad=10)

# AUPR Heatmap
sns.heatmap(aupr_data, annot=True, fmt=".4f", cmap="Greens", linewidths=.5, ax=axes[1])
axes[1].set_title('AUPR Performance', fontsize=14)

# F1 Heatmap
sns.heatmap(f1_data, annot=True, fmt=".4f", cmap="Oranges", linewidths=.5, ax=axes[2])
axes[2].set_title('F1 Performance', fontsize=14)

# Dataset count bar plot with colors matched to heatmaps
axes[3].bar(dataset_counts['Dataset'], dataset_counts['Drugs'], color='#2777B8', label='Drugs')
axes[3].bar(dataset_counts['Dataset'], dataset_counts['Diseases'], color='#4BB062', label='Diseases', bottom=dataset_counts['Drugs'])
axes[3].bar(dataset_counts['Dataset'], dataset_counts['Associations'], color='#FDA35C', label='Associations', 
            bottom=dataset_counts['Drugs'] + dataset_counts['Diseases'])

# Add line plots to indicate trends for each category
axes[3].plot(dataset_counts['Dataset'], dataset_counts['Drugs'], color='#3A8AC2', marker='o', linestyle='-', label='Drugs (trend)')
axes[3].plot(dataset_counts['Dataset'], dataset_counts['Diseases'] + dataset_counts['Drugs'], color='#2ca02c', marker='o', linestyle='-', label='Diseases (trend)')
axes[3].plot(dataset_counts['Dataset'], dataset_counts['Associations'] + dataset_counts['Drugs'] + dataset_counts['Diseases'], 
             color='#ff7f0e', marker='o', linestyle='-', label='Associations (trend)')

axes[3].set_title('Dataset Composition', fontsize=14)
axes[3].legend(loc='upper left')
axes[3].spines['top'].set_visible(False)
axes[3].spines['right'].set_visible(False)

fig.text(0.5, 0.05, 'Datasets', ha='center', fontsize=16)

plt.tight_layout()
plt.show()

In [None]:
fig.savefig("img/Figure 5.tiff", dpi=300, format='tiff', bbox_inches='tight')

In [None]:
# # KG Metrics
# heatmap_data = pd.DataFrame({
#     'Fdataset': [0.7523, 0.7684, 0.7725, 0.7710, 0.7227, 0.7453, 0.7729, 0.7685, 0.7721, 0.7092, 0.5834, 0.5807, 0.5947, 0.5947, 0.5684],
#     'Cdataset': [0.7873, 0.8062, 0.8197, 0.8198, 0.7584, 0.7735, 0.8054, 0.8135, 0.8115, 0.7394, 0.6020, 0.5994, 0.6196, 0.6196, 0.5859],
#     'Ydataset': [0.8649, 0.8552, 0.8667, 0.8600, 0.8264, 0.8594, 0.8397, 0.8550, 0.8496, 0.8230, 0.6423, 0.6155, 0.6431, 0.6397, 0.6235]
# }, index=['SVM AUC', 'RandomForest AUC', 'XGBoost AUC', 'LightGBM AUC', 'MLP AUC', 
#           'SVM AUPR', 'RandomForest AUPR', 'XGBoost AUPR', 'LightGBM AUPR', 'MLP AUPR', 
#           'SVM F1', 'RandomForest F1', 'XGBoost F1', 'LightGBM F1', 'MLP F1'])

In [None]:
# # Heatmap
# fig = plt.figure(figsize=(10, 8))
# sns.heatmap(heatmap_data, annot=True, fmt=".3f", cmap="YlGnBu", cbar_kws={'label': 'Performance'}, linewidths=.5)
# plt.title("Heatmap of Model Performance on Different Datasets", fontsize=14)
# plt.ylabel("Model and Metric")
# plt.xlabel("Dataset")
# # Show
# plt.tight_layout()
# plt.show()

In [None]:
# # Split the dataframe into separate metrics to create different color maps
# auc_data = heatmap_data.loc[['SVM AUC', 'RandomForest AUC', 'XGBoost AUC', 'LightGBM AUC', 'MLP AUC']].rename(index=lambda x: x.split()[0])
# aupr_data = heatmap_data.loc[['SVM AUPR', 'RandomForest AUPR', 'XGBoost AUPR', 'LightGBM AUPR', 'MLP AUPR']].rename(index=lambda x: x.split()[0])
# f1_data = heatmap_data.loc[['SVM F1', 'RandomForest F1', 'XGBoost F1', 'LightGBM F1', 'MLP F1']].rename(index=lambda x: x.split()[0])

# # Set up the figure
# fig, axes = plt.subplots(1, 3, figsize=(18, 9))

# # AUC Heatmap
# sns.heatmap(auc_data, annot=True, fmt=".4f", cmap="Blues", linewidths=.5, ax=axes[0])
# axes[0].set_title('AUC Performance', fontsize=14)
# axes[0].set_ylabel("Models", fontsize=14)

# # AUPR Heatmap
# sns.heatmap(aupr_data, annot=True, fmt=".4f", cmap="Greens", linewidths=.5, ax=axes[1])
# axes[1].set_title('AUPR Performance', fontsize=14)
# axes[1].set_xlabel("Dataset", fontsize=14)

# # F1 Heatmap
# sns.heatmap(f1_data, annot=True, fmt=".4f", cmap="Oranges", linewidths=.5, ax=axes[2])
# axes[2].set_title('F1 Performance', fontsize=14)

# # Show the plot
# plt.tight_layout()
# plt.show()

In [None]:
# Data for the three datasets with counts of drugs, diseases, and associations
dataset_counts = pd.DataFrame({
    'Dataset': ['Fdataset', 'Cdataset', 'Ydataset'],
    'Drugs': [593, 663, 1478],
    'Diseases': [313, 409, 655],
    'Associations': [1933, 2352, 8448]
})

# KG Metrics
heatmap_data = pd.DataFrame({
    'Fdataset': [0.7523, 0.7684, 0.7725, 0.7710, 0.7227, 0.7453, 0.7729, 0.7685, 0.7721, 0.7092, 0.5834, 0.5807, 0.5947, 0.5947, 0.5684],
    'Cdataset': [0.7873, 0.8062, 0.8197, 0.8198, 0.7584, 0.7735, 0.8054, 0.8135, 0.8115, 0.7394, 0.6020, 0.5994, 0.6196, 0.6196, 0.5859],
    'Ydataset': [0.8649, 0.8552, 0.8667, 0.8600, 0.8264, 0.8594, 0.8397, 0.8550, 0.8496, 0.8230, 0.6423, 0.6155, 0.6431, 0.6397, 0.6235]
}, index=['SVM AUC', 'RandomForest AUC', 'XGBoost AUC', 'LightGBM AUC', 'MLP AUC', 
          'SVM AUPR', 'RandomForest AUPR', 'XGBoost AUPR', 'LightGBM AUPR', 'MLP AUPR', 
          'SVM F1', 'RandomForest F1', 'XGBoost F1', 'LightGBM F1', 'MLP F1'])

# Split the data for AUC, AUPR, and F1
auc_data = heatmap_data.loc[['SVM AUC', 'RandomForest AUC', 'XGBoost AUC', 'LightGBM AUC', 'MLP AUC']].rename(index=lambda x: x.split()[0])
aupr_data = heatmap_data.loc[['SVM AUPR', 'RandomForest AUPR', 'XGBoost AUPR', 'LightGBM AUPR', 'MLP AUPR']].rename(index=lambda x: x.split()[0])
f1_data = heatmap_data.loc[['SVM F1', 'RandomForest F1', 'XGBoost F1', 'LightGBM F1', 'MLP F1']].rename(index=lambda x: x.split()[0])

# Set up the figure with an extra subplot for dataset counts
fig, axes = plt.subplots(1, 4, figsize=(24, 9), gridspec_kw={'wspace': 0.4})

# AUC Heatmap
sns.heatmap(auc_data, annot=True, fmt=".4f", cmap="Blues", linewidths=.5, ax=axes[0])
axes[0].set_title('AUC Performance', fontsize=14)
axes[0].set_ylabel("Models", fontsize=16, labelpad=10)

# AUPR Heatmap
sns.heatmap(aupr_data, annot=True, fmt=".4f", cmap="Greens", linewidths=.5, ax=axes[1])
axes[1].set_title('AUPR Performance', fontsize=14)

# F1 Heatmap
sns.heatmap(f1_data, annot=True, fmt=".4f", cmap="Oranges", linewidths=.5, ax=axes[2])
axes[2].set_title('F1 Performance', fontsize=14)

# Dataset count bar plot with colors matched to heatmaps
axes[3].bar(dataset_counts['Dataset'], dataset_counts['Drugs'], color='#2777B8', label='Drugs')
axes[3].bar(dataset_counts['Dataset'], dataset_counts['Diseases'], color='#4BB062', label='Diseases', bottom=dataset_counts['Drugs'])
axes[3].bar(dataset_counts['Dataset'], dataset_counts['Associations'], color='#FDA35C', label='Associations', 
            bottom=dataset_counts['Drugs'] + dataset_counts['Diseases'])

# Add line plots to indicate trends for each category
axes[3].plot(dataset_counts['Dataset'], dataset_counts['Drugs'], color='#3A8AC2', marker='o', linestyle='-', label='Drugs (trend)')
axes[3].plot(dataset_counts['Dataset'], dataset_counts['Diseases'] + dataset_counts['Drugs'], color='#2ca02c', marker='o', linestyle='-', label='Diseases (trend)')
axes[3].plot(dataset_counts['Dataset'], dataset_counts['Associations'] + dataset_counts['Drugs'] + dataset_counts['Diseases'], 
             color='#ff7f0e', marker='o', linestyle='-', label='Associations (trend)')

axes[3].set_title('Dataset Composition', fontsize=14)
axes[3].legend(loc='upper left')
axes[3].spines['top'].set_visible(False)
axes[3].spines['right'].set_visible(False)

fig.text(0.5, 0.05, 'Datasets', ha='center', fontsize=16)

plt.tight_layout()
plt.show()

In [None]:
fig.savefig("img/Figure 4.tiff", dpi=300, format='tiff', bbox_inches='tight')

## Enhanced Input

In [5]:
data_file = 'data/Benchmark/amvl_idrug.xlsx'
df = pd.read_excel(data_file)
df.head()

In [6]:
# Define a function to extract mean, lower, and upper bounds from the given data format
def parse_performance(value):
    mean, ci = value.split(" (")
    lower, upper = ci[:-1].split(" - ")
    return float(mean), float(lower), float(upper)

# Define model combinations (metrics) explicitly
metrics = df.columns[3:].to_list()

# Filter data for Fdataset only
fdataset_data = df[df["dataset"] == "iDrug"]

# Parse data for AUC, AUPR, and F1 metrics
parsed_fdataset_data = []
for _, row in fdataset_data.iterrows():
    for column in metrics:
        if column in row.index:  # Ensure the column exists in the row
            mean, lower, upper = parse_performance(row[column])
            parsed_fdataset_data.append({
                "metric": row["metric"],
                "model": column,
                "mean": mean,
                "lower": lower,
                "upper": upper
            })

# Convert processed data to a new DataFrame
parsed_fdataset_df = pd.DataFrame(parsed_fdataset_data)
parsed_fdataset_df.head()

In [None]:
colors = [
    '#AAD09D',  # 浅绿
    '#66BC98',  # 青绿
    '#E3EA96',  # 浅黄
    '#FDCB89',  # 新增：浅橙
    '#4DA8DA',  # 新增：浅蓝
    '#3288BD',  # 蓝色
    '#2C559A',
    '#F46D43',  # 橙红
    '#C154C1',  # 新增：紫色
    '#8A233F',  # 深红
    '#6C3483'   # 新增：深紫
]

In [None]:
def draw_color(colors):
    # 图形设置
    fig, ax = plt.subplots(figsize=(10, 2))
    
    # 绘制每种颜色
    for i, color in enumerate(colors):
        ax.add_patch(plt.Rectangle((i, 0), 1, 1, color=color))
        ax.text(i + 0.5, -0.5, color, ha='center', fontsize=10)  # 在颜色块下方显示颜色代码
    
    # 设置 x 和 y 的范围
    ax.set_xlim(0, len(colors))
    ax.set_ylim(-1, 1)
    
    # 去除坐标轴
    ax.axis('off')
    
    # 显示图形
    plt.show()

draw_color(colors)

In [None]:
# 准备模型和指标的唯一值
models = parsed_fdataset_df["model"].unique()
metrics = parsed_fdataset_df["metric"].unique()
x = np.arange(len(models))  # 模型的 x 坐标

# 设置子图布局
fig, axs = plt.subplots(2, 2, figsize=(14, 10), dpi=300)
axs = axs.flatten()

# 遍历每个指标并绘制子图
for idx, metric in enumerate(["auc", "aupr", "f1"]):
    ax = axs[idx]
    data = parsed_fdataset_df[parsed_fdataset_df["metric"] == metric]
    means = data["mean"]
    errors = [data["mean"] - data["lower"], data["upper"] - data["mean"]]

    # 绘制条形图
    ax.bar(x, means, width=0.4, yerr=errors, capsize=5, label=metric.upper(), color=colors[idx + 1])

    # 添加趋势线
    ax.plot(x, means, marker='o', color='red', linestyle='--', label=f"Trendline")

    # 子图定制
    ax.set_xticks(x)
    ax.set_xticklabels(models, rotation=45, ha="right", fontsize=10)
    ax.set_ylabel("Performance", fontsize=12)
    
    if metric == "auc":
        ax.set_ylim(0.964, 0.970)
    elif metric == "aupr":
        ax.set_ylim(0.968, 0.9715)
    elif metric == "f1":
        ax.set_ylim(0.705, 0.740)
    
    ax.legend(fontsize=10)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.tick_params(axis='x', labelsize=10)
    ax.tick_params(axis='y', labelsize=10)
    ax.text(-0.12, 1.1, chr(97 + idx), transform=ax.transAxes, size=18, weight='bold')  # 子图标注

# 绘制总体图
overall_ax = axs[3]
overall_means = []
for model in models:
    model_data = parsed_fdataset_df[parsed_fdataset_df["model"] == model]
    overall_mean = model_data["mean"].mean()
    overall_means.append(overall_mean)
overall_means = np.array(overall_means)

# 绘制总体条形图
overall_ax.bar(x, overall_means, width=0.4, color=colors[4], label="Overall")

# 添加总体趋势线
overall_ax.plot(x, overall_means, marker='o', color=colors[6], linestyle='--', label="Trendline")

# 总体图定制
overall_ax.set_xticks(x)
overall_ax.set_xticklabels(models, rotation=45, ha="right", fontsize=10)
overall_ax.set_ylabel("Mean Performance", fontsize=12)
overall_ax.set_ylim(0.88, 0.8930)
overall_ax.legend(fontsize=10)
overall_ax.spines['top'].set_visible(False)
overall_ax.spines['right'].set_visible(False)
overall_ax.tick_params(axis='x', labelsize=10)
overall_ax.tick_params(axis='y', labelsize=10)
overall_ax.text(-0.12, 1.1, 'd', transform=overall_ax.transAxes, size=18, weight='bold')  # 子图标注

# 调整布局并显示
plt.tight_layout(w_pad=2)
plt.savefig('img/Fig. 8.tiff', dpi=300)
plt.show()