In [1]:
import pandas as pd
import numpy as np
import random
from itertools import combinations
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.model_selection import train_test_split
from torch_geometric.nn import HANConv

In [2]:
# Fix random seeds for reproducibility
random.seed(42)
np.random.seed(42)

# Set the number of each node type
num_lnc = 30
num_mi = 10
num_m = 15

# Generate node IDs
lnc_ids = [f'lnc{i+1}' for i in range(num_lnc)]
mi_ids = [f'mi{i+1}' for i in range(num_mi)]
m_ids = [f'm{i+1}' for i in range(num_m)]

# Build edge list for the heterogeneous network
edge_list = []
# Step 1: Randomly add lncRNA-miRNA and lncRNA-mRNA relations
for lnc in lnc_ids:
    mi_samples = random.sample(mi_ids, random.randint(2, 5))  # More connections
    for mi in mi_samples:
        edge_list.append([lnc, mi, 'lncRNA', 'miRNA', 'regulate'])
    m_samples = random.sample(m_ids, random.randint(2, 4))    # More connections
    for m in m_samples:
        edge_list.append([lnc, m, 'lncRNA', 'mRNA', 'coexpression'])

# Step 1b: Randomly add miRNA-mRNA regulate edges
for mi in mi_ids:
    m_samples = random.sample(m_ids, random.randint(3, 6))    # Denser connections
    for m in m_samples:
        edge_list.append([mi, m, 'miRNA', 'mRNA', 'regulate'])

# Step 2: Manually create 4-hop closed paths to guarantee some meta-path-2 links
for i in range(5):  # Manually build 5 such structures
    lncA, lncB = random.sample(lnc_ids, 2)
    mi1, mi2 = random.sample(mi_ids, 2)
    m = random.choice(m_ids)
    edge_list.append([lncA, mi1, 'lncRNA', 'miRNA', 'regulate'])
    edge_list.append([lncB, mi2, 'lncRNA', 'miRNA', 'regulate'])
    edge_list.append([mi1, m, 'miRNA', 'mRNA', 'regulate'])
    edge_list.append([mi2, m, 'miRNA', 'mRNA', 'regulate'])

# Convert the edge list to a DataFrame
edges_df = pd.DataFrame(edge_list, columns=['A', 'B', 'A_type', 'B_type', 'relation'])

In [3]:
print(edges_df)

         A    B  A_type B_type      relation
0     lnc1  mi1  lncRNA  miRNA      regulate
1     lnc1  mi5  lncRNA  miRNA      regulate
2     lnc1   m4  lncRNA   mRNA  coexpression
3     lnc1   m3  lncRNA   mRNA  coexpression
4     lnc2  mi9  lncRNA  miRNA      regulate
..     ...  ...     ...    ...           ...
250    mi6   m1   miRNA   mRNA      regulate
251   lnc2  mi8  lncRNA  miRNA      regulate
252  lnc19  mi9  lncRNA  miRNA      regulate
253    mi8  m15   miRNA   mRNA      regulate
254    mi9  m15   miRNA   mRNA      regulate

[255 rows x 5 columns]


In [4]:
# Generate random features and labels for lncRNAs
lnc_labels = np.random.randint(0, 2, size=num_lnc)   # Binary classification
feat_dim = 8
lnc_feat = np.random.randn(num_lnc, feat_dim)
lnc_feat_df = pd.DataFrame({'lncRNA': lnc_ids, 'label': lnc_labels})
for i in range(feat_dim):
    lnc_feat_df[f'feat_{i}'] = lnc_feat[:, i]

In [5]:
print(lnc_feat_df.head())

  lncRNA  label    feat_0    feat_1    feat_2    feat_3    feat_4    feat_5  \
0   lnc1      0 -0.571380 -0.924083 -2.612549  0.950370  0.816445 -1.523876   
1   lnc2      1 -0.703344 -2.139621 -0.629475  0.597720  2.559488  0.394233   
2   lnc3      0 -0.600254  0.947440  0.291034 -0.635560 -1.021552 -0.161755   
3   lnc4      0 -0.229450  0.389349 -1.265119  1.091992  2.778313  1.193640   
4   lnc5      0 -1.009085 -1.583294  0.773700 -0.538142 -1.346678 -0.880591   

     feat_6    feat_7  
