In [None]:
import pandas as pd
import torch
from torch_geometric.data import HeteroData

In [None]:

df=pd.read_csv('finalized_data.csv')

In [None]:
df['score'] = df['score'].round(2)

Disease Specificity Index
There are genes (or variants) that are associated with multiple diseases (e.g., TNF) while others are associated with a small set of diseases or even to a single disease. The Disease Specificity Index (DSI) is a measure of this property of the genes (and variants). It reflects if a gene (or variant) is associated to several or fewer diseases.

Disease Pleiotropy Index
The rationale is similar to the DSI, but we consider if the multiple diseases associated with the gene (or variant) are similar among them (belong to the same MeSH disease class, e.g., Cardiovascular Diseases) or are completely different diseases and belong to different disease classes.

In [None]:

df = df.drop_duplicates()
print(f"Number of lines after removing duplicates: {len(df)}")


In [None]:
import torch
from torch_geometric.data import HeteroData

# Assuming df is your DataFrame

# Prepare gene node features using gene_index as indices
genes = df[['gene_index', 'geneDSI', 'geneDPI']].drop_duplicates().set_index('gene_index')
num_genes = genes.index.max() + 1  # Assuming indices start from 0
gene_features = torch.zeros((num_genes, 2), dtype=torch.float)
gene_features[genes.index] = torch.tensor(genes[['geneDSI', 'geneDPI']].values, dtype=torch.float)

# Prepare disease nodes
diseases = df[['disease_index']].drop_duplicates().set_index('disease_index')
num_diseases = diseases.index.max() + 1  # Assuming indices start from 0

# Prepare edge indices using existing indices
edge_index = torch.tensor([
    df['gene_index'].values,
    df['disease_index'].values
], dtype=torch.long)

# Edge attributes (scores)
edge_attr = torch.tensor(df['score'].values, dtype=torch.float).unsqueeze(1)

# Create HeteroData object
data = HeteroData()

# Add gene node features
data['gene'].x = gene_features

# Set the number of disease nodes (no features)
data['disease'].num_nodes = num_diseases

# Add edges between genes and diseases with edge attributes
data['gene', 'associates_with', 'disease'].edge_index = edge_index
data['gene', 'associates_with', 'disease'].edge_attr = edge_attr

In [88]:
data

HeteroData(
  gene={ x=[3608, 2] },
  disease={ num_nodes=23 },
  (gene, associates_with, disease)={
    edge_index=[2, 5955],
    edge_attr=[5955, 1],
  },
  (disease, rev_associates_with, gene)={
    edge_index=[2, 5955],
    edge_attr=[5955, 1],
  }
)

In [None]:
import torch_geometric.transforms as T
# Add reverse edges to allow message passing in both directions
# Convert the graph to undirected (adds reverse edges)
data = T.ToUndirected()(data)

# Define the edge types for splitting
edge_types = ('gene', 'associates_with', 'disease')
rev_edge_types = ('disease', 'rev_associates_with', 'gene')  # Reverse edge type

# Perform RandomLinkSplit with corrected parameter
transform = T.RandomLinkSplit(
    num_val=0.1,                     # 10% for validation
    num_test=0.1,                    # 10% for testing
    disjoint_train_ratio=0.3,        # 30% of training edges for supervision
    neg_sampling_ratio=2.0,          # Negative edge ratio for evaluation
    is_undirected=True,              # Graph is undirected
    add_negative_train_samples=False,  # Negative edges generated on-the-fly during training
    edge_types=edge_types,
    rev_edge_types=rev_edge_types,
)

In [None]:
train_data, val_data, test_data = transform(data)

In [89]:
print("Training data:")
print("==============")
print(train_data)
print()
print("Validation data:")
print("================")
print(val_data)

print("Training data edges (gene -> disease):", train_data["gene", "associates_with", "disease"].num_edges)
print("Training data edge label index size:", train_data["gene", "associates_with", "disease"].edge_label_index.size(1))
print("Training data edges (disease -> gene):", train_data["disease", "rev_associates_with", "gene"].num_edges)
print("Training data edge label min:", train_data["gene", "associates_with", "disease"].edge_label.min())
print("Training data edge label max:", train_data["gene", "associates_with", "disease"].edge_label.max())

