## Dataset Class

### GNN Link Prediction Workflow

#### 1. Data Preparation
- [ ] Obtain dataset (e.g., Cora, CiteSeer)
- [ ] Preprocess data:
  - [ ] Normalize node features
  - [ ] Convert to undirected graph (if needed)
- [x] Split edges:
  - [x] Training set (80%)
  - [x] Validation set (10%)
  - [x] Test set (10%)
- [x] Generate negative samples:
  - [ ] Validation negatives (1:1 ratio)
  - [ ] Test negatives (1:1 ratio)

#### 2. Model Architecture
- [ ] Implement GNN encoder:
  - [ ] Choose layer type (GCN/GraphSAGE)
  - [ ] 2-layer architecture
  - [ ] ReLU activation
- [ ] Implement decoder:
  - [ ] Dot product scorer
  - [ ] Sigmoid activation

#### 3. Training Setup
- [ ] Initialize optimizer (Adam)
- [ ] Set learning rate (~0.01)
- [ ] Define loss function (BCE)
- [ ] Implement negative sampling:
  - [ ] Dynamic per-epoch sampling
  - [ ] 1:1 positive:negative ratio

#### 4. Evaluation Metrics
- [ ] AUC-ROC calculation
- [ ] Accuracy/F1-score
- [ ] Precision-Recall curve

