In [None]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import GATConv
from torch_geometric.data import Data
from sklearn.model_selection import KFold
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, precision_recall_curve, auc
import numpy as np
import pandas as pd
import networkx as nx
from sklearn.preprocessing import StandardScaler
from itertools import product
from karateclub import Role2Vec
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, precision_recall_curve

# Load extracted DNABERT features for circRNA and miRNA
circRNA_features = pd.read_csv('circRNA_Extractedfeatures.csv')
miRNA_features = pd.read_csv('miRNA_Extractedfeatures.csv')

# Load CMI interaction data (pairs of circRNA and miRNA)
cmi_df = pd.read_csv('CMI9905pairs.csv')

# Create a graph where circRNA and miRNA interact
G = nx.Graph()
for _, row in cmi_df.iterrows():
    G.add_edge(row['miRNA'], row['circRNA'])

# Create a mapping from original node labels to integers
mapping = {node: i for i, node in enumerate(G.nodes())}
G = nx.relabel_nodes(G, mapping)

# Get Role2Vec embeddings
role2vec = Role2Vec()
role2vec.fit(G)
role2vec_embeddings = role2vec.get_embedding()

# Identify node IDs that start with 'hsa_circ' and split the embeddings accordingly
node_ids = list(mapping.keys())  # List of node IDs in their original format
circRNA_embeddings = []
miRNA_embeddings = []

# Loop through the node IDs and corresponding embeddings
for node_id, embedding in zip(node_ids, role2vec_embeddings):
    if isinstance(node_id, str) and node_id.startswith('hsa_circ'):  # If the node starts with 'hsa_circ', it's a circRNA
        circRNA_embeddings.append(embedding)
    else:  # Otherwise, it's assumed to be miRNA or another type
        miRNA_embeddings.append(embedding)

# Convert lists to numpy arrays
circRNA_embeddings = np.array(circRNA_embeddings)
miRNA_embeddings = np.array(miRNA_embeddings)

# Check the shape of circRNA_features and circRNA_embeddings
print(f"circRNA features shape: {circRNA_features.shape}")
print(f"circRNA embeddings shape: {circRNA_embeddings.shape}")

# Align circRNA features and embeddings by filtering based on the smaller size
min_circRNA_size = min(len(circRNA_features), len(circRNA_embeddings))

# Adjust both to have the same number of rows
aligned_circRNA_features = circRNA_features.iloc[:min_circRNA_size]
aligned_circRNA_embeddings = circRNA_embeddings[:min_circRNA_size]

# Concatenate circRNA features and embeddings
X_circRNA = np.hstack([aligned_circRNA_features, aligned_circRNA_embeddings])

# Check the shape of the concatenated data
print(f"X_circRNA shape after alignment: {X_circRNA.shape}")

# Align miRNA features and embeddings by filtering them based on the smaller size
min_miRNA_size = min(len(miRNA_features), len(miRNA_embeddings))

# Adjust both to have the same number of rows
aligned_miRNA_features = miRNA_features.iloc[:min_miRNA_size]
aligned_miRNA_embeddings = miRNA_embeddings[:min_miRNA_size]

# Concatenate aligned miRNA features and embeddings
X_miRNA = np.hstack([aligned_miRNA_features, aligned_miRNA_embeddings])

# Generate all possible circRNA-miRNA pairs
all_pairs = list(product(circRNA_features.index, aligned_miRNA_features.index))

# Filter out the positive pairs (those already known to interact from the CMI dataset)
positive_pairs = set((mapping[row['miRNA']], mapping[row['circRNA']]) for _, row in cmi_df.iterrows())
all_pairs_set = set(all_pairs)

# Get negative pairs by subtracting positive pairs from all pairs
negative_pairs = list(all_pairs_set - positive_pairs)

# Randomly sample the same number of negative pairs as positive pairs
np.random.shuffle(negative_pairs)
negative_pairs = negative_pairs[:len(positive_pairs)]

# Map the circRNA and miRNA indices back to their original labels
negative_pairs_mapped = [(list(mapping.keys())[circ_idx], list(mapping.keys())[miRNA_idx]) for circ_idx, miRNA_idx in negative_pairs]

# Create DataFrame for negative samples
negative_df = pd.DataFrame(negative_pairs_mapped, columns=['miRNA', 'circRNA'])
negative_df['interaction'] = 0  # Label these as non-interaction

# Add label 1 for positive pairs
cmi_df['interaction'] = 1

# Combine positive and negative samples
interaction_df = pd.concat([cmi_df[['miRNA', 'circRNA', 'interaction']], negative_df], ignore_index=True)

# Create feature vectors for each circRNA-miRNA pair
valid_rows = []
X = []

for index, row in interaction_df.iterrows():
    miRNA_id = row['miRNA']
    circRNA_id = row['circRNA']
    
    if circRNA_id in mapping and miRNA_id in mapping:
        miRNA_idx = mapping[miRNA_id]
        circRNA_idx = mapping[circRNA_id]
        
        if miRNA_idx < X_miRNA.shape[0] and circRNA_idx < X_circRNA.shape[0]:
            combined_features = np.hstack([X_miRNA[miRNA_idx], X_circRNA[circRNA_idx]])
            X.append(combined_features)
            valid_rows.append(index)

# Convert the list of feature vectors to numpy array
X = np.array(X)

# Extract the matching y labels from valid rows
y = interaction_df.loc[valid_rows]['interaction'].values

# Normalize the features
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)