print("Validation data edges (gene -> disease):", val_data["gene", "associates_with", "disease"].num_edges)
print("Validation data edge label index size:", val_data["gene", "associates_with", "disease"].edge_label_index.size(1))
print("Validation data edges (disease -> gene):", val_data["disease", "rev_associates_with", "gene"].num_edges)
print("Validation data edge label bincount:", val_data["gene", "associates_with", "disease"].edge_label.long().bincount().tolist())


Training data:
HeteroData(
  gene={ x=[3608, 2] },
  disease={ num_nodes=23 },
  (gene, associates_with, disease)={
    edge_index=[2, 3336],
    edge_attr=[3336, 1],
    edge_label=[1429],
    edge_label_index=[2, 1429],
  },
  (disease, rev_associates_with, gene)={
    edge_index=[2, 3336],
    edge_attr=[3336, 1],
  }
)

Validation data:
HeteroData(
  gene={ x=[3608, 2] },
  disease={ num_nodes=23 },
  (gene, associates_with, disease)={
    edge_index=[2, 4765],
    edge_attr=[4765, 1],
    edge_label=[1785],
    edge_label_index=[2, 1785],
  },
  (disease, rev_associates_with, gene)={
    edge_index=[2, 4765],
    edge_attr=[4765, 1],
  }
)
Training data edges (gene -> disease): 3336
Training data edge label index size: 1429
Training data edges (disease -> gene): 3336
Training data edge label min: tensor(1.)
Training data edge label max: tensor(1.)
Validation data edges (gene -> disease): 4765
Validation data edge label index size: 1785
Validation data edges (disease -> gene): 4765

In [29]:
for split_name, split_data in [('Train', train_data), ('Validation', val_data), ('Test', test_data)]:
    edge = split_data['gene', 'associates_with', 'disease']
    print(f"{split_name} Data:")
    print(f"  edge_label_index shape: {edge.edge_label_index.shape}")
    print(f"  edge_label shape: {edge.edge_label.shape}")
    print(f"  Max edge_label_index: {edge.edge_label_index.max().item()}")
    print()

Train Data:
  edge_label_index shape: torch.Size([2, 1429])
  edge_label shape: torch.Size([1429])
  Max edge_label_index: 3606

Validation Data:
  edge_label_index shape: torch.Size([2, 1785])
  edge_label shape: torch.Size([1785])
  Max edge_label_index: 3600

Test Data:
  edge_label_index shape: torch.Size([2, 1785])
  edge_label shape: torch.Size([1785])
  Max edge_label_index: 3604



In [None]:
from torch_geometric.loader import LinkNeighborLoader

# Define the edge type for link prediction
edge_label_index = ('gene', 'associates_with', 'disease')

# Batch size (adjust based on your memory constraints)
batch_size = 128

# Training Loader
train_loader = LinkNeighborLoader(
    data=train_data,
    num_neighbors=[10, 5],  # Number of neighbors to sample at each hop
    batch_size=batch_size,
    edge_label_index=(edge_label_index, train_data[edge_label_index].edge_index),
    edge_label=train_data[edge_label_index].edge_label,
    neg_sampling_ratio=2.0,  # Negative samples per positive sample
    shuffle=True,
)

# Validation Loader
val_loader = LinkNeighborLoader(
    data=val_data,
    num_neighbors=[10, 5],
    batch_size=batch_size,
    edge_label_index=(edge_label_index, val_data[edge_label_index].edge_index),
    edge_label=val_data[edge_label_index].edge_label,
    neg_sampling_ratio=0.0,  # Use fixed negative samples
    shuffle=False,
)

# Test Loader
test_loader = LinkNeighborLoader(
    data=test_data,
    num_neighbors=[10, 5],
    batch_size=batch_size,
    edge_label_index=(edge_label_index, test_data[edge_label_index].edge_index),
    edge_label=test_data[edge_label_index].edge_label,
    neg_sampling_ratio=0.0,
    shuffle=False,
)

