This notebook is ued to train and evaluate OncoPlex on the cancer-specific dataset

- Load the data preprocessed previously
- create graph for each cancer individually 
- Model class
- Train and eval for the individual cancers

In [None]:
import pandas as pd
import numpy as np
import pickle
import random
import os 
import math 

import torch
import torch_geometric
import torch.nn as nn
from torch_geometric.nn import GCNConv, LayerNorm, HypergraphConv
from torch.nn import Dropout, Parameter

import torch.nn.functional as F
import torch.optim as optim
from torch.nn.modules.module import Module


from sklearn.metrics import roc_auc_score, precision_recall_curve, auc, f1_score
from sklearn.model_selection import StratifiedKFold, KFold, ParameterGrid

from scipy.sparse import coo_matrix

from sklearn.preprocessing import MinMaxScaler, StandardScaler
from sklearn.model_selection import train_test_split

import warnings
warnings.filterwarnings('ignore')

# Load data

In [None]:
# i will load incidence into datafrmae
loaded_coo = np.load("/content/incidence_matrix__FI_coo13560.npz")
incidence_matrix_coo = coo_matrix(
    (loaded_coo["data"], (loaded_coo["row"], loaded_coo["col"])), shape=loaded_coo["shape"]
)
gene_names = loaded_coo["genes"]
pathway_names = loaded_coo["pathways"]

incidence_df = pd.DataFrame.sparse.from_spmatrix(incidence_matrix_coo, index=gene_names, columns=pathway_names)
incidence_df
genelist = incidence_df.index.tolist()
incidence_df

Unnamed: 0,BIOCARTA_41BB_PATHWAY,BIOCARTA_ACE2_PATHWAY,BIOCARTA_ACETAMINOPHEN_PATHWAY,BIOCARTA_ACH_PATHWAY,BIOCARTA_ACTINY_PATHWAY,BIOCARTA_AGPCR_PATHWAY,BIOCARTA_AGR_PATHWAY,BIOCARTA_AHSP_PATHWAY,BIOCARTA_AKAP13_PATHWAY,BIOCARTA_AKAP95_PATHWAY,...,WP_VITAMIN_D_RECEPTOR_PATHWAY,WP_WARBURG_EFFECT_MODULATED_BY_DEUBIQUITINATING_ENZYMES_AND_THEIR_SUBSTRATES,WP_WHITE_FAT_CELL_DIFFERENTIATION,WP_WNTBETACATENIN_SIGNALING_INHIBITORS_IN_CURRENT_AND_PAST_CLINICAL_TRIALS,WP_WNTBETACATENIN_SIGNALING_IN_LEUKEMIA,WP_WNT_SIGNALING_AND_PLURIPOTENCY,WP_WNT_SIGNALING_IN_KIDNEY_DISEASE,WP_WNT_SIGNALING_WP363,WP_WNT_SIGNALING_WP428,WP_ZINC_HOMEOSTASIS
A1BG,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
A1CF,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
A2M,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
A3GALT2,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
A4GALT,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
ZWILCH,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
ZWINT,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
ZYG11B,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
ZYX,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0


### Read the multi-omics features files

