In [None]:
import pandas as pd
import torch
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv
from sklearn.preprocessing import StandardScaler
import torch.nn as nn
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops
from sklearn.metrics import average_precision_score, roc_auc_score
import numpy as np
import random
from sklearn.metrics import accuracy_score
from sklearn.model_selection import StratifiedKFold, StratifiedShuffleSplit


c = 'BRCA'
edges_df = pd.read_csv('/content/drive/My Drive/networks/ppi_{c}_plus_knn3_bidirectional.csv')
features_df = pd.read_csv('/content/drive/My Drive/features/features_for_{c}.csv', index_col=0)
labels_df = pd.read_csv('/content/drive/My Drive/lables/{c}_labels(0_1).csv')


genes_from_edges = set(edges_df['gene1']).union(set(edges_df['gene2']))
genes_from_features = set(features_df.index)
all_genes = sorted(genes_from_edges)

node_to_idx = {gene: i for i, gene in enumerate(all_genes)}
idx_to_node = {i: gene for gene, i in node_to_idx.items()}

edges = edges_df[['gene1', 'gene2']].dropna()
edge_index = torch.tensor([[node_to_idx[a], node_to_idx[b]]
                           for a, b in edges.values if a in node_to_idx and b in node_to_idx],
                          dtype=torch.long).t().contiguous()

feature_dim = features_df.shape[1]
x_matrix = np.zeros((len(all_genes), feature_dim))
has_feature = np.zeros(len(all_genes), dtype=bool)

scaler = StandardScaler()
features_scaled = scaler.fit_transform(features_df.values)
features_scaled_df = pd.DataFrame(features_scaled, index=features_df.index)

for gene in features_scaled_df.index:
    if gene in node_to_idx:
        idx = node_to_idx[gene]
        x_matrix[idx] = features_scaled_df.loc[gene].values
        has_feature[idx] = True

neighbors_dict = {i: [] for i in range(len(all_genes))}
for src, dst in edge_index.t().tolist():
    neighbors_dict[src].append(dst)
    neighbors_dict[dst].append(src)

for i in range(len(all_genes)):
    if not has_feature[i]:
        neighbor_feats = [x_matrix[n] for n in neighbors_dict[i] if has_feature[n]]
        if neighbor_feats:
            x_matrix[i] = np.mean(neighbor_feats, axis=0)


x = torch.tensor(x_matrix, dtype=torch.float)

labels_map = {row['Gene']: row['Labels'] for _, row in labels_df.iterrows()}
y = torch.full((x.size(0),), -1, dtype=torch.long)

for gene, label in labels_map.items():
    if gene in node_to_idx:
        y[node_to_idx[gene]] = int(label)

data = Data(x=x, edge_index=edge_index, y=y)

class GCN1(nn.Module):
    def __init__(self, in_channels, hidden_channels, dropout=0.5):
        super(GCN1, self).__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, 1)
        self.dropout = dropout

    def forward(self, x, edge_index):
        x = F.relu(self.conv1(x, edge_index))
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.conv2(x, edge_index).squeeze()
        return x

def train(model, data, mask, optimizer, loss_fn):
    model.train()
    optimizer.zero_grad()
    logits = model(data.x, data.edge_index)
    labels = data.y[mask].float()
    loss = loss_fn(logits[mask], labels)
    loss.backward()
    optimizer.step()
    return loss.item()

def evaluate(model, data, mask, loss_fn):
    model.eval()
    with torch.no_grad():
        logits = model(data.x, data.edge_index)
        logits_masked = logits[mask]
        labels = data.y[mask].float()

        loss = loss_fn(logits_masked, labels).item()

        probs = torch.sigmoid(logits_masked).cpu()
        labels_cpu = labels.cpu()
        preds = (probs > 0.5).long()
        acc = accuracy_score(labels_cpu, preds)
        auprc = average_precision_score(labels_cpu, probs)
        return acc, auprc, loss

def train_model_5fold(data, hidden_channels=64,
                       epochs=200, patience=30, min_delta=1e-4):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    data = data.to(device)

    labeled_idx = torch.where(data.y != -1)[0]
    labeled_y   = data.y[labeled_idx].cpu().numpy()

    skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=46)
    results = []

    for fold, (train_val_idx, test_idx) in enumerate(skf.split(np.arange(len(labeled_y)), labeled_y)):
        print(f"\n📂 Fold {fold+1}/5")

        train_val_idx = labeled_idx[train_val_idx]
        test_idx      = labeled_idx[test_idx]

        sss = StratifiedShuffleSplit(n_splits=1, test_size=0.20, random_state=44)
        train_nodes, val_nodes = next(
            sss.split(np.zeros(len(train_val_idx)), data.y[train_val_idx].cpu().numpy())
        )
        train_nodes = train_val_idx[train_nodes]
        val_nodes   = train_val_idx[val_nodes]
        test_nodes  = test_idx

        n_total   = data.num_nodes
        train_mask = torch.zeros(n_total, dtype=torch.bool, device=device)
        val_mask   = torch.zeros_like(train_mask)
        test_mask  = torch.zeros_like(train_mask)
        train_mask[train_nodes] = True
        val_mask  [val_nodes]   = True
        test_mask [test_nodes]  = True

        pos_weight_value  = (data.y[train_nodes] == 0).sum() / (data.y[train_nodes] == 1).sum()
        pos_weight_tensor = torch.tensor(float(pos_weight_value), dtype=torch.float32, device=device)
        loss_fn = nn.BCEWithLogitsLoss(pos_weight=pos_weight_tensor)

        model = GCN1(data.num_node_features, hidden_channels).to(device)
        optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

        best_val = 0.0
        best_state = None
        wait = 0

        for epoch in range(1, epochs + 1):
            train_loss = train(model, data, train_mask, optimizer, loss_fn)
            train_acc, train_auprc, _  = evaluate(model, data, train_mask, loss_fn)
            val_acc,   val_auprc, vls  = evaluate(model, data, val_mask, loss_fn)

            print(f"Ep{epoch:03d} | "
                  f"Train Acc {train_acc:.4f}, AUPRC {train_auprc:.4f} | "
                  f"Val Acc {val_acc:.4f}, AUPRC {val_auprc:.4f}, Loss {vls:.4f}")


            if val_auprc - best_val > min_delta:
                best_val  = val_auprc
                best_state = model.state_dict()
                wait = 0
            else:
                wait += 1
                if wait >= patience:
                    print(f"⏹️  Stop early at epoch {epoch} (no impro ≥{patience})")
                    break


        if best_state is not None:
            model.load_state_dict(best_state)

        test_acc, test_auprc, _ = evaluate(model, data, test_mask, loss_fn)
        print(f"✅ Test Accuracy: {test_acc:.4f} | AUPRC: {test_auprc:.4f}")
        results.append((test_acc, test_auprc))
    if results:
        accs, auprcs = map(np.array, zip(*results))
        print("\n📊 5 fold:")
        print(f"  Accuracy: {accs.mean():.4f} ± {accs.std():.4f}")
        print(f"  AUPRC:    {auprcs.mean():.4f} ± {auprcs.std():.4f}")
train_model_5fold(data)