In [1]:
import pickle
import torch

In [2]:
graphium2 = pickle.load(open("val_graphs_graphium2_new.pkl", "rb"))
graphium2[:2]

[{'labels': DataBatch(graph_zinc=[1, 3], x=[24, 1], edge_index=[2, 52], graph_qm9=[1, 19], graph_tox21=[1, 12], batch=[24], ptr=[2]),
  'features': DataBatch(edge_index=[2, 52], edge_weight=[52], num_nodes=24, feat=[24, 85], edge_feat=[52, 13], laplacian_eigvec=[24, 8], laplacian_eigval=[24, 8], rw_return_probs=[24, 16], batch=[24], ptr=[2])},
 {'labels': DataBatch(graph_zinc=[1, 3], x=[21, 1], edge_index=[2, 46], graph_tox21=[1, 12], graph_qm9=[1, 19], batch=[21], ptr=[2]),
  'features': DataBatch(edge_index=[2, 46], edge_weight=[46], num_nodes=21, feat=[21, 85], edge_feat=[46, 13], laplacian_eigvec=[21, 8], laplacian_eigval=[21, 8], rw_return_probs=[21, 16], batch=[21], ptr=[2])}]

In [3]:
graphium3 = pickle.load(open("val_graphs_graphium3_new.pkl", "rb"))
graphium3[:2]

[{'labels': DataBatch(graph_qm9=[1, 19], graph_zinc=[1, 3], graph_tox21=[1, 12]),
  'features': DataBatch(edge_index=[2, 16], edge_weight=[16], num_nodes=8, feat=[8, 85], edge_feat=[16, 13], laplacian_eigvec=[8, 8], laplacian_eigval=[8, 8], rw_return_probs=[8, 16], batch=[8], ptr=[2])},
 {'labels': DataBatch(graph_zinc=[1, 3], graph_qm9=[1, 19], graph_tox21=[1, 12]),
  'features': DataBatch(edge_index=[2, 42], edge_weight=[42], num_nodes=20, feat=[20, 85], edge_feat=[42, 13], laplacian_eigvec=[20, 8], laplacian_eigval=[20, 8], rw_return_probs=[20, 16], batch=[20], ptr=[2])}]

In [4]:
def is_same_graph(graph1, graph2, tolerance=0.1):
    """
    Check if two graphs are the same by comparing the shape of their features, and the values of their labels.
    """

    num_nodes1 = graph1["features"].num_nodes
    num_nodes2 = graph2["features"].num_nodes
    num_edges1 = graph1["features"].num_edges
    num_edges2 = graph2["features"].num_edges
    if (num_nodes1 != num_nodes2) or (num_edges1 != num_edges2):
        return False

    qm9_1 = graph1["labels"]["graph_qm9"]
    tox21_1 = graph1["labels"]["graph_tox21"]
    zinc_1 = graph1["labels"]["graph_zinc"]
    qm9_2 = graph2["labels"]["graph_qm9"]
    tox21_2 = graph2["labels"]["graph_tox21"]
    zinc_2 = graph2["labels"]["graph_zinc"]


    if (qm9_1.shape != qm9_2.shape) or (tox21_1.shape != tox21_2.shape) or (zinc_1.shape != zinc_2.shape):
        return False

    qm9_1 = qm9_1[~qm9_1.isnan()]
    tox21_1 = tox21_1[~tox21_1.isnan()]
    zinc_1 = zinc_1[~zinc_1.isnan()]
    qm9_2 = qm9_2[~qm9_2.isnan()]
    tox21_2 = tox21_2[~tox21_2.isnan()]
    zinc_2 = zinc_2[~zinc_2.isnan()]

    if (qm9_1.shape != qm9_2.shape) or (tox21_1.shape != tox21_2.shape) or (zinc_1.shape != zinc_2.shape):
        return False
        
    if not torch.allclose(qm9_1, qm9_2, atol=tolerance):
        return False
    if not torch.allclose(tox21_1, tox21_2, atol=tolerance):
        return False
    if not torch.allclose(zinc_1, zinc_2, atol=tolerance):
        return False
    
    return True
    

def non_empty_label(graph):
    qm9 = graph["labels"]["graph_qm9"]
    tox21 = graph["labels"]["graph_tox21"]
    zinc = graph["labels"]["graph_zinc"]

    non_empty = []

    if not torch.all(qm9.isnan()):
        non_empty.append("qm9")
    if not torch.all(tox21.isnan()):
        non_empty.append("tox21")
    if not torch.all(zinc.isnan()):
        non_empty.append("zinc")

    return non_empty