In [None]:
def process_features(file_paths):
    """
    This function will load all the preprocessed omics features, and make sure all
    the nodes are present in the used network, then it will do some normalization and finally combine all the features
    into a single big biological matrix.

    Parameters:
        file_paths (dict): Dictionary with keys as dataset names ('expression', 'meth', 'mutation')
                           and values as file paths to the datasets.

    Returns:
        pd.DataFrame: A combined, normalized features matrix.
    """
    # Read all preprocess omics feature matrix
    def read_and_process(file_path):
        df = pd.read_csv(file_path, sep='\t')
        df.columns = ['Name'] + [col.upper() for col in df.columns[1:]]
        df.set_index('Name', inplace=True)
        return df

    # Read and process all datasets
    datasets = {name: read_and_process(path) for name, path in file_paths.items()}

    # Load network from the previous function with the node as index
    incidenc= incidence_df
    ppi_index = incidenc.index.tolist()


    # Find common cancer types across all datasets
    common_ctypes = list(
        set.intersection(*(set(df.columns) for df in datasets.values()))
    )

    # Filter datasets by common cancer types
    for name in datasets:
        datasets[name] = datasets[name][common_ctypes]

    # Make sure the nodes in the features same as the network
    reindexed_datasets = {name: df.reindex(ppi_index, fill_value=0) for name, df in datasets.items()}
    mutation_node = datasets['mutation'][datasets['mutation'].index.isin(ppi_index)].shape[0]
    print(f'Number of genes in mutation matrix: {mutation_node}')
    expr_nodes = datasets['expression'][datasets['expression'].index.isin(ppi_index)].shape[0]
    print(f'Number of genes in expression matrix: {expr_nodes}')
    meth_nodes = datasets['meth'][datasets['meth'].index.isin(ppi_index)].shape[0]
    print(f'Number of genes in methylation matrix: {meth_nodes}')


    # Normalize with MinMax
    scaler = MinMaxScaler()
    normalized_datasets = {
        name: pd.DataFrame(
            scaler.fit_transform(np.abs(df)),
            index=df.index,
            columns=[f"{name.upper()}_{col}" for col in df.columns]
        )
        for name, df in reindexed_datasets.items()
    }
    '''scaler = StandardScaler()

    normalized_datasets = {
    name: pd.DataFrame(
        scaler.fit_transform(np.abs(df)),  # Apply StandardScaler
        index=df.index,
        columns=[f"{name.upper()}_{col}" for col in df.columns]
    )
    for name, df in reindexed_datasets.items()
}'''

    # Combine datasets into a single feature matrix
    multi_omics_features = pd.concat(normalized_datasets.values(), axis=1)

    return multi_omics_features


file_paths = {
    "expression": "/content/drive/MyDrive/DTGNN/Cancer data/TCGA/processed/cancer_gene_expression_matrix.tsv",
    "meth": "/content/drive/MyDrive/DTGNN/Cancer data/TCGA/processed/meth_mean_logfold.csv",
    "mutation": "/content/drive/MyDrive/DTGNN/Cancer data/TCGA/processed/mutation_genecancer.tsv"
}

biological_features= process_features(file_paths)
biological_features.head()

Number of genes in mutation matrix: 12648
Number of genes in expression matrix: 13314
Number of genes in methylation matrix: 11671


Unnamed: 0_level_0,EXPRESSION_THCA,EXPRESSION_STAD,EXPRESSION_BRCA,EXPRESSION_LUAD,EXPRESSION_LUSC,EXPRESSION_UCEC,EXPRESSION_PRAD,EXPRESSION_CESC,EXPRESSION_READ,EXPRESSION_COAD,...,MUTATION_PRAD,MUTATION_CESC,MUTATION_READ,MUTATION_COAD,MUTATION_KIRC,MUTATION_BLCA,MUTATION_KIRP,MUTATION_ESCA,MUTATION_LIHC,MUTATION_HNSC
Name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
A1BG,0.064627,0.05059,0.042643,0.035614,0.013315,0.026581,0.069536,0.071728,0.022431,0.025391,...,0.016757,0.030481,0.011118,0.025176,0.005916,0.036646,0.021801,0.0,0.028092,0.019984
A1CF,0.103195,0.075999,0.013237,0.378315,0.175999,0.094288,0.282134,0.195314,0.076384,0.121657,...,0.108539,0.044913,0.007343,0.015552,0.01757,0.017955,0.0,0.0,0.035585,0.021146
A2M,0.06892,0.002152,0.129277,0.196795,0.338599,0.178535,0.037951,0.133509,0.100596,0.131797,...,0.046368,0.047025,0.052003,0.057855,0.029411,0.064517,0.065283,0.023726,0.037366,0.014702
A3GALT2,0.082871,0.008796,0.037956,0.003012,0.01693,0.012469,0.029517,0.176873,0.042488,0.006947,...,0.0,0.020544,0.0,0.012161,0.0,0.004654,0.0,0.005985,0.0,0.0
A4GALT,0.055952,0.086685,0.022587,0.001923,0.142636,0.044437,0.087904,0.006782,0.055393,0.046637,...,0.033047,0.020172,0.0,0.011678,0.0,0.004634,0.021722,0.005925,0.00935,0.002511