In [None]:
batch = next(iter(train_loader))
print("Batch edge_label_index shape:", batch['edge_label_index'].shape)
print("Batch edge_label shape:", batch['edge_label'].shape)
print("Batch positive labels:", (batch['edge_label'] == 1).sum().item())
print("Batch negative labels:", (batch['edge_label'] == 0).sum().item())

Verification of edge_label and edge_label_index attributes in  train_data, val_data, and test_data objects:

1. Check for the existence of edge_label and edge_label_index
Ensure that both edge_label and edge_label_index are present in each split for the specified edge type.

In [None]:


for split_name, split_data in [('Train', train_data), ('Validation', val_data), ('Test', test_data)]:
    edge = split_data['gene', 'associates_with', 'disease']
    print(f"{split_name} Data:")
    print(f"  edge_label_index exists: {'edge_label_index' in edge}")
    print(f"  edge_label exists: {'edge_label' in edge}")
    print()

2. Check the shapes of edge_label and edge_label_index

Verify that the number of labels matches the number of edges.

In [None]:
for split_name, split_data in [('Train', train_data), ('Validation', val_data), ('Test', test_data)]:
    edge = split_data['gene', 'associates_with', 'disease']
    if 'edge_label' in edge and 'edge_label_index' in edge:
        num_edges = edge.edge_label_index.size(1)
        num_labels = edge.edge_label.size(0)
        print(f"{split_name} Data:")
        print(f"  Number of edges: {num_edges}")
        print(f"  Number of labels: {num_labels}")
        print()

3. Check the distribution of edge labels

Count the number of positive and negative labels to ensure they are correctly assigned.

In [None]:
for split_name, split_data in [('Train', train_data), ('Validation', val_data), ('Test', test_data)]:
    edge = split_data['gene', 'associates_with', 'disease']
    labels = edge.edge_label
    pos_labels = (labels == 1).sum().item()
    neg_labels = (labels == 0).sum().item()
    print(f"{split_name} Data:")
    print(f"  Positive labels (edges): {pos_labels}")
    print(f"  Negative labels (non-edges): {neg_labels}")
    print()

4. Inspect a small sample of edges and labels

View a few edges and their corresponding labels to manually verify correctness.

In [None]:
split_data = train_data  # Replace with val_data or test_data as needed
edge = split_data['gene', 'associates_with', 'disease']

# View the first 5 edges and labels
print("First 5 edges:")
print(edge.edge_label_index[:, :5])
print("Corresponding labels:")
print(edge.edge_label[:5])

5. Verify that positive edges correspond to actual connections

For a few positive labels, confirm that the edges exist in original DataFrame or data object

In [None]:
# Assuming you have access to the original edge indices
original_edges = data['gene', 'associates_with', 'disease'].edge_index

# Compare with the training edges
train_edges = train_data['gene', 'associates_with', 'disease'].edge_label_index

# Check if training edges are a subset of original edges
edge_pairs = set(map(tuple, original_edges.t().tolist()))
train_edge_pairs = set(map(tuple, train_edges.t().tolist()))
print("All training edges are in the original edges:", train_edge_pairs.issubset(edge_pairs))

6. Confirm that negative samples are not in the original edges
For validation and test splits, ensure that negative samples do not exist in the actual graph

In [None]:
split_name = 'Validation'  # or 'Test'
split_data = val_data if split_name == 'Validation' else test_data
edge = split_data['gene', 'associates_with', 'disease']

# Get negative edge indices
neg_edge_indices = edge.edge_label_index[:, edge.edge_label == 0]

# Convert to set of edge pairs
neg_edge_pairs = set(map(tuple, neg_edge_indices.t().tolist()))

# Check if any negative edges are in the original edges
common_edges = neg_edge_pairs.intersection(edge_pairs)
print(f"Number of negative {split_name.lower()} edges in original edges:", len(common_edges))

7. Check the LinkNeighborLoader outputs
You can also check the batches produced by the LinkNeighborLoader.