<a href="https://colab.research.google.com/github/draginverse/dragin-healthcare/blob/feature%2Fg-retriever/scripts/retrieval/gretriever.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [23]:
!pip install torch_geometric
!pip install pcst_fast
!pip install torch_scatter -f https://data.pyg.org/
!pip install datasets

Looking in links: https://data.pyg.org/


# simulate graphs input

In [44]:
toy_graphs = [
    [
        ("asthma", "caused_by", "allergens"),
        ("inhaler", "treats", "asthma"),
        ("asthma", "symptom", "shortness of breath")
    ],
    [
        ("asthma", "treated_by", "inhaler"),
        ("inhaler", "treats", "asthma"),
        ("asthma", "symptom", "shortness of breath")
    ],
    [
        ("copd", "risk_factor", "smoking"),
        ("oxygen therapy", "treats", "copd"),
        ("copd", "symptom", "chronic cough")
    ],
    [
        ("bronchitis", "caused_by", "virus"),
        ("bronchitis", "symptom", "chest discomfort"),
        ("rest", "helps_with", "bronchitis")
    ],
    [
        ("pneumonia", "caused_by", "bacteria"),
        ("antibiotics", "treats", "pneumonia"),
        ("pneumonia", "symptom", "fever")
    ],
    [
        ("covid-19", "affects", "lungs"),
        ("vaccine", "prevents", "covid-19"),
        ("covid-19", "symptom", "loss of smell")
    ]
]


transform graphs to pyg (needed for the encoder)

In [45]:
from transformers import AutoModel, AutoTokenizer
import torch
from torch_geometric.data import Data, Batch

# 1. Load the MiniLM model
model_name = "sentence-transformers/all-MiniLM-L6-v2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# 2. Text to embedding function (optimized for MiniLM)
def text2embedding(texts):
    inputs = tokenizer(texts, padding=True, truncation=True,
                      max_length=128, return_tensors="pt").to(device)
    with torch.no_grad():
        outputs = model(**inputs)
    # Use mean pooling instead of CLS token for better performance
    embeddings = outputs.last_hidden_state.mean(dim=1).cpu()
    return embeddings

# 3. Graph transformation function
def transform_graphs_to_pyg(triple_graphs):
    pyg_graphs = []

    for triples in triple_graphs:
        # Extract all node names
        node_names = set()
        for src, _, dst in triples:
            node_names.update([src.lower(), dst.lower()])
        node_names = sorted(node_names)
        node_map = {name: idx for idx, name in enumerate(node_names)}

        # Edge list and texts
        edge_list = []
        edge_texts = []
        for src, rel, dst in triples:
            edge_list.append([node_map[src.lower()], node_map[dst.lower()]])
            edge_texts.append(f"relation: {rel}")

        # Node embeddings
        node_texts = [f"node: {name}" for name in node_names]
        node_embeddings = text2embedding(node_texts)
        edge_embeddings = text2embedding(edge_texts) if edge_texts else torch.zeros(0, 384)

        # Create PyG graph
        pyg_graph = Data(
            x=node_embeddings,
            edge_index=torch.tensor(edge_list).t().contiguous(),
            edge_attr=edge_embeddings,
            num_nodes=len(node_names)
        )
        pyg_graphs.append(pyg_graph)

    return pyg_graphs

# ======== Verify consistency ===============
# Transform the graphs
pyg_graphs = transform_graphs_to_pyg(toy_graphs)

# Create a batch of graphs for processing
toy_graph_batch = Batch.from_data_list(pyg_graphs)

# Print information about the first graph
print("First graph in PyG format:")
print(pyg_graphs[0])
#print("\nNode mapping:", {name: idx for idx, name in enumerate(node_encoder.classes_)})
#print("Edge type mapping:", {name: idx for idx, name in enumerate(edge_type_encoder.classes_)})
print("\nBatch information:")
print(toy_graph_batch)
print("Batch vector:", toy_graph_batch.batch)