### Read the driver label for each cancer

In [None]:
def get_labels(genes, cancer_type):
    # cancer-specific driver genes
    driver_path = f'/content/Positive_{cancer_type}_driver.txt'
    driver_df = pd.read_csv(driver_path, sep='\t', header=None, names=['gene'])
    driver_genes = sorted(list(set(genes) & set(driver_df['gene'])))

    print(f"[{cancer_type}] Driver genes count: {len(driver_genes)}")

    # non-driver genes
    nondriver_df = pd.read_csv('/content/Negative_nondriver.csv', header=None, names=['gene'])
    nondriver_genes = sorted(list(set(genes) & set(nondriver_df['gene'])))

    print(f"[{cancer_type}] Nondriver genes count: {len(nondriver_genes)}")

    
    labels = pd.DataFrame(data=[-1] * len(genes), index=genes)

    # Assign labels for driver and nondriver
    labels.loc[driver_genes, 0] = 1
    labels.loc[nondriver_genes, 0] = 0

    # known samples
    driver_idx = labels.index.get_indexer(driver_genes)
    nondriver_idx = labels.index.get_indexer(nondriver_genes)

    sample_indices = np.concatenate([driver_idx, nondriver_idx])
    sample_labels = np.array([1] * len(driver_idx) + [0] * len(nondriver_idx))

    return sample_indices, sample_labels, labels.values.ravel(), driver_genes, nondriver_genes


# label dictionary for all cancer types
dataset_FI_HGNN = {}
cancer_types = ['BLCA', 'BRCA', 'LUAD', 'HNSC', 'ESCA', 'CESC', 'PRAD', 'STAD', 'LIHC']
for cancer_type in cancer_types:
    print(f"\nProcessing cancer type: {cancer_type}")

    sample_indices, sample_labels, full_labels, driver_genes, nondriver_genes = get_labels(gene_names, cancer_type)

    dataset_FI_HGNN[cancer_type] = {
        'label': torch.from_numpy(full_labels)
    }


Processing cancer type: BLCA
[BLCA] Driver genes count: 45
[BLCA] Nondriver genes count: 1116

Processing cancer type: BRCA
[BRCA] Driver genes count: 48
[BRCA] Nondriver genes count: 1116

Processing cancer type: LUAD
[LUAD] Driver genes count: 28
[LUAD] Nondriver genes count: 1116

Processing cancer type: HNSC
[HNSC] Driver genes count: 42
[HNSC] Nondriver genes count: 1116

Processing cancer type: ESCA
[ESCA] Driver genes count: 18
[ESCA] Nondriver genes count: 1116

Processing cancer type: CESC
[CESC] Driver genes count: 26
[CESC] Nondriver genes count: 1116

Processing cancer type: PRAD
[PRAD] Driver genes count: 36
[PRAD] Nondriver genes count: 1116

Processing cancer type: STAD
[STAD] Driver genes count: 37
[STAD] Nondriver genes count: 1116

Processing cancer type: LIHC
[LIHC] Driver genes count: 44
[LIHC] Nondriver genes count: 1116


### Create weighted hypergraph for each cancer

In [None]:
def generate_G_from_H_weight(H, W):
    """
    This function generate the propagation matrix G for HGNN from incidence matrix H.
    Here i already define the incidence matrix H with weight from the training nodes .
    Adapted from HGNN github repo: https://github.com/iMoonLab/HGNN
    :param H: hypergraph weighted incidence matrix H
    :param variable_weight: whether the weight of hyperedge is variable
    :return: G
    """
    n_edge = H.shape[1]
    # the degree of the node: sum of the columns (hyperedges)
    DV = np.sum(H * W, axis=1)
    # the degree of the hyperedge: sum of the row (vertices)
    DE = np.sum(H, axis=0)
    #inverse of the square root of the diagonal D_v.
    invDE = np.asmatrix(np.diag(1/DE))
    DV2 = np.asmatrix(np.diag(np.power(DV, -0.5)))
    #replace nan with 0. This is caused by isolated nodes
    DV2 = np.nan_to_num(DV2)
    W = np.asmatrix(np.diag(W))
    H = np.asmatrix(H)
    HT = H.T
    G = DV2 * H * W * invDE * HT * DV2
    return G

