In [51]:
import time

import numpy as np
import torch
from torch_geometric.data import Data
from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import NormalizeFeatures
from torch_geometric.utils import subgraph

In [43]:
dataset = Planetoid("../data", "cora", transform=NormalizeFeatures())

In [44]:
num_nodes = dataset[0].num_nodes
num_features = dataset.num_node_features
num_classes = dataset.num_classes

In [45]:
data = dataset[0]

In [46]:
def drop_random_edges(edge_index: torch.Tensor, p_drop: float = 0.2) -> Data:
    num_edges = edge_index.shape[1]
    num_edges_drop = int(p_drop * num_edges)
    mask = torch.ones(num_edges, dtype=bool)
    drop_indices = np.random.choice(num_edges, num_edges_drop, replace=False)
    mask[drop_indices] = False
    edge_index_dropped = edge_index[:, mask]

    return edge_index_dropped


def add_random_edges(edge_index: torch.Tensor, _num_nodes: int, p_add: float = 0.2) -> torch.Tensor:
    num_edges = edge_index.shape[1]
    num_edges_add = int(p_add * num_edges)
    edge_index_np = edge_index.cpu().numpy()
    edges_add = np.random.randint(low=0, high=_num_nodes, size=(2, num_edges_add))

    # Ensure new edges are not duplicates
    for i in range(num_edges_add):
        while edges_add[:, i] in edge_index_np.T:
            edges_add[:, i] = np.random.randint(low=0, high=_num_nodes, size=2)

    # Add new edges to graph
    edge_index_added = torch.cat([edge_index, torch.tensor(edges_add, dtype=torch.long)], dim=1)

    return edge_index_added


def edge_perturbation(_data: Data, p_drop: float, p_add: float) -> Data:
    start_time = time.time()

    edge_index = _data.edge_index
    _num_nodes = _data.num_nodes
    edge_index_dropped = drop_random_edges(edge_index, p_drop)
    edge_index_added = add_random_edges(edge_index_dropped, _num_nodes, p_add)

    _data.edge_index = edge_index_added

    end_time = time.time()
    print("Perturbation took {:.2f} seconds".format(end_time - start_time))

    return _data

In [47]:
data_perturbed = edge_perturbation(data, 0.2, 0.2)

Perturbation took 31.41 seconds


In [48]:
def drop_random_nodes(_data: Data, p_drop: float = 0.2) -> Data:
    _num_nodes = _data.num_nodes
    num_nodes_drop = int(p_drop * _num_nodes)
    mask = torch.ones(_num_nodes, dtype=bool)
    drop_indices = torch.randperm(_num_nodes)[:num_nodes_drop]
    mask[drop_indices] = False

    # Drop from edge index
    edge_index = _data.edge_index
    rows, cols = edge_index
    mask_edge = mask[rows] & mask[cols]
    edge_index_dropped = edge_index[:, mask_edge]

    # Re-index edge index
    remaining_nodes = torch.arange(_num_nodes)[mask]
    map_new_indices = torch.full((_num_nodes,), -1, dtype=torch.long)
    map_new_indices[remaining_nodes] = torch.arange(remaining_nodes.size(0))
    edge_index_dropped = map_new_indices[edge_index_dropped]

    # Drop nodes from node features if present
    if _data.x is not None:
        _data.x = _data.x[mask]

    _data.edge_index = edge_index_dropped
    _data.num_nodes = remaining_nodes.size(0)

    return _data

In [56]:
def get_subgraph(_data: Data, p_sample: float = 0.2, walk_length: int = 10, max_attempts: int = 100) -> Data:
    edge_index = _data.edge_index
    _num_nodes = _data.num_nodes
    num_nodes_sample = int(p_sample * _num_nodes)

    sampled_nodes = torch.tensor([], dtype=torch.long)
    attempts = 0

    while sampled_nodes.size(0) < num_nodes_sample and attempts < max_attempts:
        start_node = torch.randint(0, _num_nodes, (1,))
        walk_nodes = [start_node.item()]
        current_node = start_node.item()

        for _ in range(walk_length):
            neighbors = edge_index[1, edge_index[0] == current_node]
            if neighbors.size(0) == 0:
                break  # current_node is a leaf node
            current_node = neighbors[torch.randint(0, neighbors.size(0), (1,))].item()
            walk_nodes.append(current_node)

        sampled_nodes = torch.unique(torch.cat([sampled_nodes, torch.tensor(walk_nodes, dtype=torch.long)]))
        attempts += 1

    subgraph_nodes, subgraph_edges = subgraph(sampled_nodes, edge_index)
    data_subgraph = _data.clone()
    data_subgraph.edge_index = subgraph_edges

    # If node features exist, extract subgraph node features
    if _data.x is not None:
        data_subgraph.x = _data.x[sampled_nodes]

    data_subgraph.num_nodes = sampled_nodes.size(0)

    return data_subgraph

In [57]:
def attribute_masking(_data: Data, p_mask: float = 0.2) -> Data:
    _num_nodes = _data.num_nodes
    num_node_features = _data.x.size(1)
    num_features_mask = int(p_mask * num_node_features)
    mask = torch.ones(num_node_features, dtype=bool)
    mask_indices = torch.randperm(num_node_features)[:num_features_mask]
    mask[mask_indices] = False
    _data.x_masked = _data.x.clone()
    _data.x_masked[:, mask] = 0
    _data.attr_mask = mask

    return _data