#### 5. PyTorch Implementation
- [ ] Environment setup:
  ```bash
  pip install torch torch_geometric

In [1]:
import torch
import pandas as pd
from torch_geometric.data import Dataset, Data
from sklearn.preprocessing import LabelEncoder
import numpy as np

class LinkPredictionDataset(Dataset):
    def __init__(self, net_csv_path, label_csv_path, transform=None):
        super().__init__(transform=transform)
        
        # Load data
        self.net_df = pd.read_excel(net_csv_path)
        self.label_df = pd.read_excel(label_csv_path)
        
        # Create node mappings
        all_nodes = set(self.net_df['Regulator']).union(set(self.net_df['Target']))
        self.node_to_idx = {node: idx for idx, node in enumerate(all_nodes)}
        self.idx_to_node = {idx: node for node, idx in self.node_to_idx.items()}
        
        # Create edge index from network data
        regulators = [self.node_to_idx[reg] for reg in self.net_df['Regulator']]
        targets = [self.node_to_idx[target] for target in self.net_df['Target']]
        self.edge_index = torch.tensor([regulators, targets], dtype=torch.long)
        
        # Create simple node features (just identity matrix for now)
        self.num_nodes = len(all_nodes)
        self.x = torch.eye(self.num_nodes, dtype=torch.float)
        
        # Prepare labels for link prediction (binary: edge exists or not)
        self.edge_labels = torch.ones(len(regulators), dtype=torch.float)
        
    def len(self):
        return 1  # Single graph for link prediction
    
    def get(self, idx):
        # Return the graph data
        data = Data(
            x=self.x,
            edge_index=self.edge_index,
            edge_attr=self.edge_labels,
            num_nodes=self.num_nodes
        )
        
        if self.transform:
            data = self.transform(data)
            
        return data

# Usage example:
dataset = LinkPredictionDataset('data/net.xlsx', 'data/label.xlsx')
data = dataset[0]  # Get the graph
print(f"Number of nodes: {data.num_nodes}")
print(f"Number of edges: {data.edge_index.size(1)}")

Number of nodes: 3322
Number of edges: 10747


In [4]:

import plotly.graph_objects as go
import plotly.express as px

def visualize_graph_plotly(dataset, width=800, height=600):
    """
    Create an interactive Plotly visualization of the graph
    """
    data = dataset[0]
    
    # Create NetworkX graph
    G = nx.Graph()
    
    # Add nodes with labels
    for idx, node_name in dataset.idx_to_node.items():
        G.add_node(idx, label=node_name)
    
    # Add edges
    edge_index = data.edge_index.numpy()
    edges = [(edge_index[0][i], edge_index[1][i]) for i in range(edge_index.shape[1])]
    G.add_edges_from(edges)
    
    # Generate layout
    pos = nx.spring_layout(G, k=2, iterations=50)
    
    # Determine node types and colors
    node_types = {}
    for _, row in dataset.net_df.iterrows():
        regulator = row['Regulator']
        target = row['Target']
        reg_type = row['RegulatorType']
        target_type = row['TargetType']
        
        if regulator in dataset.node_to_idx:
            node_types[dataset.node_to_idx[regulator]] = reg_type
        if target in dataset.node_to_idx:
            node_types[dataset.node_to_idx[target]] = target_type
    
    # Prepare edge traces
    edge_x = []
    edge_y = []
    for edge in G.edges():
        x0, y0 = pos[edge[0]]
        x1, y1 = pos[edge[1]]
        edge_x.extend([x0, x1, None])
        edge_y.extend([y0, y1, None])
    
    edge_trace = go.Scatter(
        x=edge_x, y=edge_y,
        line=dict(width=1, color='#888'),
        hoverinfo='none',
        mode='lines'
    )
    
    # Prepare node traces by type
    color_map = {'lncRNA': '#1f77b4', 'miRNA': '#ff7f0e', 'mRNA': '#2ca02c'}
    node_traces = []
    
    for node_type, color in color_map.items():
        # Filter nodes by type
        node_indices = [idx for idx, ntype in node_types.items() if ntype == node_type]
        
        if node_indices:
            node_x = [pos[idx][0] for idx in node_indices]
            node_y = [pos[idx][1] for idx in node_indices]
            node_labels = [dataset.idx_to_node[idx] for idx in node_indices]
            
            node_trace = go.Scatter(
                x=node_x, y=node_y,
                mode='markers+text',
                text=node_labels,
                textposition="middle center",
                textfont=dict(size=10, color='white'),
                hoverinfo='text',
                hovertext=[f"{label}<br>Type: {node_type}" for label in node_labels],
                marker=dict(
                    size=20,
                    color=color,
                    line=dict(width=2, color='white')
                ),
                name=node_type
            )
            node_traces.append(node_trace)
    
    # Create figure
    fig = go.Figure(data=[edge_trace] + node_traces,
                   layout=go.Layout(
                        title=dict(
                            text='Interactive Regulatory Network Graph',
                            x=0.5,
                            font=dict(size=16)
                        ),
                        showlegend=True,
                        hovermode='closest',
                        margin=dict(b=20,l=5,r=5,t=40),
                        annotations=[ dict(
                            text="Hover over nodes for details. Use mouse to zoom and pan.",
                            showarrow=False,
                            xref="paper", yref="paper",
                            x=0.005, y=-0.002,
                            xanchor="left", yanchor="bottom",
                            font=dict(color="#888", size=12)
                        )],
                        xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
                        yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
                        width=width,
                        height=height,
                        plot_bgcolor='white'
                        ))
    
    # Show the plot
    fig.show()
    
    # Print statistics
    print(f"Interactive Graph Created!")
    print(f"Number of nodes: {data.num_nodes}")
    print(f"Number of edges: {data.edge_index.size(1)}")
    print(f"Node types: {list(color_map.keys())}")
    print("Features: Zoom, pan, hover for details, toggle node types in legend")

def print_sample_data(dataset):
    """
    Print sample of the original data and processed tensors
    """
    print("Sample Network Data:")
    print(dataset.net_df.head())
    print("\nSample Label Data:")
    print(dataset.label_df.head())
    
    data = dataset[0]
    print(f"\nProcessed PyG Data Object:")
    print(f"Node features shape: {data.x.shape}")
    print(f"Edge index shape: {data.edge_index.shape}")
    print(f"First few edges (source -> target):")
    edge_index = data.edge_index.numpy()
    for i in range(min(5, edge_index.shape[1])):
        source_name = dataset.idx_to_node[edge_index[0][i]]
        target_name = dataset.idx_to_node[edge_index[1][i]]
        print(f"  {source_name} -> {target_name}")

# Example usage (uncomment when you have the CSV files):
dataset = LinkPredictionDataset('data/net.xlsx', 'data/label.xlsx')
print_sample_data(dataset)
visualize_graph_plotly(dataset)  # Interactive Plotly version

Sample Network Data:
   Unnamed: 0  Regulator      Target RegulatorType TargetType  \
0           1      NEAT1  miR-194-5p        lncRNA      miRNA   
1           2  LINC00460     miR-206        lncRNA      miRNA   
2           3     MALAT1     miR-497        lncRNA      miRNA   
3           4       MIAT     miR-29b        lncRNA      miRNA   
4           5      CASC7     miR-30c        lncRNA      miRNA   

  regulatory_Mechanism  
0      ceRNA or sponge  
1      ceRNA or sponge  
2      ceRNA or sponge  
3      ceRNA or sponge  
4      ceRNA or sponge  

Sample Label Data:
   Unnamed: 0      Regulator  cell.proliferation  cell.invasion  \
0           1  1700020I14Rik                   0              0   
1           2            7SK                   1              0   
2           3            91H                   0              1   
3           4        A2M-AS1                   1              1   
4           5          AATBC                   1              0   

   cell.migrati

NameError: name 'nx' is not defined

In [5]:
## split into test val train
from torch_geometric.transforms import RandomLinkSplit

# Apply the transform to your data object
transform = RandomLinkSplit(
    num_val=0.1, num_test=0.1,  # 10% val, 10% test
    is_undirected=True,         # Set to True if your graph is undirected
    add_negative_train_samples=True,
    neg_sampling_ratio=1.0      # 1:1 positive:negative
)

data = dataset[0]
train_data, val_data, test_data = transform(data)

print(train_data)
print(val_data)
print(test_data)

Data(x=[3322, 3322], edge_index=[2, 8858], edge_attr=[8858], num_nodes=3322, edge_label=[8858], edge_label_index=[2, 8858])
Data(x=[3322, 3322], edge_index=[2, 8858], edge_attr=[8858], num_nodes=3322, edge_label=[1106], edge_label_index=[2, 1106])
Data(x=[3322, 3322], edge_index=[2, 9964], edge_attr=[9964], num_nodes=3322, edge_label=[1106], edge_label_index=[2, 1106])


In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from sklearn.metrics import roc_auc_score, accuracy_score, f1_score, precision_recall_curve, auc

# 1. GNN Encoder + Dot Product Decoder
class GCNLinkPredictor(nn.Module):
    def __init__(self, in_channels, hidden_channels):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        
    def encode(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        return x

    def decode(self, z, edge_label_index):
        # Dot product decoder
        src, dst = edge_label_index
        return (z[src] * z[dst]).sum(dim=1)

    def forward(self, x, edge_index, edge_label_index):
        z = self.encode(x, edge_index)
        logits = self.decode(z, edge_label_index)
        return logits

# 2. Prepare data splits (already done with RandomLinkSplit)
# train_data, val_data, test_data = transform(data)

# 3. Training setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GCNLinkPredictor(in_channels=train_data.x.size(1), hidden_channels=32).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
loss_fn = nn.BCEWithLogitsLoss()

def train():
    model.train()
    optimizer.zero_grad()
    out = model(
        train_data.x.to(device),
        train_data.edge_index.to(device),
        train_data.edge_label_index.to(device)
    )
    loss = loss_fn(out, train_data.edge_label.to(device))
    loss.backward()
    optimizer.step()
    return loss.item()

@torch.no_grad()
def test(data):
    model.eval()
    logits = model(
        data.x.to(device),
        data.edge_index.to(device),
        data.edge_label_index.to(device)
    )
    probs = torch.sigmoid(logits).cpu().numpy()
    labels = data.edge_label.cpu().numpy()
    auc_score = roc_auc_score(labels, probs)
    preds = (probs > 0.5).astype(int)
    acc = accuracy_score(labels, preds)
    f1 = f1_score(labels, preds)
    precision, recall, _ = precision_recall_curve(labels, probs)
    pr_auc = auc(recall, precision)
    return auc_score, acc, f1, pr_auc

# 4. Training loop
epochs = 100
for epoch in range(1, epochs+1):
    loss = train()
    if epoch % 10 == 0 or epoch == 1:
        val_auc, val_acc, val_f1, val_pr_auc = test(val_data)
        print(f"Epoch {epoch:03d} | Loss: {loss:.4f} | Val AUC: {val_auc:.4f} | Val F1: {val_f1:.4f}")

# 5. Final evaluation
test_auc, test_acc, test_f1, test_pr_auc = test(test_data)
print(f"\nTest AUC: {test_auc:.4f} | Test Accuracy: {test_acc:.4f} | Test F1: {test_f1:.4f} | Test PR AUC: {test_pr_auc:.4f}")

Epoch 001 | Loss: 0.6931 | Val AUC: 0.8626 | Val F1: 0.6667
Epoch 010 | Loss: 0.5889 | Val AUC: 0.8328 | Val F1: 0.6667
Epoch 020 | Loss: 0.4555 | Val AUC: 0.7296 | Val F1: 0.6496
Epoch 030 | Loss: 0.4259 | Val AUC: 0.7277 | Val F1: 0.6530
Epoch 040 | Loss: 0.4035 | Val AUC: 0.7343 | Val F1: 0.6593
Epoch 050 | Loss: 0.3738 | Val AUC: 0.7183 | Val F1: 0.6570
Epoch 060 | Loss: 0.3361 | Val AUC: 0.7119 | Val F1: 0.6563
Epoch 070 | Loss: 0.2871 | Val AUC: 0.7051 | Val F1: 0.6550
Epoch 080 | Loss: 0.2457 | Val AUC: 0.7062 | Val F1: 0.6554
Epoch 090 | Loss: 0.2127 | Val AUC: 0.7048 | Val F1: 0.6582
Epoch 100 | Loss: 0.1843 | Val AUC: 0.7110 | Val F1: 0.6683

Test AUC: 0.7088 | Test Accuracy: 0.6401 | Test F1: 0.6716 | Test PR AUC: 0.7680