In [None]:
incidence_df = incidence_df.sparse.to_dense()
def weighted_H(genes, incidence_df, cancer_type):
    # Get cancer-specific labels
    sampleIndex, label, labelFrame, _, _ = get_labels(genes, cancer_type)

    # Split dataset
    train_idx, test_idx, train_label, test_label = train_test_split(
        sampleIndex, label, test_size=0.2, random_state=42, stratify=label, shuffle=True
    )
    train_idx, val_idx, train_label, val_label = train_test_split(
        train_idx, train_label, test_size=0.25, random_state=42, stratify=train_label
    )

    print(f"\n[{cancer_type}] Number of training samples:", len(train_idx))
    print(f"[{cancer_type}] Number of test samples:", len(test_idx))
    print(f"[{cancer_type}] Number of validation samples:", len(val_idx))

    # Select positive (driver) genes
    trainframe = labelFrame.iloc[train_idx]
    positive_train = trainframe[(trainframe[0] == 1)]
    positive_train = positive_train.dropna().index.tolist()

    print(f"[{cancer_type}] Number of positive genes in training:", len(positive_train))

    # Sum positive genes in each hyperedge
    positiveMatrix = incidence_df.loc[positive_train].sum()

    # Select hyperedges with >=2 driver genes
    selHyperedgeIndex = np.where(positiveMatrix >= 2)[0] # 2 give the best results 
    selHyperedge = incidence_df.iloc[:, selHyperedgeIndex]
    hyperedgeWeight = positiveMatrix[selHyperedgeIndex].values

    print(f"[{cancer_type}] Number of selected hyperedges:", len(selHyperedgeIndex))

    # Normalize weights
    selHyperedgeWeightSum = incidence_df.iloc[:, selHyperedgeIndex].values.sum(0)
    hyperedgeWeight = hyperedgeWeight / selHyperedgeWeightSum

    # Create weighted incidence matrix
    H = np.array(selHyperedge).astype('float')
    DV = np.sum(H * hyperedgeWeight, axis=1)
    for i in range(DV.shape[0]):
        if DV[i] == 0:  # isolated nodes
            t = random.randint(0, H.shape[1] - 1)
            H[i][t] = 0.0001

    G = generate_G_from_H_weight(H, hyperedgeWeight)

    return train_idx, val_idx, test_idx, train_label, val_label, test_label, G


cancer_types = ['BLCA', 'BRCA', 'LUAD', 'HNSC', 'ESCA', 'CESC', 'PRAD', 'STAD', 'LIHC']
for cancer_type in cancer_types:
    print(f"\nProcessing: {cancer_type}")

    train_idx, val_idx, test_idx, train_label, val_label, test_label, G = weighted_H(genelist, incidence_df, cancer_type)
  
    dataset_FI_HGNN[cancer_type]['train_idx'] = train_idx
    dataset_FI_HGNN[cancer_type]['val_idx'] = val_idx
    dataset_FI_HGNN[cancer_type]['test_idx'] = test_idx
    dataset_FI_HGNN[cancer_type]['edge_index'] = G


Processing: BLCA
[BLCA] Driver genes count: 45
[BLCA] Nondriver genes count: 1116

[BLCA] Number of training samples: 696
[BLCA] Number of test samples: 233
[BLCA] Number of validation samples: 232
[BLCA] Number of positive genes in training: 27
[BLCA] Number of selected hyperedges: 335

Processing: BRCA
[BRCA] Driver genes count: 48
[BRCA] Nondriver genes count: 1116

