import torch

# Load the tensor and map it to the CPU
data = torch.load('data_o_new2.pt', map_location=torch.device('cpu'))

# Print the shape of the tensor
print(data)



MRCGNN Revised Model Definition

In [1]:
!pip install torch==2.3.1
!pip install torch_geometric==2.6.1
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import RGCNConv

class Discriminator(nn.Module):
    def __init__(self, n_h):
        super(Discriminator, self).__init__()
        self.f_k = nn.Bilinear(32, 32, 1)
        for m in self.modules():
            self.weights_init(m)
    def weights_init(self, m):
        if isinstance(m, nn.Bilinear):
            torch.nn.init.xavier_uniform_(m.weight.data)
            if m.bias is not None:
                m.bias.data.fill_(0.0)
    def forward(self, c, h_pl, h_mi, s_bias1=None, s_bias2=None):
        c_x = c.expand_as(h_pl)
        sc_1 = self.f_k(h_pl, c_x)
        sc_2 = self.f_k(h_mi, c_x)
        if s_bias1 is not None:
            sc_1 += s_bias1
        if s_bias2 is not None:
            sc_2 += s_bias2
        logits = torch.cat((sc_1, sc_2), 1)
        return logits

class AvgReadout(nn.Module):
    def __init__(self):
        super(AvgReadout, self).__init__()
    def forward(self, seq, msk=None):
        if msk is None:
            return torch.mean(seq, 0)
        else:
            msk = torch.unsqueeze(msk, -1)
            return torch.sum(seq * msk, 0) / torch.sum(msk)

class MRCGNN(nn.Module):
    def __init__(self, feature, hidden1, hidden2, decoder1, dropout, zhongzi):
        super(MRCGNN, self).__init__()

        # RGCN layers for the main (data_o) branch
        self.encoder_o1 = RGCNConv(feature, hidden1, num_relations=65)
        self.encoder_o2 = RGCNConv(hidden1, hidden2, num_relations=65)

        # Two-element parameter for layer attention
        self.attt = nn.Parameter(torch.tensor([0.5, 0.5]))
        self.disc = Discriminator(hidden2 * 2)
        self.dropout = dropout
        self.sigm = nn.Sigmoid()
        self.read = AvgReadout()
        
        # Final classifier: prediction solely from data_o branch.
        # Each node's final representation is a concatenation of (hidden1 + hidden2).
        # For a pair of entities, the dimension becomes 2*(hidden1+hidden2).
        self.classifier = nn.Linear(2 * (hidden1 + hidden2), 65)

        # We no longer load any pretrained features for skip connection.

    def forward(self, data_o, data_s, data_a, idx):
        # Process data_o branch
        x_o, adj, e_type = data_o.x, data_o.edge_index, data_o.edge_type
        e_type1 = data_a.edge_type
        e_type = torch.tensor(e_type, dtype=torch.int64)
        e_type1 = torch.tensor(e_type1, dtype=torch.int64)

        # Main branch for prediction (data_o)
        x1_o = F.relu(self.encoder_o1(x_o, adj, e_type))
        x1_o = F.dropout(x1_o, self.dropout, training=self.training)
        x2_o = self.encoder_o2(x1_o, adj, e_type)

        # Contrastive learning branches (unused in prediction)
        x_a = data_s.x
        x1_o_a = F.relu(self.encoder_o1(x_a, adj, e_type))
        x1_o_a = F.dropout(x1_o_a, self.dropout, training=self.training)
        x2_o_a = self.encoder_o2(x1_o_a, adj, e_type)

        x1_o_a_a = F.relu(self.encoder_o1(x_o, adj, e_type1))
        x1_o_a_a = F.dropout(x1_o_a_a, self.dropout, training=self.training)
        x2_o_a_a = self.encoder_o2(x1_o_a_a, adj, e_type1)

        # Readout for contrastive learning
        h_os = self.read(x2_o)
        h_os = self.sigm(h_os)
        ret_os = self.disc(h_os, x2_o, x2_o_a)
        ret_os_a = self.disc(h_os, x2_o, x2_o_a_a)

        # For final prediction, use only data_o branch:
        final = torch.cat((self.attt[0] * x1_o, self.attt[1] * x2_o), dim=1)

        a = [int(i) for i in list(idx[0])]
        b = [int(i) for i in list(idx[1])]
        aa = torch.tensor(a, dtype=torch.long)
        bb = torch.tensor(b, dtype=torch.long)
        entity1 = final[aa]
        entity2 = final[bb]
        concatenate = torch.cat((entity1, entity2), dim=1)
        log = self.classifier(concatenate)

        return log, ret_os, ret_os_a, x2_o

    def predict(self, data_o, idx):
        """
        New prediction method that uses only data_o and idx.
        """
        x_o, adj, e_type = data_o.x, data_o.edge_index, data_o.edge_type
        e_type = torch.tensor(e_type, dtype=torch.int64)
        # Process the main branch
        x1_o = F.relu(self.encoder_o1(x_o, adj, e_type))
        x1_o = F.dropout(x1_o, self.dropout, training=self.training)
        x2_o = self.encoder_o2(x1_o, adj, e_type)
        final = torch.cat((self.attt[0] * x1_o, self.attt[1] * x2_o), dim=1)
        
        a = [int(i) for i in list(idx[0])]
        b = [int(i) for i in list(idx[1])]
        aa = torch.tensor(a, dtype=torch.long)
        bb = torch.tensor(b, dtype=torch.long)
        entity1 = final[aa]
        entity2 = final[bb]
        concatenate = torch.cat((entity1, entity2), dim=1)
        log = self.classifier(concatenate)
        return log