0 -0.428046 -0.742407  
1  0.122219 -0.515436  
2 -0.533649 -0.005528  
3  0.218638  0.881761  
4 -1.130552  0.134429  


In [6]:
# -------- Meta-path view construction functions --------
def metapath_lnc_mi_lnc(edges_df, lnc_list):
    """
    Construct the lncRNA–miRNA–lncRNA meta-path view.
    Returns edge_index for this view.
    """
    adj = {}
    for _, row in edges_df.iterrows():
        if row['A_type'] == 'lncRNA' and row['B_type'] == 'miRNA':
            adj.setdefault(row['A'], set()).add(row['B'])
    mi2lnc = {}
    for lnc, mis in adj.items():
        for mi in mis:
            mi2lnc.setdefault(mi, set()).add(lnc)
    edge_set = set()
    for lnc_set in mi2lnc.values():
        for u, v in combinations(sorted(lnc_set), 2):
            edge_set.add((u, v))
    node2idx = {nid: i for i, nid in enumerate(lnc_list)}
    edges = list(edge_set)
    if len(edges) > 0:
        edge_index = torch.tensor([
            [node2idx[u] for u, v in edges] + [node2idx[v] for u, v in edges],
            [node2idx[v] for u, v in edges] + [node2idx[u] for u, v in edges]
        ], dtype=torch.long)
    else:
        edge_index = torch.empty((2, 0), dtype=torch.long)
    return edge_index

def metapath_lnc_mi_m_mi_lnc(edges_df, lnc_list):
    """
    Construct the lncRNA–miRNA–mRNA–miRNA–lncRNA meta-path view.
    Returns edge_index for this view.
    """
    lnc2mi = {}
    for _, row in edges_df.iterrows():
        if row['A_type'] == 'lncRNA' and row['B_type'] == 'miRNA':
            lnc2mi.setdefault(row['A'], set()).add(row['B'])
    mi2m = {}
    for _, row in edges_df.iterrows():
        if row['A_type'] == 'miRNA' and row['B_type'] == 'mRNA':
            mi2m.setdefault(row['A'], set()).add(row['B'])
    m2mi = {}
    for mi, ms in mi2m.items():
        for m in ms:
            m2mi.setdefault(m, set()).add(mi)
    mi2lnc = {}
    for lnc, mis in lnc2mi.items():
        for mi in mis:
            mi2lnc.setdefault(mi, set()).add(lnc)
    edge_set = set()
    for lnc_start in lnc_list:
        for mi1 in lnc2mi.get(lnc_start, []):
            for m in mi2m.get(mi1, []):
                for mi2 in m2mi.get(m, []):
                    for lnc_end in mi2lnc.get(mi2, []):
                        if lnc_start != lnc_end:
                            u, v = sorted([lnc_start, lnc_end])
                            edge_set.add((u, v))
    node2idx = {nid: i for i, nid in enumerate(lnc_list)}
    edges = list(edge_set)
    if len(edges) > 0:
        edge_index = torch.tensor([
            [node2idx[u] for u, v in edges] + [node2idx[v] for u, v in edges],
            [node2idx[v] for u, v in edges] + [node2idx[u] for u, v in edges]
        ], dtype=torch.long)
    else:
        edge_index = torch.empty((2, 0), dtype=torch.long)
    return edge_index

In [7]:
# Build meta-path-based edge indices for lncRNA nodes
lnc_list = lnc_feat_df['lncRNA'].tolist()
edge_index1 = metapath_lnc_mi_lnc(edges_df, lnc_list)
edge_index2 = metapath_lnc_mi_m_mi_lnc(edges_df, lnc_list)
edge_index_list = [edge_index1, edge_index2]

print(f"View1 edges (lnc-mi-lnc): {edge_index1.size(1)//2}")
print(f"View2 edges (lnc-mi-m-mi-lnc): {edge_index2.size(1)//2}")

View1 edges (lnc-mi-lnc): 348
View2 edges (lnc-mi-m-mi-lnc): 435


In [8]:
# Prepare node features and labels for model input
feat_cols = [c for c in lnc_feat_df.columns if c.startswith('feat_')]
x = torch.tensor(lnc_feat_df[feat_cols].values, dtype=torch.float)
y = torch.tensor(lnc_feat_df['label'].values, dtype=torch.long)