[BRCA] Number of training samples: 698
[BRCA] Number of test samples: 233
[BRCA] Number of validation samples: 233
[BRCA] Number of positive genes in training: 28
[BRCA] Number of selected hyperedges: 540

Processing: LUAD
[LUAD] Driver genes count: 28
[LUAD] Nondriver genes count: 1116

[LUAD] Number of training samples: 686
[LUAD] Number of test samples: 229
[LUAD] Number of validation samples: 229
[LUAD] Number of positive genes in training: 16
[LUAD] Number of selected hyperedges: 363

Processing: HNSC
[HNSC] Driver genes count: 42
[HNSC] Nondriver genes count: 1116

[HNSC] Number of training samples: 694
[HNSC] Num

### Assign core features to each cancer

In [None]:
tcga = [
    'BLCA', 'BRCA', 'LUAD', 'HNSC', 'ESCA', 'CESC', 'PRAD', 'STAD', 'LIHC'
]

if 'dataset_FI_TDHGNN' not in locals():
    dataset_FI_HGNN = {}

for c in tcga:
    matching_columns = [col for col in biological_features.columns if c in col]
    cancer_features = biological_features[matching_columns]
    nodes = cancer_features.index.tolist()


    print(f"Features for {c}:\n", cancer_features.head())
    print(f"Number of genes in {c}: {len(nodes)}")
    
    # Save features in the data dict 
    dataset_FI_HGNN.setdefault(c, {})['features'] = torch.tensor(cancer_features.values, dtype=torch.float32)
    dataset_FI_HGNN.setdefault(c, {})['nodes'] = nodes


Features for BLCA:
          EXPRESSION_BLCA  METH_BLCA  MUTATION_BLCA
Name                                              
A1BG            0.076625   0.212929       0.036646
A1CF            0.343028   0.000000       0.017955
A2M             0.272502   0.013197       0.064517
A3GALT2         0.068790   0.000000       0.004654
A4GALT          0.080663   0.000000       0.004634
Number of genes in BLCA: 13560
Features for BRCA:
          EXPRESSION_BRCA  METH_BRCA  MUTATION_BRCA
Name                                              
A1BG            0.042643   0.007194       0.008502
A1CF            0.013237   0.000000       0.024429
A2M             0.129277   0.182592       0.027949
A3GALT2         0.037956   0.000000       0.000000
A4GALT          0.022587   0.000000       0.005652
Number of genes in BRCA: 13560
Features for LUAD:
          EXPRESSION_LUAD  METH_LUAD  MUTATION_LUAD
Name                                              
A1BG            0.035614   0.010822       0.025328
A1CF       

In [None]:
len(dataset_FI_HGNN['BLCA']['nodes']) # Three features+ label+ weighted G for BRCA cancer

13560

In [None]:
output_dir = '/content/specific_cancers'
os.makedirs(output_dir, exist_ok=True)

for cancer_type, data in dataset_FI_HGNN.items():
    filename = f"{output_dir}/OncoPlex_{cancer_type}.pkl"
    with open(filename, 'wb') as f:
        pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL)
    print(f"Saved: {filename}")

# Model Class

In [None]:
class conv_layer(nn.Module):
    def __init__(self, in_ft, out_ft, bias=True):
        super(conv_layer, self).__init__()

        self.weight = Parameter(torch.Tensor(in_ft, out_ft))
        if bias:
            self.bias = Parameter(torch.Tensor(out_ft))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)

    def forward(self, x: torch.Tensor, G: torch.Tensor):
        x = x.matmul(self.weight)
        if self.bias is not None:
            x = x + self.bias
        x = G.matmul(x)
        return x

#===========================================================
class HGCN_layer(nn.Module):
    def __init__(self, n_hid, dropout=0.5):
        super(HGCN_layer, self).__init__()
        self.hgc1 = conv_layer(n_hid, n_hid)
        self.act = nn.LeakyReLU()
        self.dropout = dropout  

    def forward(self, x, G):
        x = self.hgc1(x, G)
        x = self.act(x)
        x = F.dropout(x, self.dropout, training=self.training)
        return x

#=======================================================
class HD_sim(nn.Module):
    def __init__(self, h_dim, dropout=0.5):
        super(HD_sim, self).__init__()
        self.HD1 = HGCN_layer(h_dim)
        self.emb = nn.Linear(h_dim, h_dim)
        #self.norm = nn.LayerNorm(h_dim)
        self.dropout = dropout

    def forward(self, x, G):
        x = F.leaky_relu_(self.HD1(x, G))
        x1 = self.emb(x)
        #x1 = self.norm(x1)
        x1 += x  # residual
        return x1

#=============================================================
class OncoNet(nn.Module):
    def __init__(self, in_dim, hid_dim, out_dim, num_layer=3, dropout=0.5):
        super(OncoNet, self).__init__()

        
        self.mlp = nn.Linear(in_dim, hid_dim)

        self.convs = nn.ModuleList([HD_sim(hid_dim, dropout) for _ in range(num_layer)])
        self.fc2 = nn.Linear(hid_dim, out_dim)
        self.dropout = dropout

    def forward(self, x, G):
        x = F.leaky_relu(self.mlp(x))
        x = F.dropout(x, self.dropout, training=self.training)

        for conv in self.convs:
            x = conv(x, G)

        x = self.fc2(x)
        return F.log_softmax(x, dim=1)


# Train and evaluate for 11 Cancer type

Now we have all the data for each cancer, just we need to aupload it here and train the model 

All the results with the predictions will be saved 

In [None]:
def load_cancer_datasets(path):
    cancer_data = {}
    for file in os.listdir(path):
        if file.endswith('.pkl'):
            cancer_type = file.replace('OncoPlex_', '').replace('.pkl', '')
            filepath = os.path.join(path, file)
            with open(filepath, 'rb') as f:
                cancer_data[cancer_type] = pickle.load(f)
    return cancer_data

# Load all saved cancer datasets
dataset_path = '/content/specific_cancers'
cancer_data = load_cancer_datasets(dataset_path)
print(f"Loaded datasets for cancers: {list(cancer_data.keys())}")

Loaded datasets for cancers: ['HNSC', 'BLCA', 'BRCA', 'CESC', 'LIHC', 'STAD', 'ESCA', 'PRAD', 'LUAD']


In [None]:
def cal_auc(y_true, y_pred):
     pred = y_pred.cpu().detach().numpy()
     pred= np.exp(pred)
     pred = pred[:,1]
     # pred = (torch.sigmoid(y_pred) > 0.5).float()
     true = y_true.cpu().numpy()
     AUROC = roc_auc_score(true, pred)
     precision, recall, thresholds = precision_recall_curve(true, pred)
     AUPRC = auc(recall, precision)
     return AUROC, AUPRC

def accuracy_fn(y_true, y_pred):
   # pred=(torch.sigmoid(y_pred)>0.5).float()
    pred=torch.argmax(y_pred,dim=1).cpu().numpy()
    true=y_true.cpu().numpy()
    acc = (pred == true).mean()
    return acc


def f1_score_(y_true, y_pred):
    pred = y_pred.cpu().detach().numpy()
    pred = np.exp(pred)
    pred = (pred[:,1] > 0.5).astype(float)
    true = y_true.cpu().numpy()
    f1 = f1_score(true, pred)
    return f1
 
# ===========================================================================================
def train(model, optimizer, x, G, y, train_idx, weight):
    model.train()
    optimizer.zero_grad()
    logits = model(x, G)
    loss = F.nll_loss(logits[train_idx], y[train_idx], weight=torch.tensor(weight))
    train_auroc, train_auprc = cal_auc(y[train_idx], logits[train_idx])
    train_f1 = f1_score_(y[train_idx], logits[train_idx])
    loss.backward()
    optimizer.step()
    return loss.item(), train_auroc, train_auprc, train_f1

