In [1]:
# Full Python Script: BERT Embeddings + PyG Heterogeneous Graph for R-GAT

import torch
from transformers import BertTokenizer, BertModel
from torch_geometric.data import HeteroData

# 1. Load BERT
bert_model_name = "bert-base-uncased"
tokenizer = BertTokenizer.from_pretrained(bert_model_name)
model = BertModel.from_pretrained(bert_model_name)

# 2. Sample instance data aligned to schema
instance_data = {
    "type": "Person",
    "name": "Ada Lovelace",
    "birthdate": "1815-12-10",
    "address": {
        "city": "London",
        "country": "United Kingdom"
    }
}

# 3. Function to convert value to BERT embedding
@torch.no_grad()
def get_bert_embedding(text: str) -> torch.Tensor:
    inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=16)
    outputs = model(**inputs)
    return outputs.last_hidden_state[:, 0, :]  # CLS token

# 4. Build node feature dictionary
node_features = {
    "name_value": get_bert_embedding(instance_data["name"]),
    "birthdate_value": get_bert_embedding(instance_data["birthdate"]),
    "city_value": get_bert_embedding(instance_data["address"]["city"]),
    "country_value": get_bert_embedding(instance_data["address"]["country"])
}

# 5. Define heterogeneous graph
data = HeteroData()

# Add nodes with features
for node_name, embedding in node_features.items():
    node_type = "Value"
    data[node_type].x = torch.cat([data[node_type].x, embedding], dim=0) if node_type in data else embedding

# Map node names to indices in each type
value_node_index = {name: i for i, name in enumerate(node_features.keys())}

# Add typed nodes for fields
fields = ["name", "birthdate", "city", "country"]
data["Field"].x = torch.eye(len(fields))  # one-hot for demo
field_index = {f: i for i, f in enumerate(fields)}

# Add schema node
data["Schema"].x = torch.tensor([[1.0]])  # dummy feature

# 6. Define edges and types
# Connect schema to fields
schema_edges = [(0, field_index[f]) for f in fields]
data["Schema", "has_field", "Field"].edge_index = torch.tensor(schema_edges).t().contiguous()

# Connect fields to values
field_value_edges = [
    (field_index["name"], value_node_index["name_value"]),
    (field_index["birthdate"], value_node_index["birthdate_value"]),
    (field_index["city"], value_node_index["city_value"]),
    (field_index["country"], value_node_index["country_value"]),
]
data["Field", "field_value", "Value"].edge_index = torch.tensor(field_value_edges).t().contiguous()

# Reverse relations (optional for R-GAT)
data["Field", "rev_has_field", "Schema"].edge_index = data["Schema", "has_field", "Field"].edge_index.flip(0)
data["Value", "rev_field_value", "Field"].edge_index = data["Field", "field_value", "Value"].edge_index.flip(0)

# 7. Print graph structure
print(data)
print("Schema -> Field edges:", data["Schema", "has_field", "Field"].edge_index)
print("Field -> Value edges:", data["Field", "field_value", "Value"].edge_index)
print("Value node feature shape:", data["Value"].x.shape)
print("Field node feature shape:", data["Field"].x.shape)
print("Schema node feature shape:", data["Schema"].x.shape)

# This graph is now ready for use in a relational GNN (R-GAT, HGT, etc.)


  from .autonotebook import tqdm as notebook_tqdm


HeteroData(
  Value={ x=[1, 768] },
  Field={ x=[4, 4] },
  Schema={ x=[1, 1] },
  (Schema, has_field, Field)={ edge_index=[2, 4] },
  (Field, field_value, Value)={ edge_index=[2, 4] },
  (Field, rev_has_field, Schema)={ edge_index=[2, 4] },
  (Value, rev_field_value, Field)={ edge_index=[2, 4] }
)
Schema -> Field edges: tensor([[0, 0, 0, 0],
        [0, 1, 2, 3]])
Field -> Value edges: tensor([[0, 1, 2, 3],
        [0, 1, 2, 3]])
Value node feature shape: torch.Size([1, 768])
Field node feature shape: torch.Size([4, 4])
Schema node feature shape: torch.Size([1, 1])
