<a href="https://colab.research.google.com/github/draginverse/dragin-healthcare/blob/feature%2Fg-retriever/scripts/retrieval/gretriever.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install torch_geometric
!pip install pcst_fast
!pip install torch_scatter -f https://data.pyg.org/
!pip install datasets

Collecting torch_geometric
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/63.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m2.0 MB/s[0m eta [36m0:00:00[0m
Downloading torch_geometric-2.6.1-py3-none-any.whl (1.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m15.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torch_geometric
Successfully installed torch_geometric-2.6.1
Collecting pcst_fast
  Downloading pcst_fast-1.0.10-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.8 kB)
Collecting pybind11>=2.1.0 (from pcst_fast)
  Downloading pybind11-2.13.6-py3-none-any.whl.metadata (9.5 kB)
Downloading pcst_fast-1.0.10-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [

# simulate graphs input

In [2]:
toy_graphs = [
    [
        ("asthma", "caused_by", "allergens"),
        ("inhaler", "treats", "asthma"),
        ("asthma", "symptom", "shortness of breath")
    ],
    [
        ("copd", "risk_factor", "smoking"),
        ("oxygen therapy", "treats", "copd"),
        ("copd", "symptom", "chronic cough")
    ],
    [
        ("bronchitis", "caused_by", "virus"),
        ("bronchitis", "symptom", "chest discomfort"),
        ("rest", "helps_with", "bronchitis")
    ],
    [
        ("pneumonia", "caused_by", "bacteria"),
        ("antibiotics", "treats", "pneumonia"),
        ("pneumonia", "symptom", "fever")
    ],
    [
        ("covid-19", "affects", "lungs"),
        ("vaccine", "prevents", "covid-19"),
        ("covid-19", "symptom", "loss of smell")
    ]
]


transform graphs to pyg (needed for the encoder)

In [4]:
from transformers import AutoModel, AutoTokenizer
import torch
from torch_geometric.data import Data, Batch

# 1. Load the MiniLM model
model_name = "sentence-transformers/all-MiniLM-L6-v2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# 2. Text to embedding function (optimized for MiniLM)
def text2embedding(texts):
    inputs = tokenizer(texts, padding=True, truncation=True,
                      max_length=128, return_tensors="pt").to(device)
    with torch.no_grad():
        outputs = model(**inputs)
    # Use mean pooling instead of CLS token for better performance
    embeddings = outputs.last_hidden_state.mean(dim=1).cpu()
    return embeddings

# 3. Graph transformation function
def transform_graphs_to_pyg(triple_graphs):
    pyg_graphs = []

    for triples in triple_graphs:
        # Extract all node names
        node_names = set()
        for src, _, dst in triples:
            node_names.update([src.lower(), dst.lower()])
        node_names = sorted(node_names)
        node_map = {name: idx for idx, name in enumerate(node_names)}

        # Edge list and texts
        edge_list = []
        edge_texts = []
        for src, rel, dst in triples:
            edge_list.append([node_map[src.lower()], node_map[dst.lower()]])
            edge_texts.append(f"relation: {rel}")

        # Node embeddings
        node_texts = [f"node: {name}" for name in node_names]
        node_embeddings = text2embedding(node_texts)
        edge_embeddings = text2embedding(edge_texts) if edge_texts else torch.zeros(0, 384)

        # Create PyG graph
        pyg_graph = Data(
            x=node_embeddings,
            edge_index=torch.tensor(edge_list).t().contiguous(),
            edge_attr=edge_embeddings,
            num_nodes=len(node_names)
        )
        pyg_graphs.append(pyg_graph)

    return pyg_graphs

# ======== Verify consistency ===============
# Transform the graphs
pyg_graphs = transform_graphs_to_pyg(toy_graphs)

# Create a batch of graphs for processing
toy_graph_batch = Batch.from_data_list(pyg_graphs)

# Print information about the first graph
print("First graph in PyG format:")
print(pyg_graphs[0])
#print("\nNode mapping:", {name: idx for idx, name in enumerate(node_encoder.classes_)})
#print("Edge type mapping:", {name: idx for idx, name in enumerate(edge_type_encoder.classes_)})
print("\nBatch information:")
print(toy_graph_batch)
print("Batch vector:", toy_graph_batch.batch)

print("Total nodes:", toy_graph_batch.num_nodes)
print("Batch vector max index:", toy_graph_batch.batch.max())
print("Batch vector length:", len(toy_graph_batch.batch))

assert toy_graph_batch.batch.max() < len(pyg_graphs), "Batch indices exceed graph count"
assert len(toy_graph_batch.batch) == toy_graph_batch.num_nodes, "Batch vector length mismatch"

First graph in PyG format:
Data(x=[4, 384], edge_index=[2, 3], edge_attr=[3, 384], num_nodes=4)

Batch information:
DataBatch(x=[20, 384], edge_index=[2, 15], edge_attr=[15, 384], num_nodes=20, batch=[20], ptr=[6])
Batch vector: tensor([0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4])
Total nodes: 20
Batch vector max index: tensor(4)
Batch vector length: 20


# gnn.py

In [5]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, TransformerConv, GATConv


class GCN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers, dropout, num_heads=-1):
        super(GCN, self).__init__()
        self.convs = torch.nn.ModuleList()
        self.convs.append(GCNConv(in_channels, hidden_channels))
        self.bns = torch.nn.ModuleList()
        self.bns.append(torch.nn.BatchNorm1d(hidden_channels))
        for _ in range(num_layers - 2):
            self.convs.append(GCNConv(hidden_channels, hidden_channels))
            self.bns.append(torch.nn.BatchNorm1d(hidden_channels))
        self.convs.append(GCNConv(hidden_channels, out_channels))
        self.dropout = dropout

    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()
        for bn in self.bns:
            bn.reset_parameters()

    def forward(self, x, adj_t, edge_attr):
        for i, conv in enumerate(self.convs[:-1]):
            x = conv(x, adj_t)
            x = self.bns[i](x)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.convs[-1](x, adj_t)
        return x, edge_attr


class GraphTransformer(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers, dropout, num_heads=-1):
        super(GraphTransformer, self).__init__()
        self.convs = torch.nn.ModuleList()
        self.convs.append(TransformerConv(in_channels=in_channels, out_channels=hidden_channels//num_heads, heads=num_heads, edge_dim=in_channels, dropout=dropout))
        self.bns = torch.nn.ModuleList()
        self.bns.append(torch.nn.BatchNorm1d(hidden_channels))
        for _ in range(num_layers - 2):
            self.convs.append(TransformerConv(in_channels=hidden_channels, out_channels=hidden_channels//num_heads, heads=num_heads, edge_dim=in_channels, dropout=dropout,))
            self.bns.append(torch.nn.BatchNorm1d(hidden_channels))
        self.convs.append(TransformerConv(in_channels=hidden_channels, out_channels=out_channels//num_heads, heads=num_heads, edge_dim=in_channels, dropout=dropout,))
        self.dropout = dropout

    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()
        for bn in self.bns:
            bn.reset_parameters()

    def forward(self, x, adj_t, edge_attr):
        for i, conv in enumerate(self.convs[:-1]):
            x = conv(x, edge_index=adj_t, edge_attr=edge_attr)
            x = self.bns[i](x)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.convs[-1](x, edge_index=adj_t, edge_attr=edge_attr)
        return x, edge_attr

class GAT(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers, dropout, num_heads=4):
        super(GAT, self).__init__()
        self.convs = torch.nn.ModuleList()
        self.convs.append(GATConv(in_channels, hidden_channels, heads=num_heads, concat=False))
        self.bns = torch.nn.ModuleList()
        self.bns.append(torch.nn.BatchNorm1d(hidden_channels))
        for _ in range(num_layers - 2):
            self.convs.append(GATConv(hidden_channels, hidden_channels, heads=num_heads, concat=False))
            self.bns.append(torch.nn.BatchNorm1d(hidden_channels))
        self.convs.append(GATConv(hidden_channels, out_channels, heads=num_heads, concat=False))
        self.dropout = dropout

    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()
        for bn in self.bns:
            bn.reset_parameters()

    def forward(self, x, edge_index, edge_attr):
        for i, conv in enumerate(self.convs[:-1]):
            x = conv(x, edge_index=edge_index, edge_attr=edge_attr)
            x = self.bns[i](x)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.convs[-1](x,edge_index=edge_index, edge_attr=edge_attr)
        return x, edge_attr


load_gnn_model = {
    'gcn': GCN,
    'gat': GAT,
    'gt': GraphTransformer,
}

# graph_encoder.py (embeds graphs)

In [6]:
import torch
import torch.nn as nn
from torch_scatter import scatter
#from gnn import load_gnn_model

class GraphEncoder(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.graph_encoder = load_gnn_model[args.gnn_model_name](
            in_channels=args.gnn_in_dim,
            out_channels=args.gnn_hidden_dim,
            hidden_channels=args.gnn_hidden_dim,
            num_layers=args.gnn_num_layers,
            dropout=args.gnn_dropout,
            num_heads=args.gnn_num_heads,
        )
        self.projector = nn.Sequential(
            nn.Linear(args.gnn_hidden_dim, 2048),
            nn.Sigmoid(),
            nn.Linear(2048, 4096),
        )

    def encode(self, graphs):
        graphs = graphs.to(next(self.parameters()).device)
        n_embeds, _ = self.graph_encoder(graphs.x, graphs.edge_index.long(), graphs.edge_attr)
        g_embeds = scatter(n_embeds, graphs.batch, dim=0, reduce='mean')
        projected_embeds = self.projector(g_embeds)
        return projected_embeds

# pcst_retrieval.py

In [None]:
import torch
import numpy as np
from pcst_fast import pcst_fast
from torch_geometric.data.data import Data


def retrieval_via_pcst(graph, q_emb, textual_nodes, textual_edges, topk=3, topk_e=3, cost_e=0.5):
    c = 0.01
    if len(textual_nodes) == 0 or len(textual_edges) == 0:
        desc = textual_nodes.to_csv(index=False) + '\n' + textual_edges.to_csv(index=False, columns=['src', 'edge_attr', 'dst'])
        graph = Data(x=graph.x, edge_index=graph.edge_index, edge_attr=graph.edge_attr, num_nodes=graph.num_nodes)
        return graph, desc

    root = -1  # unrooted
    num_clusters = 1
    pruning = 'gw'
    verbosity_level = 0
    if topk > 0:
        n_prizes = torch.nn.CosineSimilarity(dim=-1)(q_emb, graph.x)
        topk = min(topk, graph.num_nodes)
        _, topk_n_indices = torch.topk(n_prizes, topk, largest=True)

        n_prizes = torch.zeros_like(n_prizes)
        n_prizes[topk_n_indices] = torch.arange(topk, 0, -1).float()
    else:
        n_prizes = torch.zeros(graph.num_nodes)

    if topk_e > 0:
        e_prizes = torch.nn.CosineSimilarity(dim=-1)(q_emb, graph.edge_attr)
        topk_e = min(topk_e, e_prizes.unique().size(0))

        topk_e_values, _ = torch.topk(e_prizes.unique(), topk_e, largest=True)
        e_prizes[e_prizes < topk_e_values[-1]] = 0.0
        last_topk_e_value = topk_e
        for k in range(topk_e):
            indices = e_prizes == topk_e_values[k]
            value = min((topk_e-k)/sum(indices), last_topk_e_value)
            e_prizes[indices] = value
            last_topk_e_value = value*(1-c)
        # reduce the cost of the edges such that at least one edge is selected
        cost_e = min(cost_e, e_prizes.max().item()*(1-c/2))
    else:
        e_prizes = torch.zeros(graph.num_edges)

    costs = []
    edges = []
    vritual_n_prizes = []
    virtual_edges = []
    virtual_costs = []
    mapping_n = {}
    mapping_e = {}
    for i, (src, dst) in enumerate(graph.edge_index.T.numpy()):
        prize_e = e_prizes[i]
        if prize_e <= cost_e:
            mapping_e[len(edges)] = i
            edges.append((src, dst))
            costs.append(cost_e - prize_e)
        else:
            virtual_node_id = graph.num_nodes + len(vritual_n_prizes)
            mapping_n[virtual_node_id] = i
            virtual_edges.append((src, virtual_node_id))
            virtual_edges.append((virtual_node_id, dst))
            virtual_costs.append(0)
            virtual_costs.append(0)
            vritual_n_prizes.append(prize_e - cost_e)

    prizes = np.concatenate([n_prizes, np.array(vritual_n_prizes)])
    num_edges = len(edges)
    if len(virtual_costs) > 0:
        costs = np.array(costs+virtual_costs)
        edges = np.array(edges+virtual_edges)

    vertices, edges = pcst_fast(edges, prizes, costs, root, num_clusters, pruning, verbosity_level)

    selected_nodes = vertices[vertices < graph.num_nodes]
    selected_edges = [mapping_e[e] for e in edges if e < num_edges]
    virtual_vertices = vertices[vertices >= graph.num_nodes]
    if len(virtual_vertices) > 0:
        virtual_vertices = vertices[vertices >= graph.num_nodes]
        virtual_edges = [mapping_n[i] for i in virtual_vertices]
        selected_edges = np.array(selected_edges+virtual_edges)

    edge_index = graph.edge_index[:, selected_edges]
    selected_nodes = np.unique(np.concatenate([selected_nodes, edge_index[0].numpy(), edge_index[1].numpy()]))

    n = textual_nodes.iloc[selected_nodes]
    e = textual_edges.iloc[selected_edges]
    desc = n.to_csv(index=False)+'\n'+e.to_csv(index=False, columns=['src', 'edge_attr', 'dst'])

    mapping = {n: i for i, n in enumerate(selected_nodes.tolist())}

    x = graph.x[selected_nodes]
    edge_attr = graph.edge_attr[selected_edges]
    src = [mapping[i] for i in edge_index[0].tolist()]
    dst = [mapping[i] for i in edge_index[1].tolist()]
    edge_index = torch.LongTensor([src, dst])
    data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, num_nodes=len(selected_nodes))

    return data, desc

# compute graph embeddings & retrieve the most relevant graph:

In [7]:
import torch
from sentence_transformers import SentenceTransformer
from types import SimpleNamespace
#from src.model.graph_encoder import GraphEncoder
#from src.model.gnn import load_gnn_model

def retrieve_relevant_graph(toy_graphs, query):
    # 1. Load text encoder
    text_encoder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")  # Or another model

    # 2. Define GNN parameters
    graph_encoder_args = SimpleNamespace(
        gnn_model_name="gat",
        gnn_in_dim=384,        # SBERT dim
        gnn_hidden_dim=128,
        gnn_num_layers=2,
        gnn_dropout=0.1,
        gnn_num_heads=4,       # Needed for GAT
    )

    # 3. Create the graph encoder
    graph_encoder = GraphEncoder(graph_encoder_args)

    graph_encoder.eval()  # Disable dropout during inference

    # 4. Encode all graphs into fixed-size graph representations
    '''graph_reprs = []
    for graph in toy_graph_batch:
        with torch.no_grad():
            repr = graph_encoder.encode(graph)  # add batch dimension if needed
            graph_reprs.append(repr.squeeze(0))  # shape: [hidden_dim]

    graph_reprs = torch.stack(graph_reprs)  # shape: [num_graphs, hidden_dim]'''
    # ===
    # 4.1. Transform all graphs to PyG format
    pyg_graphs = transform_graphs_to_pyg(toy_graphs)

    # 4.2. Create one batch containing all graphs
    batch = Batch.from_data_list(pyg_graphs)

    # 4.3. Encode all graphs in one forward pass (most efficient)
    with torch.no_grad():
        graph_reprs = graph_encoder.encode(batch)
    # ===
    '''# Test each graph separately
    for i, data in enumerate(pyg_graphs):
        try:
            single_batch = Batch.from_data_list([data])
            with torch.no_grad():
                graph_encoder.encode(single_batch)
            print(f"Graph {i} processed successfully")
        except Exception as e:
            print(f"Error in graph {i}: {str(e)}")'''

    # 5. Encode the query
    q_enc = text_encoder.encode(query, convert_to_tensor=True) # shape: [hidden_dim]
    q_emb = q_enc.clone().detach()
    # Process query with projection
    text_projection = nn.Linear(384, 4096)
    with torch.no_grad():
        q_emb = text_encoder.encode(query, convert_to_tensor=True)
        q_emb = text_projection(q_emb)  # Now shape: [4096]

    # 6. Compare and select the most similar graph
    sims = torch.nn.functional.cosine_similarity(q_emb.unsqueeze(0), graph_reprs)  # shape: [num_graphs]
    top_idx = torch.argmax(sims).item()

    # 7. Get the most relevant graph
    selected_graph = toy_graph_batch[top_idx]
    #selected_node_texts = toy_node_texts[top_idx]
    #selected_edge_texts = toy_edge_texts[top_idx]

    return selected_graph, top_idx, sims


retrieve once

In [10]:
query = "how can i treat asthma?"
# Get the most similar graph
selected_graph, top_idx, sims = retrieve_relevant_graph(toy_graphs, query)
#print(f"Most relevant graph index: {top_idx}")
#print(selected_graph)

matched_graph = toy_graphs[top_idx]
# Print with formatting
print("=== Most Similar Graph ===")
print(f"Match Score: {sims[top_idx]:.3f}")
print(f"Graph : {matched_graph}")

Most relevant graph index: 3
Data(x=[4, 384], edge_index=[2, 3], edge_attr=[3, 384], num_nodes=4)
=== Most Similar Graph ===
Match Score: 0.019
Graph : [('pneumonia', 'caused_by', 'bacteria'), ('antibiotics', 'treats', 'pneumonia'), ('pneumonia', 'symptom', 'fever')]


retrieve 50x for testing

In [14]:
from collections import defaultdict
import numpy as np

query = "how can i treat asthma?"
number_of_retrievals = 50

# Run retrievals
all_results = []
for _ in range(number_of_retrievals):
    _, top_idx, sims = retrieve_relevant_graph(toy_graphs, query)
    all_results.append((top_idx, sims[top_idx].item()))

# Analyze
stats = defaultdict(list)
for idx, score in all_results:
    stats[idx].append(score)

print("\n=== Detailed Statistics ===")
for idx in sorted(stats.keys(), key=lambda x: -np.mean(stats[x])):
    scores = stats[idx]
    graph = toy_graphs[idx]
    print(f"\nGraph {idx}:")
    print(f"  Frequency: {len(scores)}/{number_of_retrievals}")
    print(f"  Avg Score: {np.mean(scores):.3f} ± {np.std(scores):.3f}")
    print(f"  Preview : {graph[0]}...")


=== Detailed Statistics ===

Graph 0:
  Frequency: 4/50
  Avg Score: 0.007 ± 0.009
  Preview : ('asthma', 'caused_by', 'allergens')...

Graph 3:
  Frequency: 19/50
  Avg Score: 0.001 ± 0.018
  Preview : ('pneumonia', 'caused_by', 'bacteria')...

Graph 4:
  Frequency: 7/50
  Avg Score: 0.000 ± 0.016
  Preview : ('covid-19', 'affects', 'lungs')...

Graph 1:
  Frequency: 12/50
  Avg Score: -0.001 ± 0.020
  Preview : ('copd', 'risk_factor', 'smoking')...

Graph 2:
  Frequency: 8/50
  Avg Score: -0.005 ± 0.014
  Preview : ('bronchitis', 'caused_by', 'virus')...