In [5]:
# Validate function by checking that each graph is ONLY the same as itself, for the graphium2 graphs
matches = []
for ii, g in enumerate(graphium2):
    for jj, gg in enumerate(graphium2):
        if is_same_graph(g, gg):
            matches.append((ii, jj))
print(matches[:5])
matches_ten = torch.as_tensor(matches)
print((matches_ten[:, 0] == matches_ten[:, 1]).sum().tolist(), "/", len(matches_ten))

[(0, 0), (1, 1), (2, 2), (3, 3), (4, 4)]
30 / 30


In [6]:
# Validate function by checking that each graph is ONLY the same as itself, for the graphium3 graphs
matches = []
for ii, g in enumerate(graphium3):
    for jj, gg in enumerate(graphium3):
        if is_same_graph(g, gg):
            matches.append((ii, jj))
print(matches[:5])
matches_ten = torch.as_tensor(matches)
print((matches_ten[:, 0] == matches_ten[:, 1]).sum().tolist(), "/", len(matches_ten))

[(0, 0), (1, 1), (2, 2), (3, 3), (4, 4)]
30 / 30


In [7]:
# Find the matches between the two datasets from graphium2 and graphium3
matches = []
for ii, g2 in enumerate(graphium2):
    for jj, g3 in enumerate(graphium3):
        if is_same_graph(g2, g3):
            print(ii, jj, "Found match", non_empty_label(g2))
            matches.append((ii, jj))
        # else:
        #     print(ii, jj, "No Match")
matches

0 15 Found match ['zinc']
1 14 Found match ['zinc']
2 13 Found match ['tox21']
3 5 Found match ['tox21']
5 25 Found match ['tox21']
6 2 Found match ['tox21']
8 8 Found match ['tox21']
11 22 Found match ['tox21']
12 18 Found match ['tox21']
23 17 Found match ['tox21']


[(0, 15),
 (1, 14),
 (2, 13),
 (3, 5),
 (5, 25),
 (6, 2),
 (8, 8),
 (11, 22),
 (12, 18),
 (23, 17)]

In [8]:
# Since their are no matches for QM9, print their labels to visually inspect

LABEL = "qm9"
print("Graphium 2 \n_________________________")
for ii, g2 in enumerate(graphium2):
    if LABEL in non_empty_label(g2):
        print(ii, g2["labels"][f"graph_{LABEL}"][0, :5])

print("\n\nGraphium 3 \n_________________________")
for ii, g3 in enumerate(graphium3):
    if LABEL in non_empty_label(g3):
        print(ii, g3["labels"][f"graph_{LABEL}"][0, :5])

Graphium 2 
_________________________
7 tensor([-1.0777,  0.9168,  0.2093,  0.4276, -0.0869], dtype=torch.float64)
9 tensor([-0.6980,  1.8550,  2.3131, -1.4528,  0.6577], dtype=torch.float64)
10 tensor([ 3.5083, -0.8910,  0.4519, -1.9198, -0.1023], dtype=torch.float64)
14 tensor([ 0.6422, -0.9315,  0.0098, -1.3373,  0.3904], dtype=torch.float64)
16 tensor([ 2.7967, -2.0454, -1.4393, -0.4990, -0.2233], dtype=torch.float64)
17 tensor([-0.7930,  2.7671,  3.6079, -0.9251, -0.1265], dtype=torch.float64)
19 tensor([ 0.4698, -0.1997, -0.3652, -0.3209, -0.3607], dtype=torch.float64)
25 tensor([ 2.0471,  1.0800,  1.0429, -0.8319,  0.3090], dtype=torch.float64)
27 tensor([-0.6739, -0.1215,  0.4055,  0.5390,  0.7896], dtype=torch.float64)
28 tensor([ 1.8100,  6.9263,  4.9668,  0.1054, -1.6760], dtype=torch.float64)


Graphium 3 
_________________________
0 tensor([ 1.8190, -0.8910,  0.4519, -1.7172, -0.1023], dtype=torch.float64)
10 tensor([ 0.6422, -0.9315,  0.0098, -1.3373,  0.3904], dtype=torc

In [9]:
from scipy.optimize import linear_sum_assignment

def find_row_mapping(graph1, graph2, p: float = 2.0) -> torch.Tensor:
    """
    Find which rows in graph1 correspond to which rows in graph2 using the Hungarian algorithm.

    Given two 2D tensors A and B of shape (n, d), this function returns a tensor of shape (n, 2)
    where each row [i, j] indicates that the i-th row of A corresponds to the j-th row of B.
    
    Parameters:
      A (torch.Tensor): A tensor of shape (n, d)
      B (torch.Tensor): A tensor of shape (n, d)
      p (float): The norm degree to use in torch.cdist (default is 2 for Euclidean distance).
    
    Returns:
      torch.Tensor: A tensor of shape (n, 2) containing the mapping pairs.
    """

    # Concatenate the deterministic node features "feat", Laplacian eigenvalues, and random walk return probabilities
    A = torch.cat([graph1["features"]["feat"], graph1["features"]["laplacian_eigval"], graph1["features"]["rw_return_probs"]], dim=1).float()
    B = torch.cat([graph2["features"]["feat"], graph2["features"]["laplacian_eigval"], graph2["features"]["rw_return_probs"]], dim=1).float()

    # Compute the pairwise cost matrix (squared Euclidean distances)
    cost_matrix = torch.cdist(A, B, p=p) ** 2  # Shape: (n, n)

    # Convert cost matrix to NumPy array for the Hungarian algorithm
    cost_np = cost_matrix.cpu().numpy()

    # Solve the assignment problem using SciPy's linear_sum_assignment (Hungarian algorithm)
    row_ind, col_ind = linear_sum_assignment(cost_np)

    # Gather the error (cost) for each assignment pair
    errors = cost_matrix[row_ind, col_ind]

    # Stack the row indices, column indices, and error values into a mapping tensor of shape (n, 3)
    mapping = torch.stack([torch.tensor(row_ind), torch.tensor(col_ind), errors], dim=1)
    return mapping


In [10]:
# For the molecules that matched, check if we can find a mapping between the node features on the rows.
# We actually find that there is always a perfect match, but the order of the rows is different.
# First column is the order in graph1, second is order in graph2, third is the error (distance) between the rows.
# We see the distance is always 0
for match in matches:
    mapping = find_row_mapping(graphium2[match[0]], graphium3[match[1]])
    print(mapping[:5], f"\nDistances sum to {mapping[:, 2].sum()}")

tensor([[ 0.,  0.,  0.],
        [ 1., 21.,  0.],
        [ 2.,  7.,  0.],
        [ 3., 22.,  0.],
        [ 4., 23.,  0.]]) 
Distances sum to 0.0
tensor([[ 0.,  0.,  0.],
        [ 1.,  2.,  0.],
        [ 2.,  5.,  0.],
        [ 3., 19.,  0.],
        [ 4., 14.,  0.]]) 
Distances sum to 0.0
tensor([[ 0.,  0.,  0.],
        [ 1.,  8.,  0.],
        [ 2., 11.,  0.],
        [ 3., 19.,  0.],
        [ 4., 18.,  0.]]) 
Distances sum to 0.0
tensor([[0., 0., 0.],
        [1., 8., 0.],
        [2., 9., 0.],
        [3., 3., 0.],
        [4., 1., 0.]]) 
Distances sum to 0.0
tensor([[0., 0., 0.],
        [1., 7., 0.],
        [2., 5., 0.],
        [3., 2., 0.],
        [4., 8., 0.]]) 
Distances sum to 0.0
tensor([[0., 0., 0.],
        [1., 9., 0.],
        [2., 4., 0.],
        [3., 1., 0.],
        [4., 7., 0.]]) 
Distances sum to 0.0
tensor([[0., 0., 0.],
        [1., 8., 0.],
        [2., 9., 0.],
        [3., 2., 0.],
        [4., 3., 0.]]) 
Distances sum to 0.0
tensor([[0., 0., 0.],
  

In [11]:
# For the molecules that matched, visually inspect their laplacian eigenvectors and compute their differences
for match in matches:
    mapping = find_row_mapping(graphium2[match[0]], graphium3[match[1]])
    eigvec2 = graphium2[match[0]]["features"]["laplacian_eigvec"][mapping[:, 0].long()]
    eigvec3 = graphium3[match[1]]["features"]["laplacian_eigvec"][mapping[:, 1].long()]
    print(((eigvec2.abs() - eigvec3.abs()).abs().mean(dim=0)).numpy())

[0.     0.     0.     0.     0.     0.     0.     0.0487]
[0. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0. 0. 0. 0.]
[0.     0.     0.     0.2756 0.2229 0.     0.     0.    ]
[0. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0. 0. 0. 0.]
[0.     0.     0.3613 0.2842 0.     0.1685 0.1224 0.    ]
[0. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0. 0. 0. 0.]