@torch.no_grad()
def val(model, x, G, y, val_idx, weight):
    model.eval()
    logits = model(x, G)
    loss = F.nll_loss(logits[val_idx], y[val_idx], weight=torch.tensor(weight))
    val_acc = accuracy_fn(y[val_idx], logits[val_idx])
    val_auroc, val_auprc = cal_auc(y[val_idx], logits[val_idx])
    val_f1 = f1_score_(y[val_idx], logits[val_idx])
    return loss.item(), val_acc, val_auroc, val_auprc, val_f1

@torch.no_grad()
def test(model, x, G, y, test_idx, unknown_idx,  weight):
    model.eval()
    logits = model(x, G)
    loss = F.nll_loss(logits[test_idx], y[test_idx], weight=torch.tensor(weight))
    test_acc = accuracy_fn(y[test_idx], logits[test_idx])
    test_auroc, test_auprc = cal_auc(y[test_idx], logits[test_idx])
    test_f1 = f1_score_(y[test_idx], logits[test_idx])

    test_genes = [nodes[i] for i in test_idx]
    unknown_genes = [nodes[i] for i in unknown_idx]

    prob_test = logits.exp().detach().cpu().numpy()[test_idx]
    prob_unknown = logits.exp().detach().cpu().numpy()[unknown_idx]

    test_results = pd.DataFrame(prob_test, index=test_genes, columns=["non_driver", "driver"])
    unknown_results = pd.DataFrame(prob_unknown, index=unknown_genes, columns=["non_driver", "driver"])
    final_results = pd.concat([test_results, unknown_results])

    return loss.item(), test_acc, test_auroc, test_auprc, test_f1, final_results, test_results, unknown_results



