In [2]:
import torch_geometric
from torch_geometric.data import Data
import torch
import pandas as pd
import numpy as np
import h5py
from torch_geometric.utils import add_self_loops
from sklearn.model_selection import train_test_split

In [3]:
health = pd.read_csv('../data/real/smg_data/health.tsv', sep ='\t')

In [4]:
health

Unnamed: 0,entrez,symbol,pubmed_id,type,organ_system,organ_site,tissue_type,method
0,92,ACVR2A,31645727,WES-Gene Panel,Developmental gastrointestinal,hepatobiliary,normal_hepatocytes_and_cirrhotic_hepatocytes,dN/dS
1,207,AKT1,31187483,Gene Panel,Gynecologic,endometrium,normal_endometrium,Paper-specific
2,213,ALB,31645727,WES-Gene Panel,Developmental gastrointestinal,hepatobiliary,normal_hepatocytes_and_cirrhotic_hepatocytes,dN/dS
3,213,ALB,30955891,WES-Gene Panel,Developmental gastrointestinal,hepatobiliary,non-dysplastic_hepatocytes_from_diseased_liver,Recurrence-based
4,338,APOB,30955891,WES-Gene Panel,Developmental gastrointestinal,hepatobiliary,non-dysplastic_hepatocytes_from_diseased_liver,Recurrence-based
...,...,...,...,...,...,...,...,...
162,171023,ASXL1,25326804,WES,Hematologic and lymphatic,blood,normal_blood_cells,Recurrence-based
163,196528,ARID2,30337457,Gene Panel,Core gastrointestinal,esophagus,normal_oesophageal_epithelium,dNdScv
164,196528,ARID2,33029006,Gene Panel,Skin,skin,normal_melanocytes,Paper-specific
165,196528,ARID2,31996850,WGS,Thoracic,lung,normal_bronchial_epithelium,dNdScv


In [5]:
def load_h5_graph(PATH, ppi):
    f = h5py.File(f'{PATH}/{ppi}_multiomics.h5', 'r')
    # Build edge indices from the network matrix
    network = f['network'][:]
    src, dst = np.nonzero(network)
    edge_index = torch.tensor(np.vstack((src, dst)), dtype=torch.long)

    # Load node features and assign a node "name" attribute if desired
    features = f['features'][:]
    x = torch.from_numpy(features)
    num_nodes = x.size(0)
    node_name = f['gene_names'][...,-1].astype(str)

    # Retrieve gene names and create a mapping: gene name -> node index
    gene_name = f['gene_names'][...,-1].astype(str)
    gene_map = {g: i for i, g in enumerate(gene_name)}  # gene name -> node index

    # Originally, the code combined several label arrays but then reads a health.tsv.
    # Here we read the health.tsv file and extract the symbols.
    # Ensure that PATH is defined in your environment.
    label_df = pd.read_csv(PATH + 'health.tsv', sep='\t').astype(str) # TODO fix this for druggable gene prediction
    label_symbols = label_df['symbol'].tolist()

    # Determine positive nodes: indices that appear in both the health.tsv and gene_name list
    mask = [gene_map[g] for g in sorted(list(set(label_symbols) & set(gene_name)))]

    # Randomly select negative samples from those nodes not in the positive mask.
    np.random.seed(42)
    all_indices = set(range(len(gene_name)))
    negative_candidates = sorted(list(all_indices - set(mask)))
    neg_sample_size = min(len(mask), len(gene_name) - len(mask))
    neg_mask = np.random.choice(negative_candidates, size=neg_sample_size, replace=False).tolist()

    print("Negative mask indices:", neg_mask)

    # Create a label vector (1 for positive, 0 for negative)
    y = torch.zeros(len(gene_name), dtype=torch.float)
    y[mask] = 1
    y = y.unsqueeze(1)  # shape: [num_nodes, 1]

    # Combine positive and negative indices for the split
    final_mask = mask + neg_mask
    final_labels = y[final_mask].squeeze(1).numpy()  # converting to numpy for stratification

    # Split indices into train, test, and validation sets using stratification
    train_idx, test_idx, _, _ = train_test_split(final_mask, final_labels, test_size=0.2,
                                                    shuffle=True, stratify=final_labels, random_state=42)
    train_idx, val_idx, _, _ = train_test_split(train_idx, y[train_idx].numpy().squeeze(1),
                                                test_size=0.2, shuffle=True,
                                                stratify=y[train_idx].numpy().squeeze(1), random_state=42)

    # Create boolean masks for all nodes
    train_mask = torch.zeros(len(gene_name), dtype=torch.bool)
    test_mask = torch.zeros(len(gene_name), dtype=torch.bool)
    val_mask = torch.zeros(len(gene_name), dtype=torch.bool)
    train_mask[train_idx] = True
    test_mask[test_idx] = True
    val_mask[val_idx] = True

    # Add self-loops to the edge_index
    edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes)

    # Build the PyTorch Geometric data object
    data = Data(x=x, edge_index=edge_index, y=y)
    data.train_mask = train_mask.unsqueeze(1)  # unsqueeze if you want to mimic the original shape
    data.test_mask = test_mask.unsqueeze(1)
    data.val_mask = val_mask.unsqueeze(1)
    data.name = node_name  # optional: storing node names

    return data#, gene_map

