In [16]:
from torch_geometric.data import Data
import torch
import random

In [2]:
torch.manual_seed(0)

<torch._C.Generator at 0x1819ea8bf50>

In [33]:
x=torch.tensor([[1,2,3,4,5],
               [2,3,4,5,6],
               [3,4,5,6,7],
               [4,5,6,7,8],
               [6,7,8,9,0],
               [1,1,1,1,1]])

In [34]:
edge=torch.tensor([[1,2],
                  [2,3],
                  [0,1],
                  [4,5],
                  [5,5],
                  [3,4],
                  [1,5],
                  [1,4]])

In [35]:
edge_index=edge.reshape((2,8))

In [36]:
data=Data(x=x,edge_index=edge_index)

In [37]:
# train, test, val masks for each node
train_mask = torch.tensor([True] * round(edge_index.unique().size(0) * 0.8) + 
                          [False]* (edge_index.unique().size(0) - round(edge_index.unique().size(0) * 0.8)))
test_mask = torch.tensor([False] * round(edge_index.unique().size(0) * 0.8) + 
                         [True]* (round(edge_index.unique().size(0) * 0.1)) + 
                         [False]* (edge_index.unique().size(0) - round(edge_index.unique().size(0) * 0.8) 
                                   - round(edge_index.unique().size(0) * 0.1)))
val_mask = torch.tensor([False] * round(edge_index.unique().size(0) * 0.8) + 
                        [False]* (round(edge_index.unique().size(0) * 0.1)) + 
                        [True]* (edge_index.unique().size(0) - round(edge_index.unique().size(0) * 0.8) 
                                 - round(edge_index.unique().size(0) * 0.1)))

new_data = Data(edge_index=edge_index, 
                x=x, 
                train_mask = train_mask, 
                val_mask=val_mask, 
                test_mask=test_mask)

In [43]:
new_data

Data(x=[6, 5], train_mask=[6], val_mask=[6], test_mask=[6], val_pos_edge_index=[2, 0], test_pos_edge_index=[2, 0], train_pos_edge_index=[2, 10], train_neg_adj_mask=[6, 6], val_neg_edge_index=[2, 0], test_neg_edge_index=[2, 0])

In [10]:
new_data.train_mask

tensor([ True,  True,  True,  True,  True, False])

In [11]:
new_data.val_mask

tensor([False, False, False, False, False, False])

In [12]:
new_data.test_mask

tensor([False, False, False, False, False,  True])

In [24]:
def train_test_split_edges(data, val_ratio=0.05, test_ratio=0.1, undirected=True):
    r"""Splits the edges of a :obj:`torch_geometric.data.Data` object
    into positive and negative train/val/test edges, and adds attributes of
    `train_pos_edge_index`, `train_neg_adj_mask`, `val_pos_edge_index`,
    `val_neg_edge_index`, `test_pos_edge_index`, and `test_neg_edge_index`
    to :attr:`data`.
    Args:
        data (Data): The data object.
        val_ratio (float, optional): The ratio of positive validation
            edges. (default: :obj:`0.05`)
        test_ratio (float, optional): The ratio of positive test
            edges. (default: :obj:`0.1`)
    :rtype: :class:`torch_geometric.data.Data`
    """

    assert 'batch' not in data  # No batch-mode.

    random.seed(77)
    torch.manual_seed(77)

    num_nodes = data.num_nodes
    row, col = data.edge_index
    data.edge_index = None

    # Return upper triangular portion.
    mask = row < col
    row, col = row[mask], col[mask]

    n_v = int(math.floor(val_ratio * row.size(0)))
    n_t = int(math.floor(test_ratio * row.size(0)))

    # Positive edges.
    perm = torch.randperm(row.size(0))
    row, col = row[perm], col[perm]

    r, c = row[:n_v], col[:n_v]
    data.val_pos_edge_index = torch.stack([r, c], dim=0)
    r, c = row[n_v:n_v + n_t], col[n_v:n_v + n_t]
    data.test_pos_edge_index = torch.stack([r, c], dim=0)

    r, c = row[n_v + n_t:], col[n_v + n_t:]
    data.train_pos_edge_index = torch.stack([r, c], dim=0)
    if undirected:
        data.train_pos_edge_index = to_undirected(data.train_pos_edge_index)

    # Negative edges.
    neg_adj_mask = torch.ones(num_nodes, num_nodes, dtype=torch.uint8)
    neg_adj_mask = neg_adj_mask.triu(diagonal=1).to(torch.bool)
    neg_adj_mask[row, col] = 0

    neg_row, neg_col = neg_adj_mask.nonzero(as_tuple=False).t()
    perm = torch.randperm(neg_row.size(0))[:n_v + n_t]
    neg_row, neg_col = neg_row[perm], neg_col[perm]

    neg_adj_mask[neg_row, neg_col] = 0
    data.train_neg_adj_mask = neg_adj_mask

    row, col = neg_row[:n_v], neg_col[:n_v]
    data.val_neg_edge_index = torch.stack([row, col], dim=0)

    row, col = neg_row[n_v:n_v + n_t], neg_col[n_v:n_v + n_t]
    data.test_neg_edge_index = torch.stack([row, col], dim=0)

    return data

In [39]:
a=train_test_split_edges(new_data)

In [23]:
from torch_geometric.utils import to_undirected

In [42]:
new_data.edge_index

In [44]:
a

Data(x=[6, 5], train_mask=[6], val_mask=[6], test_mask=[6], val_pos_edge_index=[2, 0], test_pos_edge_index=[2, 0], train_pos_edge_index=[2, 10], train_neg_adj_mask=[6, 6], val_neg_edge_index=[2, 0], test_neg_edge_index=[2, 0])

In [45]:
a.train_pos_edge_index

tensor([[0, 1, 1, 2, 2, 3, 3, 4, 5, 5],
        [1, 0, 5, 3, 5, 2, 4, 3, 1, 2]])

In [49]:
edge.reshape((2,8))

tensor([[1, 2, 2, 3, 0, 1, 4, 5],
        [5, 5, 3, 4, 1, 5, 1, 4]])

In [50]:
a.train_neg_adj_mask

tensor([[False, False,  True,  True,  True,  True],
        [False, False,  True,  True,  True, False],
        [False, False, False, False,  True, False],
        [False, False, False, False, False,  True],
        [False, False, False, False, False,  True],
        [False, False, False, False, False, False]])