In [None]:
def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    seed = 42
    num_epoch = 200
    patience = 20

    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

    # Load all saved cancer datasets
    dataset_path = '/content/specific_cancers'
    cancer_data = load_cancer_datasets(dataset_path)
    print(f"Loaded datasets for cancers: {list(cancer_data.keys())}")

    param_grid = {
        'lr': [0.001, 0.0005, 0.005],
        'weight_decay': [0.001, 0.0001],
        'hidden_dim': [128, 64, 256],
        'dropout': [0.5, 0.4, 0.25],
        'num_layers': [2, 3, 4],
        'class_weight': [[1.0, 0.4], [1.0, 0.2], [1.0, 0.45]]
    }

    outer_k = 5
    inner_k = 3

    for cancer_type in cancer_data.keys():
        print(f"\n\n==================== Cancer Type: {cancer_type} ====================")

        x = cancer_data[cancer_type]['features']
        x_ = torch.tensor(x, dtype=torch.float32)

        y = cancer_data[cancer_type]['label']
        y_ = torch.tensor(y, dtype=torch.long)

        edge_index = cancer_data[cancer_type]['edge_index']
        G_ = torch.tensor(edge_index, dtype=torch.float32)

        known_idx = torch.where((y_ == 1) | (y_ == 0))[0]
        unknown_idx = torch.where(y_ == -1)[0]

        train_idx, test_idx = cancer_data[cancer_type]['train_idx'], cancer_data[cancer_type]['test_idx']
        nodes = cancer_data[cancer_type]['nodes']

        outer_results = []
        outer_kfold = StratifiedKFold(n_splits=outer_k, shuffle=True, random_state=seed)

        for fold, (train_val_idx, test_idx) in enumerate(outer_kfold.split(x_[train_idx], y_[train_idx])):
            print(f"\n=== Outer Fold {fold+1}/{outer_k} ===")

            assert len(set(train_val_idx.tolist()) & set(test_idx.tolist())) == 0

            inner_kfold = StratifiedKFold(n_splits=inner_k, shuffle=True, random_state=seed)
            best_hyperparams = None
            best_val_loss = float('inf')
            best_inner_metrics = None

            for params in ParameterGrid(param_grid):
                val_losses = []
                val_metrics = {'val_acc': [], 'val_auroc': [], 'val_auprc': [], 'val_f1': []}

                for inner_fold, (inner_train_idx, inner_val_idx) in enumerate(inner_kfold.split(x_[train_val_idx], y_[train_val_idx])):
                    assert len(set(inner_train_idx.tolist()) & set(inner_val_idx.tolist())) == 0

                    model = OncoNet(in_dim=x_.shape[1], hid_dim=params['hidden_dim'], out_dim=2, num_layer=params['num_layers'], dropout=params['dropout'])
                    optimizer = torch.optim.AdamW(model.parameters(), lr=params['lr'], weight_decay=params['weight_decay'])

                    best_inner_val_loss = float('inf')
                    patience_counter = 0

                    for epoch in range(num_epoch):
                        train_loss, _, _, _ = train(model, optimizer,  x_, G_, y_, inner_train_idx, weight=params['class_weight'])
                        val_loss, val_acc, val_auroc, val_auprc, val_f1 = val(model, x_, G_, y_, inner_val_idx, weight=params['class_weight'])

                        if val_loss < best_inner_val_loss:
                            best_inner_val_loss = val_loss
                            patience_counter = 0
                        else:
                            patience_counter += 1
                            if patience_counter >= patience:
                                print(f" Early stopping at epoch {epoch+1} in inner fold {inner_fold+1} (no improvement for {patience} epochs)")
                                break

                    val_losses.append(best_inner_val_loss.item())
                    val_metrics['val_acc'].append(val_acc)
                    val_metrics['val_auroc'].append(val_auroc)
                    val_metrics['val_auprc'].append(val_auprc)
                    val_metrics['val_f1'].append(val_f1)

                avg_val_loss = np.mean(val_losses)
                if avg_val_loss < best_val_loss:
                    best_val_loss = avg_val_loss
                    best_hyperparams = params
                    best_inner_metrics = {
                        'val_loss': avg_val_loss,
                        'val_acc': np.mean(val_metrics['val_acc']),
                        'val_auroc': np.mean(val_metrics['val_auroc']),
                        'val_auprc': np.mean(val_metrics['val_auprc']),
                        'val_f1': np.mean(val_metrics['val_f1'])
                    }

            print(f"\nBest hyperparameters for Outer Fold {fold+1}: {best_hyperparams}")

            model = OncoNet(in_dim=x_.shape[1], hid_dim=best_hyperparams['hidden_dim'], out_dim=2, num_layer=best_hyperparams['num_layers'], dropout=best_hyperparams['dropout'])
            optimizer = torch.optim.AdamW(model.parameters(), lr=best_hyperparams['lr'], weight_decay=best_hyperparams['weight_decay'])

            for epoch in range(num_epoch):
                print(f"Training Outer Fold {fold+1}, Epoch {epoch+1}/{num_epoch}")
                train_loss, train_auroc, train_auprc, train_f1 = train(model, optimizer, x_, G_, y_, train_val_idx, weight=best_hyperparams['class_weight'])

            test_loss, test_acc, auroc_test, auprc_test, test_f1, final_results, test_results, unknown_results = test(model, x_, G_, y_, test_idx,unknown_idx,  weight=best_hyperparams['class_weight'])

            outer_results.append({
                'test_loss': test_loss.item(),
                'test_acc': test_acc,
                'test_auroc': auroc_test,
                'test_auprc': auprc_test,
                'test_f1': test_f1
            })

            result_dir = f"results/{cancer_type}/fold_{fold+1}"
            os.makedirs(result_dir, exist_ok=True)
            test_results.to_csv(f"{result_dir}/test_results.csv")
            unknown_results.to_csv(f"{result_dir}/unknown_results.csv")
            final_results.to_csv(f"{result_dir}/final_results.csv")

        metrics_df = pd.DataFrame(outer_results)
        mean_metrics = metrics_df.mean()
        std_metrics = metrics_df.std()
        summary_df = pd.DataFrame({
            "Metric": mean_metrics.index,
            "Mean": mean_metrics.values,
            "Std": std_metrics.values
        })

        summary_dir = f"results/{cancer_type}"
        summary_df.to_csv(os.path.join(summary_dir, "outer_fold_summary.csv"), index=False)

        print("\nAverage Results Across Outer Folds:")
        print(summary_df)

if __name__ == "__main__":
    main()