# Define GAT model
class GAT(torch.nn.Module):
    def __init__(self, in_channels, out_channels, hidden_channels=64, heads=2):
        super(GAT, self).__init__()
        self.conv1 = GATConv(in_channels, hidden_channels, heads=heads)
        self.conv2 = GATConv(hidden_channels * heads, out_channels, heads=1)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = F.elu(x)
        x = F.dropout(x, p=0.6, training=self.training)
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)

# Define training function
def train(model, optimizer, data):
    model.train()
    optimizer.zero_grad()
    out = model(data)
    loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()
    return loss.item()

# Define evaluation function
def evaluate(model, data):
    model.eval()
    logits = model(data)
    y_pred = logits[data.test_mask].max(1)[1].cpu().numpy()
    y_true = data.y[data.test_mask].cpu().numpy()
    y_probs = logits[data.test_mask][:, 1].cpu().detach().numpy()  # probabilities for class 1

    accuracy = accuracy_score(y_true, y_pred)
    precision = precision_score(y_true, y_pred)
    recall = recall_score(y_true, y_pred)
    f1 = f1_score(y_true, y_pred)
    roc_auc = roc_auc_score(y_true, y_probs)
    
    # Calculate AUPR
    precision_curve, recall_curve, _ = precision_recall_curve(y_true, y_probs)
    aupr = auc(recall_curve, precision_curve)

    return accuracy, precision, recall, f1, roc_auc, aupr, y_true, y_pred, y_probs

# Initialize dictionaries to store FPR, TPR, Precision, and Recall for each fold
fpr_dict = {}
tpr_dict = {}
precision_dict = {}
recall_dict = {}

# K-Fold Cross-Validation Setup
kf = KFold(n_splits=5, shuffle=True, random_state=42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

fold_metrics = []

for fold, (train_index, test_index) in enumerate(kf.split(X_scaled)):
    print(f'Fold {fold + 1}')
    
    # Create train and test masks
    train_mask = torch.zeros(len(y), dtype=torch.bool)
    test_mask = torch.zeros(len(y), dtype=torch.bool)
    train_mask[train_index] = True
    test_mask[test_index] = True
    
    # Prepare PyTorch Geometric data object
    edge_index = torch.tensor(list(zip(*G.edges)), dtype=torch.long)
    x = torch.tensor(X_scaled, dtype=torch.float)
    y_tensor = torch.tensor(y, dtype=torch.long)

    data = Data(x=x, edge_index=edge_index, y=y_tensor)
    data.train_mask = train_mask
    data.test_mask = test_mask
    
    # Initialize model, optimizer
    model = GAT(in_channels=x.shape[1], out_channels=2).to(device)
    data = data.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=5e-4)

    # Train model
    for epoch in range(200):
        loss = train(model, optimizer, data)
    
    # Evaluate model on the test set
    accuracy, precision, recall, f1, roc_auc, aupr, y_true, y_pred, y_probs = evaluate(model, data)
    print(f'Fold {fold + 1} -- Accuracy: {accuracy:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}, AUROC: {roc_auc:.4f}, AUPR: {aupr:.4f}')
    
    # Store metrics for each fold
    fold_metrics.append({
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1': f1,
        'roc_auc': roc_auc,
        'aupr': aupr
    })
    
    # ROC Curve
    fpr, tpr, _ = roc_curve(y_true, y_probs)
    fpr_dict[fold] = fpr
    tpr_dict[fold] = tpr
    
    # Precision-Recall Curve
    precision_curve, recall_curve, _ = precision_recall_curve(y_true, y_probs)
    precision_dict[fold] = precision_curve
    recall_dict[fold] = recall_curve

# Plot all ROC curves on one figure
plt.figure()
for fold in range(5):
    plt.plot(fpr_dict[fold], tpr_dict[fold], lw=2, label=f'Fold {fold + 1} ROC curve (AUC = {auc(fpr_dict[fold], tpr_dict[fold]):.4f})')
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC Curves for Each Fold')
plt.legend(loc="lower right")
plt.show()

# Plot all PR curves on one figure
plt.figure()
for fold in range(5):
    plt.plot(recall_dict[fold], precision_dict[fold], lw=2, label=f'Fold {fold + 1} PR curve (AUPR = {auc(recall_dict[fold], precision_dict[fold]):.4f})')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.title('Precision-Recall Curves for Each Fold')
plt.legend(loc="lower left")
plt.show()

# Aggregate metrics across folds
avg_metrics = {metric: np.mean([f[metric] for f in fold_metrics]) for metric in fold_metrics[0].keys()}
std_metrics = {metric: np.std([f[metric] for f in fold_metrics]) for metric in fold_metrics[0].keys()}

print(f'Average Metrics across 5 folds:')
print(f'Accuracy: {avg_metrics["accuracy"]:.4f}, Precision: {avg_metrics["precision"]:.4f}, Recall: {avg_metrics["recall"]:.4f}, F1: {avg_metrics["f1"]:.4f}, AUROC: {avg_metrics["roc_auc"]:.4f}, AUPR: {avg_metrics["aupr"]:.4f}')
print(f'Standard Deviation across 5 folds:')
print(f'Accuracy: {std_metrics["accuracy"]:.4f}, Precision: {std_metrics["precision"]:.4f}, Recall: {std_metrics["recall"]:.4f}, F1: {std_metrics["f1"]:.4f}, AUROC: {std_metrics["roc_auc"]:.4f}, AUPR: {std_metrics["aupr"]:.4f}')