Collecting torch==2.3.1
  Downloading torch-2.3.1-cp312-none-macosx_11_0_arm64.whl.metadata (26 kB)
Downloading torch-2.3.1-cp312-none-macosx_11_0_arm64.whl (61.0 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.0/61.0 MB[0m [31m11.0 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hInstalling collected packages: torch
  Attempting uninstall: torch
    Found existing installation: torch 2.5.1
    Uninstalling torch-2.5.1:
      Successfully uninstalled torch-2.5.1
Successfully installed torch-2.3.1


Load My Model

In [2]:
model = MRCGNN(feature=128, hidden1=64, hidden2=32, decoder1=512, dropout=0.5, zhongzi=0)
model.load_state_dict(torch.load("model_mrcgnn.pt", map_location=torch.device('cpu')))
model.eval()

model.load_state_dict(torch.load("model_mrcgnn.pt", map_location='cpu'))
model.eval()

MRCGNN(
  (encoder_o1): RGCNConv(128, 64, num_relations=65)
  (encoder_o2): RGCNConv(64, 32, num_relations=65)
  (disc): Discriminator(
    (f_k): Bilinear(in1_features=32, in2_features=32, out_features=1, bias=True)
  )
  (sigm): Sigmoid()
  (read): AvgReadout()
  (classifier): Linear(in_features=192, out_features=65, bias=True)
)

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric as pyg
from torch_geometric.utils import k_hop_subgraph, to_networkx
import networkx as nx
import numpy as np

##############################################################################
# 1) LOAD YOUR DATA
##############################################################################
data = torch.load('data_o_new2.pt', map_location=torch.device('cpu'))
print("Data object:", data)
# data is a single graph with data.x, data.edge_index, data.edge_type, etc.
# e.g.: Data(x=[572, 128], edge_index=[2, 52112], edge_type=[52112])


##############################################################################
# 2) EXTRACT COMPUTATION TREES & CREATE CONCEPT VECTOR
##############################################################################
L = 3  # number of hops for each node’s subgraph (the “computation tree”)

# We'll collect:
#   node_ctree_codes[v] = string code for node v’s L-hop subgraph
#   unique_ctree_codes   = dict {code_str -> assigned_id}
node_ctree_codes = []
unique_ctree_codes = {}

def simple_dfs_code(G, root):
    """
    Simple DFS code from a NetworkX graph G (treated as a tree),
    starting at node 'root'. This is just a placeholder function:
    you'd use a canonical labeling or something more robust in production.
    """
    code = []
    for n in nx.dfs_preorder_nodes(G, source=root):
        code.append(str(n))
    return "-".join(code)

for v in range(data.num_nodes):
    # Extract L-hop subgraph
    subset, sub_edge_index, mapping, _ = k_hop_subgraph(
        node_idx=v,
        num_hops=L,
        edge_index=data.edge_index,
        relabel_nodes=True
    )
    # Build a small Data object
    sub_data = pyg.data.Data(
        x=data.x[subset],
        edge_index=sub_edge_index
    )
    # Convert to NetworkX
    G_sub = to_networkx(sub_data, to_undirected=True)
    
    # In PyG’s relabel_nodes=True, the root node v becomes "0" in sub_data.
    ctree_code = simple_dfs_code(G_sub, root=0)
    node_ctree_codes.append(ctree_code)
    if ctree_code not in unique_ctree_codes:
        unique_ctree_codes[ctree_code] = len(unique_ctree_codes)

# --- Step 2: Build the concept vector for the entire graph. ---
# This vector is length (# unique ctree codes),
# and entry i = count of nodes that have that code.
concept_vector = np.zeros(len(unique_ctree_codes), dtype=int)
for code in node_ctree_codes:
    idx = unique_ctree_codes[code]
    concept_vector[idx] += 1

print("Number of unique ctree codes:", len(unique_ctree_codes))
print("Concept vector (frequency of each code):")
print(concept_vector)

# Print out the first 5 unique computation tree codes and their frequencies:
print("Sample unique computation tree codes and frequencies:")
for i, (code, idx) in enumerate(unique_ctree_codes.items()):
    if i >= 5:
        break
    freq = concept_vector[idx]
    print(f"Code {idx}: {code}")
    print(f"Frequency: {freq}\n")

Data object: Data(x=[572, 128], edge_index=[2, 52112], edge_type=[52112])
Number of unique ctree codes: 24
Concept vector (frequency of each code):
[522   5   4   1   9   9   1   1   2   1   1   1   1   1   3   1   1   1
   1   1   1   1   1   2]
Sample unique computation tree codes and frequencies:
Code 0: 0-408-22-477-474-214-16-130-182-365-33-63-57-240-250-8-527-366-223-113-233-31-311-154-61-228-437-426-310-161-418-274-567-13-406-279-60-346-114-19-320-398-166-53-120-211-556-502-348-563-522-18-40-545-224-225-202-321-388-503-383-167-352-193-419-410-111-183-370-330-194-252-264-151-273-349-195-216-363-507-534-47-128-123-253-255-288-394-54-555-282-77-354-324-339-405-99-44-226-520-17-209-326-415-298-43-266-338-490-483-89-548-169-500-376-291-145-64-441-488-323-565-529-178-95-160-232-347-105-454-414-355-198-399-448-436-269-157-28-29-424-480-416-49-132-51-537-335-340-242-119-236-88-485-146-501-450-131-258-452-382-482-11-20-333-451-542-290-286-245-515-237-7-155-281-197-73-261-91-510-430-557-2

DEFINE "REMOVE CONCEPTS" FUNCTION

In [4]:
def remove_concepts_from_graph(data, node_ctree_codes, unique_ctree_codes, subset_of_codes):
    """
    Prune out nodes whose L-hop code is not in 'subset_of_codes'.
    Returns a new Data object and a mapping dict (old_index -> new_index).
    """
    keep_nodes_list = []
    for v in range(data.num_nodes):
        if node_ctree_codes[v] in subset_of_codes:
            keep_nodes_list.append(v)
    
    mapping = {old: new for new, old in enumerate(keep_nodes_list)}
    
    if not keep_nodes_list:
        # No nodes remain; return empty data and empty mapping
        new_data = pyg.data.Data(
            x=torch.empty((0, data.x.shape[1])),
            edge_index=torch.empty((2, 0), dtype=torch.long),
            edge_type=torch.empty((0,), dtype=torch.long)
        )
        return new_data, mapping

    keep_nodes = torch.tensor(keep_nodes_list, dtype=torch.long)
    # Filter node features
    x_new = data.x[keep_nodes]

    # Filter edges and corresponding edge_type
    edges = []
    e_types = []
    keep_set = set(keep_nodes_list)
    for i in range(data.edge_index.size(1)):
        src = data.edge_index[0, i].item()
        dst = data.edge_index[1, i].item()
        if (src in keep_set) and (dst in keep_set):
            edges.append([src, dst])
            e_types.append(data.edge_type[i])  # No .item() here
    if edges:
        edges = torch.tensor(edges, dtype=torch.long).t().contiguous()
        e_types = torch.tensor(e_types, dtype=torch.int64)
    else:
        edges = torch.empty((2,0), dtype=torch.long)
        e_types = torch.empty((0,), dtype=torch.int64)
    
    # Relabel node IDs
    edges = _relabel_edge_index(edges, keep_nodes)

    new_data = pyg.data.Data(x=x_new, edge_index=edges, edge_type=e_types)
    return new_data, mapping

def _relabel_edge_index(edge_index, keep_nodes):
    old_to_new = {old: i for i, old in enumerate(keep_nodes.tolist())}
    new_edges = []
    for i in range(edge_index.size(1)):
        src_old = edge_index[0, i].item()
        dst_old = edge_index[1, i].item()
        new_edges.append([old_to_new[src_old], old_to_new[dst_old]])
    if len(new_edges) == 0:
        return torch.empty((2, 0), dtype=torch.long)
    new_edges = torch.tensor(new_edges, dtype=torch.long).t().contiguous()
    return new_edges

DEFINE A VALUE FUNCTION THAT RUNS MRCGNN

In [5]:
def value_function(model, data, idx,
                   subset_of_codes, node_ctree_codes, unique_ctree_codes,
                   target_class=None):
    """
    1) Remove concepts not in 'subset_of_codes'
    2) Update idx based on the new node numbering
    3) Run model.predict() and return a scalar value.
    """
    # 1) Prune the graph and get the mapping from old indices to new indices.
    modified_data, mapping = remove_concepts_from_graph(
        data, node_ctree_codes, unique_ctree_codes, subset_of_codes
    )
    
    # 2) Update idx: For each index in idx, if it exists in mapping, use the new index.
    new_src = []
    new_dst = []
    for src, dst in zip(idx[0].tolist(), idx[1].tolist()):
        if src in mapping and dst in mapping:
            new_src.append(mapping[src])
            new_dst.append(mapping[dst])
    if len(new_src) == 0 or len(new_dst) == 0:
        return 0.0  # no valid pairs remain

    new_idx = (torch.tensor(new_src, dtype=torch.long),
               torch.tensor(new_dst, dtype=torch.long))
    
    # 3) Forward pass
    with torch.no_grad():
        out = model.predict(modified_data, new_idx)  # shape: [num_pairs, 65]
        if out.shape[0] == 0:
            return 0.0

        if target_class is None:
            # Use the predicted class for each pair and average the logit.
            preds = out.argmax(dim=1)
            chosen_logits = out[torch.arange(out.size(0)), preds]
            val = chosen_logits.mean().item()
        else:
            chosen_logits = out[:, target_class]
            val = chosen_logits.mean().item()
    return val


SHAPLEY VALUES (SAMPLING APPROACH)

In [6]:
def shapley_values(model, data, idx,
                   unique_ctree_codes, node_ctree_codes,
                   target_class=None,
                   num_samples=50):
    """
    Approximate Shapley values via random permutations.
    """
    concepts = list(unique_ctree_codes.keys())  # DFS-code strings
    M = len(concepts)
    shap = np.zeros(M, dtype=float)

    for _ in range(num_samples):
        perm = np.random.permutation(M)
        current_subset = set()
        old_val = value_function(
            model, data, idx,
            current_subset, node_ctree_codes, unique_ctree_codes,
            target_class
        )
        for j in range(M):
            c_idx = perm[j]
            c_code = concepts[c_idx]
            new_subset = current_subset.union({c_code})
            new_val = value_function(
                model, data, idx,
                new_subset, node_ctree_codes, unique_ctree_codes,
                target_class
            )
            shap[c_idx] += (new_val - old_val)
            current_subset = new_subset
            old_val = new_val

    shap /= num_samples
    code2shap = {concepts[i]: shap[i] for i in range(M)}
    return code2shap


RUN SHAPLEY & DISPLAY RESULTS

In [7]:
shap_vals = shapley_values(
    model=model,
    data=data,
    idx=idx,
    unique_ctree_codes=unique_ctree_codes,
    node_ctree_codes=node_ctree_codes,
    target_class=None,  # or an int specifying which class
    num_samples=50
)

# Sort Shapley results
sorted_shap = sorted(shap_vals.items(), key=lambda x: x[1], reverse=True)
print("\nTop-10 Concepts by Shapley Value:")
for c, val in sorted_shap[:10]:
    print(f"{c} => {val:.4f}")

# Done!

TypeError: 'int' object is not subscriptable

In [None]:
# 
print(sorted_shap[-1][1])