In [1]:
import pandas as pd
import numpy as np
import torch
from torch_geometric.data import Data
from torch_geometric.nn import Node2Vec, GCNConv
import torch.nn.functional as F
from tqdm import tqdm
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE



In [7]:
# Generate 200 nodes (100 lncRNA + 100 gene), 1000 edges, and random labels
num_lnc = 100
num_gene = 100
lnc_nodes = [f'lncRNA{i}' for i in range(num_lnc)]
gene_nodes = [f'gene{i}' for i in range(num_gene)]
all_nodes = lnc_nodes + gene_nodes

num_edges = 1000
np.random.seed(42)
edge_sources = np.random.choice(all_nodes, num_edges)
edge_targets = np.random.choice(all_nodes, num_edges)
edge_df = pd.DataFrame({'source': edge_sources, 'target': edge_targets})

labels = np.random.choice([0, 1], size=len(all_nodes))
label_df = pd.DataFrame({'node': all_nodes, 'label': labels})

# Node index mapping
node2idx = {name: i for i, name in enumerate(all_nodes)}
num_nodes = len(all_nodes)
edge_index = torch.tensor([
    edge_df['source'].map(node2idx).values,
    edge_df['target'].map(node2idx).values
], dtype=torch.long)

# Labels and train/test mask
y_map = dict(zip(label_df['node'], label_df['label']))
y = torch.tensor([y_map[node] for node in all_nodes], dtype=torch.long)
idx = np.arange(num_nodes)
np.random.shuffle(idx)
train_size = int(0.8 * num_nodes)
train_idx = idx[:train_size]
test_idx = idx[train_size:]
train_mask = torch.zeros(num_nodes, dtype=torch.bool)
test_mask = torch.zeros(num_nodes, dtype=torch.bool)
train_mask[train_idx] = True
test_mask[test_idx] = True

# Build PyG Data object
data = Data(
    edge_index=edge_index,
    y=y,
    train_mask=train_mask,
    test_mask=test_mask
)


In [8]:
# Train Node2Vec to extract structure features
device = 'cuda' if torch.cuda.is_available() else 'cpu'
node2vec = Node2Vec(
    data.edge_index, embedding_dim=64,
    walk_length=10, context_size=5, walks_per_node=5,
    num_negative_samples=1, p=1, q=1, sparse=True
).to(device)

loader = node2vec.loader(batch_size=128, shuffle=True, num_workers=0)
optimizer = torch.optim.SparseAdam(list(node2vec.parameters()), lr=0.01)

def train_node2vec():
    node2vec.train()
    total_loss = 0
    for pos_rw, neg_rw in tqdm(loader):
        optimizer.zero_grad()
        loss = node2vec.loss(pos_rw.to(device), neg_rw.to(device))
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(loader)

print("Training Node2Vec...")
for epoch in range(1, 21):  # 20 epochs for demonstration
    loss = train_node2vec()
    print(f'Node2Vec Epoch: {epoch:02d}, Loss: {loss:.4f}')

@torch.no_grad()
def get_embeddings():
    node2vec.eval()
    z = node2vec(torch.arange(data.num_nodes, device=device))
    return z.cpu()

# Use Node2Vec embeddings as node features
data.x = get_embeddings()


Training Node2Vec...


100%|████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 71.91it/s]


Node2Vec Epoch: 01, Loss: 6.8641


100%|████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 70.19it/s]


Node2Vec Epoch: 02, Loss: 6.4812


100%|████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 68.28it/s]


Node2Vec Epoch: 03, Loss: 6.2985


100%|████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 69.32it/s]


Node2Vec Epoch: 04, Loss: 5.9819


100%|████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 69.01it/s]


Node2Vec Epoch: 05, Loss: 5.8813


100%|████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 63.21it/s]


Node2Vec Epoch: 06, Loss: 5.6732


100%|████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 66.30it/s]


Node2Vec Epoch: 07, Loss: 5.4141


100%|████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 68.81it/s]


Node2Vec Epoch: 08, Loss: 5.2935


100%|████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 74.05it/s]


Node2Vec Epoch: 09, Loss: 5.1119


100%|████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 68.81it/s]


Node2Vec Epoch: 10, Loss: 4.9524


100%|████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 54.41it/s]


Node2Vec Epoch: 11, Loss: 4.7886


100%|████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 68.40it/s]