# Train/validation/test split
idx = torch.arange(x.size(0))
train_idx, test_idx = train_test_split(idx, test_size=0.2, random_state=42)
train_idx, val_idx  = train_test_split(train_idx, test_size=0.25, random_state=42)  # 0.6 train, 0.2 val, 0.2 test

# PyG 2.x metadata definition: required for HANConv
node_types = ['lncRNA']
edge_types = [
    ('lncRNA', 'meta1', 'lncRNA'),  # Meta-path 1
    ('lncRNA', 'meta2', 'lncRNA'),  # Meta-path 2
]
metadata = (node_types, edge_types)

In [9]:
print(metadata)

(['lncRNA'], [('lncRNA', 'meta1', 'lncRNA'), ('lncRNA', 'meta2', 'lncRNA')])


In [10]:

# Define the HAN model (multi-view, lncRNA node classification)
class HANforLncRNA(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim, metadata, dropout=0.5):
        super().__init__()
        # HANConv: multi-view GNN with attention over meta-paths
        self.han = HANConv(
            in_channels=in_dim,
            out_channels=hidden_dim,
            metadata=metadata,
            heads=8,
        )
        self.lin = nn.Linear(hidden_dim, out_dim)  # HANConv outputs hidden_dim (not hidden_dim*heads)
        self.dropout = dropout

    def forward(self, x, edge_index_list):
        # PyG HANConv expects dict inputs: {node_type: x}, {edge_type: edge_index}
        x_dict = {'lncRNA': x}
        edge_index_dict = {etype: edge_index_list[i] for i, etype in enumerate(metadata[1])}
        h_dict = self.han(x_dict, edge_index_dict)
        h = h_dict['lncRNA']
        h = F.dropout(h, p=self.dropout, training=self.training)
        return self.lin(h)

# Instantiate model
model = HANforLncRNA(
    in_dim=x.size(1),
    hidden_dim=32,
    out_dim=2,
    metadata=metadata,
    dropout=0.5
)

optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=1e-5)

# -------- Training loop with early stopping --------
best_val = 0.0
patience = 5
patience_cnt = 0

for epoch in range(1, 101):
    model.train()
    optimizer.zero_grad()
    out = model(x, edge_index_list)
    loss = F.cross_entropy(out[train_idx], y[train_idx])
    loss.backward()
    optimizer.step()

    model.eval()
    with torch.no_grad():
        out = model(x, edge_index_list)
        pred = out.argmax(dim=-1)
        train_acc = (pred[train_idx] == y[train_idx]).float().mean().item()
        val_acc   = (pred[val_idx]   == y[val_idx]).float().mean().item()
        test_acc  = (pred[test_idx]  == y[test_idx]).float().mean().item()
    print(f"Epoch {epoch:03d} | Loss: {loss:.4f} | Train: {train_acc:.4f} | Val: {val_acc:.4f} | Test: {test_acc:.4f}")
    if val_acc > best_val:
        best_val = val_acc
        best_state = model.state_dict()
        patience_cnt = 0
    else:
        patience_cnt += 1
        if patience_cnt >= patience:
            print("Early stopping.")
            break

Epoch 001 | Loss: 0.7143 | Train: 0.4444 | Val: 0.3333 | Test: 0.5000
Epoch 002 | Loss: 0.7027 | Train: 0.4444 | Val: 0.1667 | Test: 0.3333
Epoch 003 | Loss: 0.6857 | Train: 0.5556 | Val: 0.3333 | Test: 0.3333
Epoch 004 | Loss: 0.6795 | Train: 0.5556 | Val: 0.3333 | Test: 0.3333
Epoch 005 | Loss: 0.7100 | Train: 0.5556 | Val: 0.3333 | Test: 0.3333
Epoch 006 | Loss: 0.6888 | Train: 0.5556 | Val: 0.3333 | Test: 0.3333
Early stopping.


In [11]:
# -------- Final evaluation on test set --------
model.load_state_dict(best_state)
model.eval()
with torch.no_grad():
    out = model(x, edge_index_list)
    pred = out.argmax(dim=-1)
    print("\nFinal Test accuracy:", (pred[test_idx] == y[test_idx]).float().mean().item())


Final Test accuracy: 0.3333333432674408
