In [None]:
import sys

sys.path.append("..")

In [None]:
import matplotlib.pyplot as plt
import networkx as nx
from data.dataset import MyKarateClub
from torch_geometric.loader import LinkNeighborLoader
from torch_geometric.sampler import NegativeSampling
from torch_geometric.transforms import RandomLinkSplit
from torch_geometric.utils import to_networkx

## Dataset Overview

In [None]:
dataset = MyKarateClub()

dataset.data

In [None]:
dataset.get_summary()

In [None]:
G = to_networkx(dataset.data)
nx.draw(G, with_labels=True, pos=nx.spiral_layout(G))

## Edge Split into train/val/test for Link Prediction

In [None]:
transform = RandomLinkSplit(
    is_undirected=True, split_labels=False, add_negative_train_samples=False
)
train_data, val_data, test_data = transform(dataset.data)

In [None]:
fig, axes = plt.subplots(3, 1, figsize=(5, 15))

for data, ax in zip([train_data, val_data, test_data], axes.flatten()):
    G = to_networkx(data)
    label = {id: dataset.node_index_map[id] for id in data.x.squeeze().tolist()}
    nx.draw(G, ax=ax, label=label, with_labels=True, pos=nx.spiral_layout(G))
    print(data)

# Link Neighbor Loader

In [None]:
train_loader = LinkNeighborLoader(
    train_data,
    num_neighbors=[10, 5],
    edge_label_index=train_data.edge_label_index,
    edge_label=train_data.edge_label,
    batch_size=1,
    shuffle=False,
    neg_sampling=NegativeSampling(mode="binary", amount=1),
)
val_loader = LinkNeighborLoader(
    val_data,
    num_neighbors=[10, 5],
    edge_label_index=val_data.edge_label_index,
    edge_label=val_data.edge_label,
    batch_size=1,
    shuffle=False,
    # Do not sample negative edges for validation
    # https://github.com/pyg-team/pytorch_geometric/discussions/9164
    # but this parameter is duplicated
    # https://pytorch-geometric.readthedocs.io/en/2.5.3/modules/loader.html?highlight=LinkNeighborLoader#torch_geometric.loader.LinkNeighborLoader
    neg_sampling_ratio=0,
)

In [None]:
for data in train_loader:
    print(data.edge_label.unique(return_counts=True))

In [None]:
for data in val_loader:
    print(data.edge_label.unique(return_counts=True))

In [None]:
batch = next(iter(train_loader))

label = {id: dataset.node_index_map[id] for id in batch.n_id.tolist()}
G = to_networkx(batch)
nx.draw(G, label=label, with_labels=True, pos=nx.spiral_layout(G))