In [38]:
import networkx as nx
import torch
from torch_geometric.data import Data
import numpy as np

# Load graph
G = nx.read_gexf("1ehz_graph.gexf")
# print(G.nodes)

# Create node feature matrix (use 3D coordinates as node features)
pos = []
x = []
base_map = {'A':0, 'U':1, 'G':2, 'C':3}
for node in G.nodes():
    attr = G.nodes[node]
    
    coord = (attr["x"], attr["y"], attr["z"])
    pos.append(coord)
    
    base = attr["nucleotide"]
    x.append([base_map.get(base, 0)])

pos = torch.tensor(pos, dtype=torch.float)
x = torch.tensor(x, dtype=torch.long)

node_map = {node: i for i, node in enumerate(G.nodes())}
# Create edge index
edges = []
for u, v in G.edges():
    edges.append([node_map[u], node_map[v]])
edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()  # shape: [2, num_edges]

# Dummy target (e.g., regression)
y = torch.tensor([0.5])  # Placeholder

mean = pos.mean(dim=0, keepdim=True)
std  = pos.std(dim=0, keepdim=True) + 1e-8

y = (pos - mean) / std

data = Data(x=x, edge_index=edge_index, pos=pos, y=y)
data.y_mean = mean
data.y_std  = std

print(data)
print("y mean (should be ~0):", data.y.mean(dim=0))
print("y std  (should be ~1):", data.y.std(dim=0))


Data(x=[62, 1], edge_index=[2, 153], y=[62, 3], pos=[62, 3], y_mean=[1, 3], y_std=[1, 3])
y mean (should be ~0): tensor([-1.3459e-07,  9.9982e-08, -1.1536e-08])
y std  (should be ~1): tensor([1.0000, 1.0000, 1.0000])


In [39]:
print("Num nodes:", data.num_nodes, "Num edges:", data.num_edges)

Num nodes: 62 Num edges: 153


In [40]:
print(data.x.shape)
print(data.x[:5])

torch.Size([62, 1])
tensor([[2],
        [3],
        [2],
        [2],
        [0]])


In [24]:
print("Has y?", hasattr(data, "y"))
print("data.y =", data.y)
print("data keys:", data.keys)

Has y? True
data.y = tensor([[50.6260, 49.7300, 50.5730],
        [54.6350, 50.4200, 53.7410],
        [60.1840, 49.4190, 54.5740],
        [65.2950, 47.8680, 51.7930],
        [68.5300, 45.7220, 46.7890],
        [69.1500, 45.1790, 41.3920],
        [67.4630, 47.0740, 35.9690],
        [63.8840, 49.2820, 30.8580],
        [64.7550, 50.7310, 26.3110],
        [59.4740, 46.4180, 23.8180],
        [60.9760, 41.8530, 26.8740],
        [65.4980, 40.0950, 30.2510],
        [71.4990, 42.5820, 32.5850],
        [76.4630, 45.2270, 33.5600],
        [81.8040, 57.4910, 38.8030],
        [81.7050, 59.9350, 32.8550],
        [79.2120, 62.2700, 27.1690],
        [76.9600, 61.5410, 23.2130],
        [76.1220, 55.6070, 20.6240],
        [75.5710, 49.4040, 18.7840],
        [75.6910, 43.8140, 17.1700],
        [72.6600, 39.6990, 15.3920],
        [62.8600, 40.7090, 10.5690],
        [61.7150, 44.5510,  6.4290],
        [64.1340, 46.8100,  1.7150],
        [69.1480, 46.9860, -1.8750],
        [74.5300,

In [41]:
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv

class TinyForwardGNN(nn.Module):
    def __init__(self, hidden=64):
        super().__init__()
        self.conv1 = SAGEConv(4, hidden)     # 4 after one-hot
        self.conv2 = SAGEConv(hidden, hidden)
        self.lin   = nn.Linear(hidden, 3)

    def forward(self, data):
        # data.x: [N,1] integers -> [N,4] one-hot float
        x = F.one_hot(data.x.squeeze(-1), num_classes=4).float()
        edge_index = data.edge_index

        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        x = F.relu(x)

        return self.lin(x)  # [N,3]

In [42]:
import torch
import torch.nn as nn

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = TinyForwardGNN(hidden=64).to(device)
data = data.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)
loss_fn = nn.MSELoss()

for epoch in range(1, 201):
    model.train()
    optimizer.zero_grad()

    pred = model(data)              # [N,3]
    loss = loss_fn(pred, data.y)    # [N,3] vs [N,3]

    loss.backward()
    optimizer.step()

    if epoch % 20 == 0:
        print(f"Epoch {epoch:03d} | Loss: {loss.item():.6f}")


Epoch 020 | Loss: 0.811098
Epoch 040 | Loss: 0.638463
Epoch 060 | Loss: 0.509135
Epoch 080 | Loss: 0.387754
Epoch 100 | Loss: 0.291490
Epoch 120 | Loss: 0.219822
Epoch 140 | Loss: 0.168273
Epoch 160 | Loss: 0.130806
Epoch 180 | Loss: 0.101922
Epoch 200 | Loss: 0.079308


In [43]:
model.eval()
with torch.no_grad():
    pred = model(data)  # normalized space
    pred_un = pred * data.y_std + data.y_mean
    true_un = data.y * data.y_std + data.y_mean   # should equal pos

    # sanity check: true_un equals original pos
    print("max |true_un - pos|:", (true_un - data.pos).abs().max().item())

    per_node_rmse = torch.sqrt(((pred_un - true_un) ** 2).sum(dim=1)).mean().item()

print("Mean per-node RMSE (real units):", per_node_rmse)


max |true_un - pos|: 3.814697265625e-06
Mean per-node RMSE (real units): 4.920711040496826


In [44]:
with torch.no_grad():
    mean_pred = pos.mean(dim=0, keepdim=True).repeat(pos.size(0), 1)
    baseline_rmse = torch.sqrt(((mean_pred - pos) ** 2).sum(dim=1)).mean().item()
print("Baseline RMSE (predict mean):", baseline_rmse)

Baseline RMSE (predict mean): 21.004024505615234