Node2Vec Epoch: 12, Loss: 4.6489


100%|████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 61.95it/s]


Node2Vec Epoch: 13, Loss: 4.4421


100%|████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 69.05it/s]


Node2Vec Epoch: 14, Loss: 4.3889


100%|████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 76.78it/s]


Node2Vec Epoch: 15, Loss: 4.2375


100%|████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 72.52it/s]


Node2Vec Epoch: 16, Loss: 4.1236


100%|████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 72.58it/s]


Node2Vec Epoch: 17, Loss: 4.0086


100%|████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 74.63it/s]


Node2Vec Epoch: 18, Loss: 3.9492


100%|████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 82.44it/s]


Node2Vec Epoch: 19, Loss: 3.8329


100%|████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 69.00it/s]

Node2Vec Epoch: 20, Loss: 3.7484





In [9]:
print(data.x)

tensor([[-0.5938, -0.7929,  1.2026,  ..., -1.4360,  1.0293,  0.5700],
        [-1.0564, -0.1794, -0.8746,  ..., -0.1808,  0.1863, -0.6419],
        [ 0.0204, -0.4464, -0.8447,  ..., -0.1966,  0.7166,  1.4012],
        ...,
        [ 0.7138, -0.5813,  0.5633,  ...,  0.0350, -0.2183, -0.1386],
        [-0.6607,  0.0411,  1.2769,  ...,  0.3121,  0.5011, -0.0204],
        [-1.4351,  0.0510, -0.7352,  ...,  0.7844,  0.0200, -0.1610]])


In [5]:
class GCN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, out_channels)
    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        return x

model_gcn = GCN(data.x.shape[1], 32, int(data.y.max().item())+1).to(device)
data = data.to(device)
optimizer_gcn = torch.optim.Adam(model_gcn.parameters(), lr=0.01, weight_decay=5e-4)

def train_gcn():
    model_gcn.train()
    optimizer_gcn.zero_grad()
    out = model_gcn(data.x, data.edge_index)
    loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer_gcn.step()
    return loss.item()

@torch.no_grad()
def test_gcn():
    model_gcn.eval()
    out = model_gcn(data.x, data.edge_index)
    pred = out.argmax(dim=1)
    acc = (pred[data.test_mask] == data.y[data.test_mask]).sum().item() / data.test_mask.sum().item()
    return acc

print("Training GCN with Node2Vec features...")
for epoch in range(1, 51):
    loss = train_gcn()
    acc = test_gcn()
    print(f'GCN Epoch: {epoch:03d}, Loss: {loss:.4f}, Test Acc: {acc:.4f}')


Training GCN with Node2Vec features...
GCN Epoch: 001, Loss: 0.7206, Test Acc: 0.5000
GCN Epoch: 002, Loss: 0.6745, Test Acc: 0.6250
GCN Epoch: 003, Loss: 0.6509, Test Acc: 0.5750
GCN Epoch: 004, Loss: 0.6328, Test Acc: 0.6250
GCN Epoch: 005, Loss: 0.6126, Test Acc: 0.7000
GCN Epoch: 006, Loss: 0.5910, Test Acc: 0.6000
GCN Epoch: 007, Loss: 0.5708, Test Acc: 0.6000
GCN Epoch: 008, Loss: 0.5533, Test Acc: 0.6000
GCN Epoch: 009, Loss: 0.5380, Test Acc: 0.5750
GCN Epoch: 010, Loss: 0.5230, Test Acc: 0.5750
GCN Epoch: 011, Loss: 0.5078, Test Acc: 0.6000
GCN Epoch: 012, Loss: 0.4921, Test Acc: 0.6250
GCN Epoch: 013, Loss: 0.4764, Test Acc: 0.6250
GCN Epoch: 014, Loss: 0.4616, Test Acc: 0.6250
GCN Epoch: 015, Loss: 0.4477, Test Acc: 0.6250
GCN Epoch: 016, Loss: 0.4342, Test Acc: 0.6500
GCN Epoch: 017, Loss: 0.4208, Test Acc: 0.6500
GCN Epoch: 018, Loss: 0.4072, Test Acc: 0.6500
GCN Epoch: 019, Loss: 0.3938, Test Acc: 0.6750
GCN Epoch: 020, Loss: 0.3809, Test Acc: 0.6750
GCN Epoch: 021, Loss: