In [2]:
import torch
from torch_geometric.data import Data


# 1. Your MoleculeData Class definition
class MoleculeData(Data):
    def __init__(self, x=None, a=None, e=None, edge_index=None, c=None, **kwargs):
        super().__init__(x=x, a=a, e=e, edge_index=edge_index, c=c, **kwargs)


# 2. The permuted_subgraph function
def permuted_subgraph(data, subset):
    device = data.x.device
    num_nodes = data.num_nodes

    mapping = torch.full((num_nodes,), -1, dtype=torch.long, device=device)
    mapping[subset] = torch.arange(subset.size(0), device=device)

    mask = (mapping[data.edge_index[0]] >= 0) & (mapping[data.edge_index[1]] >= 0)
    new_edge_index = mapping[data.edge_index[:, mask]]

    new_x = data.x[subset] if data.x is not None else None

    return MoleculeData(x=new_x, edge_index=new_edge_index)


# --- THE TEST CASE ---


def test_permuted_subgraph():
    print("=== Setting up Test Graph ===")
    # Nodes: 0, 1, 2
    # Features (X): [100, 200, 300] (Easy to track)
    # Edges: 0->1 and 1->2
    x = torch.tensor([[100.0], [200.0], [300.0]])
    edge_index = torch.tensor([[0, 1], [1, 2]], dtype=torch.long)

    data = MoleculeData(x=x, edge_index=edge_index)
    data.num_nodes = 3

    print(f"Original X:\n{data.x.squeeze().numpy()}")
    print(f"Original Edges:\n{data.edge_index.numpy()}")

    # Define a permutation (subset)
    # We want the new graph to have order: Old Node 2, then Old Node 0, then Old Node 1
    subset = torch.tensor([2, 0, 1], dtype=torch.long)
    print(f"\nApplying Permutation/Subset: {subset.numpy()}")
    print(
        "Expected New Order: Node 2 (300) -> Index 0, Node 0 (100) -> Index 1, Node 1 (200) -> Index 2"
    )

    # Run function
    new_data = permuted_subgraph(data, subset)

    # --- INSPECTION ---
    print("\n=== Results ===")

    # Check Features
    print(f"New X:\n{new_data.x.squeeze().numpy()}")
    expected_x = torch.tensor([300.0, 100.0, 200.0])
    assert torch.equal(new_data.x.squeeze(), expected_x), (
        "❌ Features did not shuffle correctly!"
    )
    print("✅ Features shuffled correctly.")

    # Check Edge Rewiring
    # Old Edge (0->1):
    #   Old 0 is now at index 1
    #   Old 1 is now at index 2
    #   New Edge should be (1->2)

    # Old Edge (1->2):
    #   Old 1 is now at index 2
    #   Old 2 is now at index 0
    #   New Edge should be (2->0)

    print(f"New Edges:\n{new_data.edge_index.numpy()}")

    # We expect edges [[1, 2], [2, 0]] (order of edges doesn't matter, but values do)
    # Let's sort to compare easily
    sorted_edges, _ = torch.sort(new_data.edge_index, dim=1)
    # Note: Just manual visual check is usually enough here, but let's be rigorous
    has_1_2 = ((new_data.edge_index[0] == 1) & (new_data.edge_index[1] == 2)).any()
    has_2_0 = ((new_data.edge_index[0] == 2) & (new_data.edge_index[1] == 0)).any()

    if has_1_2 and has_2_0:
        print("✅ Edges rewired correctly (0->1 became 1->2; 1->2 became 2->0).")
    else:
        print("❌ Edge rewiring failed.")


test_permuted_subgraph()


=== Setting up Test Graph ===
Original X:
[100. 200. 300.]
Original Edges:
[[0 1]
 [1 2]]

Applying Permutation/Subset: [2 0 1]
Expected New Order: Node 2 (300) -> Index 0, Node 0 (100) -> Index 1, Node 1 (200) -> Index 2

=== Results ===
New X:
[300. 100. 200.]
✅ Features shuffled correctly.
New Edges:
[[1 2]
 [2 0]]
✅ Edges rewired correctly (0->1 became 1->2; 1->2 became 2->0).