print("Total nodes:", toy_graph_batch.num_nodes)
print("Batch vector max index:", toy_graph_batch.batch.max())
print("Batch vector length:", len(toy_graph_batch.batch))

assert toy_graph_batch.batch.max() < len(pyg_graphs), "Batch indices exceed graph count"
assert len(toy_graph_batch.batch) == toy_graph_batch.num_nodes, "Batch vector length mismatch"

First graph in PyG format:
Data(x=[4, 384], edge_index=[2, 3], edge_attr=[3, 384], num_nodes=4)

Batch information:
DataBatch(x=[23, 384], edge_index=[2, 18], edge_attr=[18, 384], num_nodes=23, batch=[23], ptr=[7])
Batch vector: tensor([0, 0, 0, 0, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5])
Total nodes: 23
Batch vector max index: tensor(5)
Batch vector length: 23


# gnn.py

In [46]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, TransformerConv, GATConv


class GCN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers, dropout, num_heads=-1):
        super(GCN, self).__init__()
        self.convs = torch.nn.ModuleList()
        self.convs.append(GCNConv(in_channels, hidden_channels))
        self.bns = torch.nn.ModuleList()
        self.bns.append(torch.nn.BatchNorm1d(hidden_channels))
        for _ in range(num_layers - 2):
            self.convs.append(GCNConv(hidden_channels, hidden_channels))
            self.bns.append(torch.nn.BatchNorm1d(hidden_channels))
        self.convs.append(GCNConv(hidden_channels, out_channels))
        self.dropout = dropout

    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()
        for bn in self.bns:
            bn.reset_parameters()

    def forward(self, x, adj_t, edge_attr):
        for i, conv in enumerate(self.convs[:-1]):
            x = conv(x, adj_t)
            x = self.bns[i](x)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.convs[-1](x, adj_t)
        return x, edge_attr


