In [None]:
from knowledge_engine.mpgnn import MPGNN
import torch, torch.nn as nn
from torch_geometric.data import Data
from torch_geometric.loader import NeighborLoader

In [None]:

genc = MPGNN(nfeat_node=x_dim, nfeat_edge=0, nhid=256, nout=256,
             nlayer_gnn=3, node_type='continuous', edge_type='none',
             gnn_type='GIN', pooling='mean').cuda()

# --- Relation embeddings for typed decoder ---
n_rel = num_relation_types  # e.g., is_a, part_of, regulates, etc.
rel_emb = nn.Embedding(n_rel, 256).cuda()

# DistMult decoder
def score(u, r, v):
    return (u * r * v).sum(-1)

bce = nn.BCEWithLogitsLoss()
optim = torch.optim.AdamW(list(genc.parameters()) + list(rel_emb.parameters()), lr=3e-4)

# Build PyG Data
data = Data(x=node_feats, edge_index=edge_index, edge_type=edge_type)  # edge_type: [E] int64
# For huge ontologies, use NeighborLoader to sample subgraphs per step:
loader = NeighborLoader(data, num_neighbors=[20, 20], batch_size=4096, input_nodes=None, shuffle=True)

temp = 0.2  # InfoNCE temperature
def info_nce(anchor, pos, neg):  # all [B, D], neg [B, K, D]
    B, D = anchor.size()
    pos_sim = torch.cosine_similarity(anchor, pos) / temp  # [B]
    neg_sim = (anchor.unsqueeze(1) * neg).sum(-1) / (anchor.norm(dim=-1, keepdim=True) * neg.norm(dim=-1)) / temp  # [B,K]
    logits = torch.cat([pos_sim.unsqueeze(1), neg_sim], dim=1)  # [B, 1+K]
    labels = torch.zeros(B, dtype=torch.long, device=anchor.device)
    return nn.CrossEntropyLoss()(logits, labels)

for it, batch in enumerate(loader):
    batch = batch.to('cuda')
    # 1) Node embeddings on sampled subgraph
    H = genc.gnns(genc.input_encoder(batch.x), batch.edge_index, edge_attr=torch.zeros(batch.edge_index.size(1), 1, device='cuda'))
    # H: [N_sub, 256]; we won't pool (we need per-node)

    # 2) Gather positive edges in the sampled subgraph
    ei = batch.edge_index  # [2, E_sub]
    et = batch.edge_type   # [E_sub]
    u, v = ei[0], ei[1]    # child,parent or subject,object
    r = rel_emb(et)

    # 3) Negative sampling (corrupt tail)
    v_neg = torch.randint_like(v, high=H.size(0))
    # (optionally avoid sampling true neighbors)

    # 4) DistMult LP loss
    pos_s = score(H[u], r, H[v])
    neg_s = score(H[u], r, H[v_neg])
    y_pos = torch.ones_like(pos_s)
    y_neg = torch.zeros_like(neg_s)
    L_link = bce(torch.cat([pos_s, neg_s], 0), torch.cat([y_pos, y_neg], 0))

    # 5) Hierarchical contrastive (use parent as positive; random non-ancestors as negatives)
    # Build a small negative set per u (sample K random nodes)
    K = 5
    neg_idx = torch.randint(0, H.size(0), (u.size(0), K), device=H.device)
    L_contrast = info_nce(H[u], H[v], H[neg_idx])  # cosine InfoNCE

    L = L_link + 0.2 * L_contrast
    optim.zero_grad(set_to_none=True)
    L.backward()
    optim.step()

# After training, run a full forward once to export:
with torch.no_grad():
    H_full = genc(data.to('cuda'))  # if MPGNN returns pooled graph, expose an encoder that returns per-node
    # If your MPGNN pools, add a method to return node embeddings before pooling.
torch.save(H_full.detach().cpu(), "ontology_embeddings.pt")