In [11]:
cpdb.x

tensor([[0.0000, 0.0053, 0.0000,  ..., 0.0000, 0.4493, 0.0000],
        [0.0000, 0.0053, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0633, 0.0158, 0.0557,  ..., 0.1903, 0.0000, 0.0000],
        ...,
        [0.0628, 0.0619, 0.1089,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0053, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0106, 0.0000,  ..., 0.0000, 0.0000, 0.0402]],
       dtype=torch.float64)

In [7]:
PATH = '../data/real/smg_data/'
ppi = 'CPDB'

cpdb = load_h5_graph(PATH, ppi)

Negative mask indices: [2159, 7175, 8188, 407, 4946, 6510, 2580, 13553, 4724, 8587, 7405, 6021, 12333, 5345, 9624, 11971, 2529, 3514, 5420, 2226, 5388, 4962, 4581, 764, 549, 7506, 9885, 12426, 10669, 11260, 5808, 9419, 4233, 7259, 107, 9819, 1144, 6153, 7574, 5089, 8984, 11719, 8903, 5748, 8351, 12554, 459, 4614, 1743, 4353, 4937, 7393, 3472, 5846, 5464, 9088, 3926, 3203, 10013, 10141, 3114, 2912, 3156, 9933, 4497, 4079, 2708, 10816, 1807, 10743, 5838, 4853, 8331, 3519, 4175, 4359, 1243, 2661, 3324, 3475, 8090, 8662, 6689, 10699, 4222, 9223, 3364, 6586, 7973, 5702]


In [54]:
ppis = ['CPDB', 'IRefIndex_2015', 'PCNet', 'STRINGdb']#'Multinet',

data = {}

for ppi in ppis:
    data[ppi] = load_h5_graph(PATH, ppi)

Negative mask indices: [2159, 7175, 8188, 407, 4946, 6510, 2580, 13553, 4724, 8587, 7405, 6021, 12333, 5345, 9624, 11971, 2529, 3514, 5420, 2226, 5388, 4962, 4581, 764, 549, 7506, 9885, 12426, 10669, 11260, 5808, 9419, 4233, 7259, 107, 9819, 1144, 6153, 7574, 5089, 8984, 11719, 8903, 5748, 8351, 12554, 459, 4614, 1743, 4353, 4937, 7393, 3472, 5846, 5464, 9088, 3926, 3203, 10013, 10141, 3114, 2912, 3156, 9933, 4497, 4079, 2708, 10816, 1807, 10743, 5838, 4853, 8331, 3519, 4175, 4359, 1243, 2661, 3324, 3475, 8090, 8662, 6689, 10699, 4222, 9223, 3364, 6586, 7973, 5702]
Negative mask indices: [10301, 5539, 12127, 2644, 9425, 7448, 1133, 5991, 2999, 8499, 10557, 11115, 219, 6070, 10655, 7243, 8323, 1465, 10690, 360, 10025, 9909, 5207, 10090, 10486, 7881, 6911, 10712, 9641, 3274, 428, 4433, 9636, 9032, 7741, 4298, 9563, 8489, 8725, 3966, 5140, 10902, 8318, 11592, 7728, 11156, 7641, 681, 4426, 7184, 11872, 10812, 8947, 4087, 2477, 2739, 619, 9529, 1411, 11827, 4286, 6540, 10087, 10666, 11119, 

In [58]:
data['CPDB'].x.shape

torch.Size([13627, 64])