class GraphTransformer(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers, dropout, num_heads=-1):
        super(GraphTransformer, self).__init__()
        self.convs = torch.nn.ModuleList()
        self.convs.append(TransformerConv(in_channels=in_channels, out_channels=hidden_channels//num_heads, heads=num_heads, edge_dim=in_channels, dropout=dropout))
        self.bns = torch.nn.ModuleList()
        self.bns.append(torch.nn.BatchNorm1d(hidden_channels))
        for _ in range(num_layers - 2):
            self.convs.append(TransformerConv(in_channels=hidden_channels, out_channels=hidden_channels//num_heads, heads=num_heads, edge_dim=in_channels, dropout=dropout,))
            self.bns.append(torch.nn.BatchNorm1d(hidden_channels))
        self.convs.append(TransformerConv(in_channels=hidden_channels, out_channels=out_channels//num_heads, heads=num_heads, edge_dim=in_channels, dropout=dropout,))
        self.dropout = dropout

    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()
        for bn in self.bns:
            bn.reset_parameters()

    def forward(self, x, adj_t, edge_attr):
        for i, conv in enumerate(self.convs[:-1]):
            x = conv(x, edge_index=adj_t, edge_attr=edge_attr)
            x = self.bns[i](x)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.convs[-1](x, edge_index=adj_t, edge_attr=edge_attr)
        return x, edge_attr

class GAT(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers, dropout, num_heads=4):
        super(GAT, self).__init__()
        self.convs = torch.nn.ModuleList()
        self.convs.append(GATConv(in_channels, hidden_channels, heads=num_heads, concat=False))
        self.bns = torch.nn.ModuleList()
        self.bns.append(torch.nn.BatchNorm1d(hidden_channels))
        for _ in range(num_layers - 2):
            self.convs.append(GATConv(hidden_channels, hidden_channels, heads=num_heads, concat=False))
            self.bns.append(torch.nn.BatchNorm1d(hidden_channels))
        self.convs.append(GATConv(hidden_channels, out_channels, heads=num_heads, concat=False))
        self.dropout = dropout

    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()
        for bn in self.bns:
            bn.reset_parameters()

    def forward(self, x, edge_index, edge_attr):
        for i, conv in enumerate(self.convs[:-1]):
            x = conv(x, edge_index=edge_index, edge_attr=edge_attr)
            x = self.bns[i](x)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.convs[-1](x,edge_index=edge_index, edge_attr=edge_attr)
        return x, edge_attr


load_gnn_model = {
    'gcn': GCN,
    'gat': GAT,
    'gt': GraphTransformer,
}

# Training

In [36]:
from torch.optim import AdamW
from torch_geometric.loader import DataLoader
import copy

# New external training utilities
class GNNTrainingUtils:
    @staticmethod
    def train_model(model, train_graphs, train_targets, test_graphs, test_targets, config=None):
        """External training function that works with your original GraphEncoder"""
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        model = model.to(device)
        criterion = nn.MSELoss()
        optimizer = AdamW(model.parameters(), lr=config['learning_rate'], weight_decay=config['weight_decay'])

        best_loss = float('inf')
        best_model = None

        config = config or {
            'batch_size': 32,
            'epochs': 10,
            'learning_rate': 1e-4,
            'weight_decay': 1e-5
        }


        for epoch in range(config['epochs']):
            model.train()
            train_loss = 0
            indices = torch.randperm(len(pyg_graphs))  # Random batch ordering
            for i in range(0, len(indices), config['batch_size']):
                batch_idx = indices[i:i+config['batch_size']].tolist()

                # Create batch from selected graphs
                batch_graphs = [pyg_graphs[idx] for idx in batch_idx]
                batch = Batch.from_data_list(batch_graphs).to(device)
                batch_targets = targets[batch_idx].to(device)

                # Training step
                optimizer.zero_grad()
                embeddings = model.encode(batch, training_mode=True)
                loss = criterion(embeddings, batch_targets)
                loss.backward()
                optimizer.step()
                train_loss += loss.item()

            # Evaluation phase
            model.eval()
            test_loss = 0
            with torch.no_grad():
                test_batch = Batch.from_data_list(test_graphs).to(device)
                embeddings = model.encode(test_batch,training_mode=True)
                test_loss = criterion(embeddings, test_targets.to(device))

            # Track best model
            if test_loss < best_loss:
                best_loss = test_loss
                best_model = copy.deepcopy(model.state_dict())
            print(f"Epoch {epoch+1}, Loss: {train_loss/(len(indices)/config['batch_size']):.4f}")

        # Load best model weights
        model.load_state_dict(best_model)
        return model

    '''@staticmethod
    def evaluate(model, data_loader, criterion, device):
        """Evaluation function"""
        model.eval()
        total_loss = 0
        with torch.no_grad():
            for graphs, targets in data_loader:
                graphs, targets = graphs.to(device), targets.to(device)
                outputs = model.encode(graphs, training_mode=True)
                total_loss += criterion(outputs, targets).item()
        return total_loss / len(data_loader)'''

    @staticmethod
    def save_model(model, path):
        """Save helper that works with your original class"""
        torch.save(model.state_dict(), path)

    @staticmethod
    def load_model(model_class, args, path):
        """Load helper that works with your original class"""
        model = model_class(args)
        model.load_state_dict(torch.load(path))
        return model

# graph_encoder.py (embeds graphs)

In [29]:
import torch
import torch.nn as nn
from torch_scatter import scatter
#from gnn import load_gnn_model

class GraphEncoder(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.graph_encoder = load_gnn_model[args.gnn_model_name](
            in_channels=args.gnn_in_dim,
            out_channels=args.gnn_hidden_dim,
            hidden_channels=args.gnn_hidden_dim,
            num_layers=args.gnn_num_layers,
            dropout=args.gnn_dropout,
            num_heads=args.gnn_num_heads,
        )
        self.projector = nn.Sequential(
            nn.Linear(args.gnn_hidden_dim, 2048),
            nn.Sigmoid(),
            nn.Linear(2048, 4096),
        )

        # Add this adapter ONLY for training
        self.training_adapter = nn.Linear(4096, 384)  # Projects to MiniLM dimension

    def encode(self, graphs, training_mode=False):
        graphs = graphs.to(next(self.parameters()).device)
        n_embeds, _ = self.graph_encoder(graphs.x, graphs.edge_index.long(), graphs.edge_attr)
        g_embeds = scatter(n_embeds, graphs.batch, dim=0, reduce='mean')
        projected_embeds = self.projector(g_embeds)
        # Only use adapter during training
        if training_mode:
            return self.training_adapter(projected_embeds)
        return projected_embeds

# pcst_retrieval.py

In [None]:
import torch
import numpy as np
from pcst_fast import pcst_fast
from torch_geometric.data.data import Data

def retrieval_via_pcst(graph, q_emb, textual_nodes, textual_edges, topk=3, topk_e=3, cost_e=0.5):
    c = 0.01
    if len(textual_nodes) == 0 or len(textual_edges) == 0:
        desc = textual_nodes.to_csv(index=False) + '\n' + textual_edges.to_csv(index=False, columns=['src', 'edge_attr', 'dst'])
        return graph, desc

    # === Project graph features to match query dim ===
    projection = nn.Linear(graph.x.size(1), q_emb.size(0), bias=False)
    with torch.no_grad():
        projected_node_x = projection(graph.x)           # shape: [num_nodes, 4096]
        projected_edge_attr = projection(graph.edge_attr)  # shape: [num_edges, 4096]

    # === Proceed with projected features ===
    if topk > 0:
        n_prizes = torch.nn.CosineSimilarity(dim=-1)(q_emb, projected_node_x)
        topk = min(topk, graph.num_nodes)
        _, topk_n_indices = torch.topk(n_prizes, topk, largest=True)
        n_prizes = torch.zeros_like(n_prizes)
        n_prizes[topk_n_indices] = torch.arange(topk, 0, -1).float()
    else:
        n_prizes = torch.zeros(graph.num_nodes)

    if topk_e > 0:
        e_prizes = torch.nn.CosineSimilarity(dim=-1)(q_emb, projected_edge_attr)
        topk_e = min(topk_e, e_prizes.unique().size(0))
        topk_e_values, _ = torch.topk(e_prizes.unique(), topk_e, largest=True)
        e_prizes[e_prizes < topk_e_values[-1]] = 0.0
        last_topk_e_value = topk_e
        for k in range(topk_e):
            indices = e_prizes == topk_e_values[k]
            value = min((topk_e - k) / sum(indices), last_topk_e_value)
            e_prizes[indices] = value
            last_topk_e_value = value * (1 - c)
        cost_e = min(cost_e, e_prizes.max().item() * (1 - c / 2))
    else:
        e_prizes = torch.zeros(graph.num_edges)

    # === Rest of the PCST logic remains unchanged ===
    costs = []
    edges = []
    vritual_n_prizes = []
    virtual_edges = []
    virtual_costs = []
    mapping_n = {}
    mapping_e = {}
    for i, (src, dst) in enumerate(graph.edge_index.T.numpy()):
        prize_e = e_prizes[i]
        if prize_e <= cost_e:
            mapping_e[len(edges)] = i
            edges.append((src, dst))
            costs.append(cost_e - prize_e)
        else:
            virtual_node_id = graph.num_nodes + len(vritual_n_prizes)
            mapping_n[virtual_node_id] = i
            virtual_edges.append((src, virtual_node_id))
            virtual_edges.append((virtual_node_id, dst))
            virtual_costs.append(0)
            virtual_costs.append(0)
            vritual_n_prizes.append(prize_e - cost_e)

    prizes = np.concatenate([n_prizes, np.array(vritual_n_prizes)])
    num_edges = len(edges)
    if len(virtual_costs) > 0:
        costs = np.array(costs + virtual_costs)
        edges = np.array(edges + virtual_edges)

    vertices, edges = pcst_fast(edges, prizes, costs, -1, 1, 'gw', 0)

    selected_nodes = vertices[vertices < graph.num_nodes]
    selected_edges = [mapping_e[e] for e in edges if e < num_edges]
    virtual_vertices = vertices[vertices >= graph.num_nodes]
    if len(virtual_vertices) > 0:
        virtual_edges = [mapping_n[i] for i in virtual_vertices]
        selected_edges = np.array(selected_edges + virtual_edges)

    edge_index = graph.edge_index[:, selected_edges]
    selected_nodes = np.unique(np.concatenate([selected_nodes, edge_index[0].numpy(), edge_index[1].numpy()]))

    n = textual_nodes.iloc[selected_nodes]
    e = textual_edges.iloc[selected_edges]
    desc = n.to_csv(index=False) + '\n' + e.to_csv(index=False, columns=['src', 'edge_attr', 'dst'])

    mapping = {n: i for i, n in enumerate(selected_nodes.tolist())}
    x = graph.x[selected_nodes]
    edge_attr = graph.edge_attr[selected_edges]
    src = [mapping[i] for i in edge_index[0].tolist()]
    dst = [mapping[i] for i in edge_index[1].tolist()]
    edge_index = torch.LongTensor([src, dst])
    data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, num_nodes=len(selected_nodes))

    return data, desc

# compute graph embeddings & retrieve the most relevant graph:

train GNN before retrieving

In [50]:
from types import SimpleNamespace
from torch.utils.data import TensorDataset
from sklearn.model_selection import train_test_split

# Training
# 2. Create targets (mean node features for each graph)
targets = torch.stack([g.x.mean(dim=0) for g in pyg_graphs])  # Graph-level targets

# 2. Split data (80% train, 20% test)
train_graphs, test_graphs, train_targets, test_targets = train_test_split(
    pyg_graphs, targets, test_size=0.2, random_state=42
)

graph_encoder_args = SimpleNamespace(
        gnn_model_name="gat",
        gnn_in_dim=384,
        gnn_hidden_dim=128,
        gnn_num_layers=2,
        gnn_dropout=0.1,
        gnn_num_heads=4,
    )

# Initialize fresh model
graph_encoder = GraphEncoder(graph_encoder_args)

# Train the model
trained_encoder = GNNTrainingUtils.train_model(
    graph_encoder,
    train_graphs, train_targets,
    test_graphs, test_targets,
    config={
        'batch_size': 2,
        'epochs': 100,
        'learning_rate': 1e-4,
        'weight_decay': 1e-5
    }
)

# Final evaluation (example)
'''
test_batch = Batch.from_data_list(test_graphs).to(device)
with torch.no_grad():
    embeddings = trained_encoder.encode(test_batch)
    test_loss = nn.MSELoss()(embeddings, test_targets.to(device))
print(f"\nFinal Test Loss: {test_loss:.4f}")'''

# Save the trained model
GNNTrainingUtils.save_model(trained_encoder, "/content/pretrained.pth")


Epoch 1, Loss: 0.0670
Epoch 2, Loss: 0.0551
Epoch 3, Loss: 0.0296
Epoch 4, Loss: 0.0309
Epoch 5, Loss: 0.0325
Epoch 6, Loss: 0.0230
Epoch 7, Loss: 0.0216
Epoch 8, Loss: 0.0243
Epoch 9, Loss: 0.0220
Epoch 10, Loss: 0.0191
Epoch 11, Loss: 0.0212
Epoch 12, Loss: 0.0206
Epoch 13, Loss: 0.0184
Epoch 14, Loss: 0.0201
Epoch 15, Loss: 0.0189
Epoch 16, Loss: 0.0180
Epoch 17, Loss: 0.0180
Epoch 18, Loss: 0.0174
Epoch 19, Loss: 0.0167
Epoch 20, Loss: 0.0160
Epoch 21, Loss: 0.0158
Epoch 22, Loss: 0.0144
Epoch 23, Loss: 0.0157
Epoch 24, Loss: 0.0129
Epoch 25, Loss: 0.0122
Epoch 26, Loss: 0.0142
Epoch 27, Loss: 0.0112
Epoch 28, Loss: 0.0135
Epoch 29, Loss: 0.0125
Epoch 30, Loss: 0.0125
Epoch 31, Loss: 0.0111
Epoch 32, Loss: 0.0086
Epoch 33, Loss: 0.0086
Epoch 34, Loss: 0.0076
Epoch 35, Loss: 0.0080
Epoch 36, Loss: 0.0078
Epoch 37, Loss: 0.0081
Epoch 38, Loss: 0.0070
Epoch 39, Loss: 0.0074
Epoch 40, Loss: 0.0059
Epoch 41, Loss: 0.0073
Epoch 42, Loss: 0.0058
Epoch 43, Loss: 0.0080
Epoch 44, Loss: 0.00

In [40]:
import torch
from sentence_transformers import SentenceTransformer
from types import SimpleNamespace
#from src.model.graph_encoder import GraphEncoder
#from src.model.gnn import load_gnn_model

def retrieve_relevant_graph(graphs, query, topk=1):
    # 1. Load text encoder
    text_encoder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")

    # 2. Define GNN parameters
    graph_encoder_args = SimpleNamespace(
        gnn_model_name="gat",
        gnn_in_dim=384,
        gnn_hidden_dim=128,
        gnn_num_layers=2,
        gnn_dropout=0.1,
        gnn_num_heads=4,
    )

    # 3. Create the graph encoder
    # graph_encoder = GraphEncoder(graph_encoder_args)
    # or load a pretrained model
    graph_encoder = GNNTrainingUtils.load_model(GraphEncoder, graph_encoder_args, "/content/pretrained.pth")

    graph_encoder.eval()

    # 4. Encode all graphs
    pyg_graphs = transform_graphs_to_pyg(graphs)
    batch = Batch.from_data_list(pyg_graphs)
    with torch.no_grad():
        graph_reprs = graph_encoder.encode(batch)  # shape: [num_graphs, 4096]

    # 5. Encode query and project
    text_projection = nn.Linear(384, 4096)
    with torch.no_grad():
        q_emb = text_encoder.encode(query, convert_to_tensor=True)
        q_emb = text_projection(q_emb)  # shape: [4096]

    # 6. Compute similarities
    sims = torch.nn.functional.cosine_similarity(q_emb.unsqueeze(0), graph_reprs)

    # 7. Retrieve top-k indices and scores
    top_scores, top_indices = torch.topk(sims, k=topk)
    selected_graphs = [graphs[i] for i in top_indices.tolist()]

    return selected_graphs, top_indices.tolist(), sims, q_emb


retrieve once

In [51]:
query = "how can i treat asthma?"
# Get the most similar graph
selected_graphs, top_indices, sims, q_emb = retrieve_relevant_graph(toy_graphs, query)
#print(f"Most relevant graph index: {top_idx}")
#print(selected_graph)

matched_graph = toy_graphs[top_indices[0]]
# Print with formatting
print("=== Most Similar Graph ===")
print(f"Match Score: {sims[top_indices[0]]:.3f}")
print(f"Graph : {matched_graph}")

=== Most Similar Graph ===
Match Score: 0.016
Graph : [('covid-19', 'affects', 'lungs'), ('vaccine', 'prevents', 'covid-19'), ('covid-19', 'symptom', 'loss of smell')]


retrieve 50x for testing

In [52]:
from collections import defaultdict
import numpy as np

query = "how can i treat asthma?"
number_of_retrievals = 50
top_k = 1

# Run retrievals
all_results = []
for _ in range(number_of_retrievals):
    _, top_indices, sims, _ = retrieve_relevant_graph(toy_graphs, query, topk=top_k)
    for idx in top_indices:
        all_results.append((idx, sims[idx].item()))

# Analyze
stats = defaultdict(list)
for idx, score in all_results:
    stats[idx].append(score)

print("\n=== Detailed Statistics ===")
for idx in sorted(stats.keys(), key=lambda x: -np.mean(stats[x])):
    scores = stats[idx]
    graph = toy_graphs[idx]
    print(f"\nGraph {idx}:")
    print(f"  Frequency: {len(scores)}/{number_of_retrievals * top_k}")
    print(f"  Avg Score: {np.mean(scores):.3f} ± {np.std(scores):.3f}")
    print(f"  Preview  : {graph[0]}...")


=== Detailed Statistics ===

Graph 0:
  Frequency: 1/50
  Avg Score: 0.024 ± 0.000
  Preview  : ('asthma', 'caused_by', 'allergens')...

Graph 5:
  Frequency: 3/50
  Avg Score: 0.019 ± 0.003
  Preview  : ('covid-19', 'affects', 'lungs')...

Graph 2:
  Frequency: 11/50
  Avg Score: 0.014 ± 0.018
  Preview  : ('copd', 'risk_factor', 'smoking')...

Graph 4:
  Frequency: 16/50
  Avg Score: 0.010 ± 0.011
  Preview  : ('pneumonia', 'caused_by', 'bacteria')...

Graph 3:
  Frequency: 13/50
  Avg Score: 0.005 ± 0.017
  Preview  : ('bronchitis', 'caused_by', 'virus')...

Graph 1:
  Frequency: 6/50
  Avg Score: 0.003 ± 0.011
  Preview  : ('asthma', 'treated_by', 'inhaler')...


# extract subgraph

In [None]:
import pandas as pd

# 4. Extract subgraphs via PCST
selected_subgraphs = []
descriptions = []
topk_nodes=5
topk_edges=3
cost_e=0.5

for idx in top_indices:
    graph = pyg_graphs[idx]
    triples = toy_graphs[idx]  # List of (src, rel, dst)

    # Unique nodes
    nodes = sorted(set(n.lower() for triple in triples for n in (triple[0], triple[2])))

    textual_nodes = pd.DataFrame({'node': [f"node: {name}" for name in nodes]})
    textual_edges = pd.DataFrame(triples, columns=["src", "edge_attr", "dst"])
    textual_edges["edge_attr"] = textual_edges["edge_attr"].apply(lambda x: f"relation: {x}")

    subgraph, desc = retrieval_via_pcst(
        graph, q_emb, textual_nodes, textual_edges,
        topk=topk_nodes, topk_e=topk_edges, cost_e=cost_e
    )
    selected_subgraphs.append(subgraph)
    descriptions.append(desc)

print(selected_subgraphs)
print(descriptions)

print("\n=== PCST Subgraph Summary ===")
print(f"Nodes: {subgraph.num_nodes}")
print(f"Edges: {subgraph.edge_index.size(1)}")

# Convert edge_index to readable form
print("\nSelected Triples:")
for i in range(subgraph.edge_index.size(1)):
    src_idx = subgraph.edge_index[0, i].item()
    dst_idx = subgraph.edge_index[1, i].item()
    edge_vec = subgraph.edge_attr[i]
    # Try to find matching textual triple (fallback to index if needed)
    try:
        src = textual_nodes.iloc[src_idx]["node"]
        dst = textual_nodes.iloc[dst_idx]["node"]
        rel = textual_edges.iloc[i]["edge_attr"]
    except:
        src, dst, rel = src_idx, dst_idx, "[vector]"
    print(f"  ({src}) --[{rel}]--> ({dst})")

# Print node names
print("\nIncluded Nodes:")
for i in range(subgraph.num_nodes):
    node_name = textual_nodes.iloc[i]["node"]
    print(f"  - {node_name}")

[Data(x=[1, 384], edge_index=[2, 0], edge_attr=[0, 384], num_nodes=1)]
['node\nnode: covid-19\n\nsrc,edge_attr,dst\n']

=== PCST Subgraph Summary ===
Nodes: 1
Edges: 0

Selected Triples:

Included Nodes:
  - node: covid-19
