In [1]:
import dgl
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
import dgl.data

dataset = dgl.data.CoraGraphDataset()

  NumNodes: 2708
  NumEdges: 10556
  NumFeats: 1433
  NumClasses: 7
  NumTrainingSamples: 140
  NumValidationSamples: 500
  NumTestSamples: 1000
Done loading data from cached files.


In [3]:
g = dataset[0]

In [4]:
from dgl.nn import GraphConv

class GCN(nn.Module):
    def __init__(self, in_feats, h_feats, num_classes):
        super(GCN, self).__init__()
        self.conv1 = GraphConv(in_feats, h_feats)
        self.conv2 = GraphConv(h_feats, num_classes)
        # self.conv3 = GraphConv(10, num_classes)
        # self.seq = nn.Linear(5, num_classes)
    
    def forward(self, g, in_feat):
        h = self.conv1(g, in_feat)
        h = F.relu(h)
        h = self.conv2(g, h)
        # h = F.relu(h)
        # h = self.conv3(g, h)
        # h = F.relu(h)
        # h = self.seq(h)
        return h

In [5]:
def train(g, model):
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    best_val_acc = 0
    best_test_acc = 0

    features = g.ndata['feat']
    labels = g.ndata['label']
    train_mask = g.ndata['train_mask']
    val_mask = g.ndata['val_mask']
    test_mask = g.ndata['test_mask']
    for e in range(100):
        # Forward
        logits = model(g, features)

        # Compute prediction
        pred = logits.argmax(1)

        # Compute loss
        # Note that you should only compute the losses of the nodes in the training set.
        loss = F.cross_entropy(logits[train_mask], labels[train_mask])

        # Compute accuracy on training/validation/test
        train_acc = (pred[train_mask] == labels[train_mask]).float().mean()
        val_acc = (pred[val_mask] == labels[val_mask]).float().mean()
        test_acc = (pred[test_mask] == labels[test_mask]).float().mean()

        # Save the best validation accuracy and the corresponding test accuracy.
        if best_val_acc < val_acc:
            best_val_acc = val_acc
            best_test_acc = test_acc

        # Backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if e % 5 == 0:
            print('In epoch {}, loss: {:.3f}, val acc: {:.3f} (best {:.3f}), test acc: {:.3f} (best {:.3f})'.format(
                e, loss, val_acc, best_val_acc, test_acc, best_test_acc))
model = GCN(g.ndata['feat'].shape[1], 16, dataset.num_classes)
train(g, model)



In epoch 0, loss: 1.946, val acc: 0.136 (best 0.136), test acc: 0.130 (best 0.130)
In epoch 5, loss: 1.891, val acc: 0.510 (best 0.510), test acc: 0.501 (best 0.501)
In epoch 10, loss: 1.810, val acc: 0.568 (best 0.574), test acc: 0.591 (best 0.585)
In epoch 15, loss: 1.709, val acc: 0.572 (best 0.574), test acc: 0.595 (best 0.585)
In epoch 20, loss: 1.586, val acc: 0.592 (best 0.592), test acc: 0.615 (best 0.615)
In epoch 25, loss: 1.442, val acc: 0.662 (best 0.662), test acc: 0.663 (best 0.663)
In epoch 30, loss: 1.285, val acc: 0.694 (best 0.694), test acc: 0.696 (best 0.696)
In epoch 35, loss: 1.121, val acc: 0.720 (best 0.720), test acc: 0.718 (best 0.718)
In epoch 40, loss: 0.959, val acc: 0.726 (best 0.726), test acc: 0.722 (best 0.722)
In epoch 45, loss: 0.805, val acc: 0.744 (best 0.744), test acc: 0.733 (best 0.733)
In epoch 50, loss: 0.666, val acc: 0.750 (best 0.752), test acc: 0.741 (best 0.742)
In epoch 55, loss: 0.546, val acc: 0.754 (best 0.754), test acc: 0.747 (best 0

In [6]:
import dgl
import numpy as np
import torch

In [7]:
# star shepd graph

In [8]:
g = dgl.graph(([0, 0, 0, 0, 0], [1, 2, 3, 4, 5]), num_nodes=6)

In [9]:
# Equivalently, PyTorch LongTensors also work.
g = dgl.graph((torch.LongTensor([0, 0, 0, 0, 0]), torch.LongTensor([1, 2, 3, 4, 5])), num_nodes=6)

In [10]:
# You can omit the number of nodes argument if you can tell the number of nodes from the edge list alone.
g = dgl.graph(([0, 0, 0, 0, 0], [1, 2, 3, 4, 5]))

# assigning features to graph

In [11]:
# assigning 3d node feature vector for each node
g.ndata['x'] = torch.randn(6,3)

# assigning 4d edge feature vector for each edge
g.edata['a'] = torch.randn(5,4)

# assigning 5x4 node feature matrix for each node. DGL supports multi-dimensional features
g.ndata['y'] = torch.randn(6,5,4)

In [12]:
print(g.ndata['x'])

tensor([[-0.1922,  1.7157, -0.7712],
        [ 0.2557,  1.8470, -1.5688],
        [ 0.8149, -1.0861,  1.2249],
        [-1.0675,  0.4671, -0.3582],
        [-0.5609,  0.0621, -0.3328],
        [ 1.1135,  0.8310,  0.4783]])


In [13]:
print(g.ndata['y'])

tensor([[[ 5.1740e-01, -1.1059e+00,  9.4387e-01, -1.0236e+00],
         [ 1.1906e-01,  8.8992e-01, -1.1529e+00,  7.6569e-01],
         [-1.2758e+00,  6.8034e-02, -2.0166e+00, -2.9958e-01],
         [ 6.3827e-01,  1.1994e+00,  4.6846e-01,  7.0318e-01],
         [-5.8710e-02,  2.5240e-01,  4.6506e-01,  1.3793e-01]],

        [[-8.8068e-02,  1.9922e+00,  6.5358e-01,  8.9653e-01],
         [-1.1256e+00, -7.6091e-01, -2.1759e+00, -2.3135e+00],
         [ 6.7946e-01, -5.2826e-01,  7.4475e-01, -9.7166e-01],
         [ 4.0705e-01, -5.7047e-01,  3.6864e-01, -5.5453e-01],
         [ 1.7703e-01,  5.4992e-01, -3.1467e-01,  1.3404e+00]],

        [[ 1.2805e+00,  1.0879e+00, -1.0830e+00, -1.8653e+00],
         [-6.6984e-01, -7.2599e-01,  5.6298e-02,  1.5494e-01],
         [-2.8336e-01,  2.4520e-01,  9.0225e-01, -7.2946e-01],
         [ 1.8777e-01, -1.6443e-01,  1.7212e-01,  1.0218e-01],
         [-1.7572e+00,  2.9157e-01, -3.0790e-01, -1.0068e+00]],

        [[ 1.1311e-01, -7.3879e-01, -1.6125e+00, 

In [14]:
print(g.edata['a'])

tensor([[ 0.7856,  1.0833, -1.3352,  1.0333],
        [ 0.6354,  0.7014,  2.0382, -0.4109],
        [-0.8996, -0.1350, -0.2898,  0.4645],
        [-1.6151, -0.0227,  0.4396,  0.0025],
        [-0.2631,  0.3728,  0.7462, -1.1694]])


In [15]:
print(g.num_nodes())
print(g.num_edges())
# Out degrees of the center node
print(g.out_degrees(0))
# In degrees of the center node - note that the graph is directed so the in degree should be 0.
print(g.in_degrees(0))

6
5
5
0


In [16]:
sg1 = g.subgraph([0, 1, 3])

In [17]:
sg1.ndata

{'x': tensor([[-0.1922,  1.7157, -0.7712],
        [ 0.2557,  1.8470, -1.5688],
        [-1.0675,  0.4671, -0.3582]]), 'y': tensor([[[ 5.1740e-01, -1.1059e+00,  9.4387e-01, -1.0236e+00],
         [ 1.1906e-01,  8.8992e-01, -1.1529e+00,  7.6569e-01],
         [-1.2758e+00,  6.8034e-02, -2.0166e+00, -2.9958e-01],
         [ 6.3827e-01,  1.1994e+00,  4.6846e-01,  7.0318e-01],
         [-5.8710e-02,  2.5240e-01,  4.6506e-01,  1.3793e-01]],

        [[-8.8068e-02,  1.9922e+00,  6.5358e-01,  8.9653e-01],
         [-1.1256e+00, -7.6091e-01, -2.1759e+00, -2.3135e+00],
         [ 6.7946e-01, -5.2826e-01,  7.4475e-01, -9.7166e-01],
         [ 4.0705e-01, -5.7047e-01,  3.6864e-01, -5.5453e-01],
         [ 1.7703e-01,  5.4992e-01, -3.1467e-01,  1.3404e+00]],

        [[ 1.1311e-01, -7.3879e-01, -1.6125e+00,  9.8744e-01],
         [ 1.6376e-01, -1.6192e+00, -6.7076e-01, -5.7997e-01],
         [ 1.6943e-03, -6.1730e-01, -1.0794e+00,  1.7375e-01],
         [ 6.0237e-01,  1.0906e+00,  4.2878e-02, -2.3

In [18]:
sg2 = g.edge_subgraph([0, 1, 3])

In [19]:
# g = dgl.graph(([0, 0, 0, 0, 0], [1, 2, 3, 4, 5]), num_nodes=6)
# edghe subgraph uses the edge index (not related to the node number)
# that menas, when we use edge_subgraph([0,1,3]), we are specifying that we want
# the subgraph that conains the edges 0 (connects 0 to 1), 1 (connects 0 to 2), and 3 (connects 0 to 4)
sg2.ndata

{'x': tensor([[-0.1922,  1.7157, -0.7712],
        [ 0.2557,  1.8470, -1.5688],
        [ 0.8149, -1.0861,  1.2249],
        [-0.5609,  0.0621, -0.3328]]), 'y': tensor([[[ 0.5174, -1.1059,  0.9439, -1.0236],
         [ 0.1191,  0.8899, -1.1529,  0.7657],
         [-1.2758,  0.0680, -2.0166, -0.2996],
         [ 0.6383,  1.1994,  0.4685,  0.7032],
         [-0.0587,  0.2524,  0.4651,  0.1379]],

        [[-0.0881,  1.9922,  0.6536,  0.8965],
         [-1.1256, -0.7609, -2.1759, -2.3135],
         [ 0.6795, -0.5283,  0.7447, -0.9717],
         [ 0.4070, -0.5705,  0.3686, -0.5545],
         [ 0.1770,  0.5499, -0.3147,  1.3404]],

        [[ 1.2805,  1.0879, -1.0830, -1.8653],
         [-0.6698, -0.7260,  0.0563,  0.1549],
         [-0.2834,  0.2452,  0.9023, -0.7295],
         [ 0.1878, -0.1644,  0.1721,  0.1022],
         [-1.7572,  0.2916, -0.3079, -1.0068]],

        [[ 0.2448, -1.7043, -0.5305,  1.4876],
         [ 1.5271, -1.0480, -1.1199, -0.2606],
         [ 0.2892, -0.3558, -0.425

In [20]:
# The original IDs of each node in sg1
print(sg1.ndata[dgl.NID])
# The original IDs of each edge in sg1
print(sg1.edata[dgl.EID])
# The original IDs of each node in sg2
print(sg2.ndata[dgl.NID])
# The original IDs of each edge in sg2
print(sg2.edata[dgl.EID])

tensor([0, 1, 3])
tensor([0, 2])
tensor([0, 1, 2, 4])
tensor([0, 1, 3])


# transforming unidirectiona graph to bidirectional

In [21]:
newg = dgl.add_reverse_edges(g)
newg.edges()

(tensor([0, 0, 0, 0, 0, 1, 2, 3, 4, 5]),
 tensor([1, 2, 3, 4, 5, 0, 0, 0, 0, 0]))

In [22]:
# When we transform the edge information, we lose all features created for 
# the source graph
newg.edata

{}

# save/load graphs

In [23]:
dgl.save_graphs('graph.dgl', g)
dgl.save_graphs('graphs.dgl', [g, sg1, sg2])

# Load graphs
(g,), _ = dgl.load_graphs('graph.dgl')
print(g)
(g, sg1, sg2), _ = dgl.load_graphs('graphs.dgl')
print(g)
print(sg1)
print(sg2)

Graph(num_nodes=6, num_edges=5,
      ndata_schemes={'y': Scheme(shape=(5, 4), dtype=torch.float32), 'x': Scheme(shape=(3,), dtype=torch.float32)}
      edata_schemes={'a': Scheme(shape=(4,), dtype=torch.float32)})
Graph(num_nodes=6, num_edges=5,
      ndata_schemes={'y': Scheme(shape=(5, 4), dtype=torch.float32), 'x': Scheme(shape=(3,), dtype=torch.float32)}
      edata_schemes={'a': Scheme(shape=(4,), dtype=torch.float32)})
Graph(num_nodes=3, num_edges=2,
      ndata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64), 'y': Scheme(shape=(5, 4), dtype=torch.float32), 'x': Scheme(shape=(3,), dtype=torch.float32)}
      edata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64), 'a': Scheme(shape=(4,), dtype=torch.float32)})
Graph(num_nodes=4, num_edges=3,
      ndata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64), 'y': Scheme(shape=(5, 4), dtype=torch.float32), 'x': Scheme(shape=(3,), dtype=torch.float32)}
      edata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64), 'a': Scheme

# Graph classification

In [24]:
# Generate a synthetic dataset with 10000 graphs, ranging from 10 to 500 nodes.
dataset = dgl.data.GINDataset('PROTEINS', self_loop=True)

In [25]:
print('Node feature dimensionality:', dataset.dim_nfeats)
print('Number of graph categories:', dataset.gclasses)

Node feature dimensionality: 3
Number of graph categories: 2


In [26]:
t = torch.zeros(len(dataset))

In [27]:
for i, dt in enumerate(dataset):
    t[i] = dt[1]

In [28]:
# the protein dataset is exatcly the format I want for sites
# dataset contains 1113 graphs, each of them have the tensor specifying the label

In [29]:
g = dataset[0][0]

In [30]:
g.ndata

{'attr': tensor([[1., 0., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.]]), 'label': tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 

In [31]:
from dgl.dataloading import GraphDataLoader
from torch.utils.data.sampler import SubsetRandomSampler

num_examples = len(dataset)
num_train = int(num_examples * 0.8)

In [32]:
train_sampler = SubsetRandomSampler(torch.arange(num_train))

In [33]:
test_sampler = SubsetRandomSampler(torch.arange(num_train, num_examples))

In [34]:
train_dataloader = GraphDataLoader(
    dataset, sampler=train_sampler, batch_size=5, drop_last=False)
test_dataloader = GraphDataLoader(
    dataset, sampler=test_sampler, batch_size=5, drop_last=False)

In [35]:
it = iter(train_dataloader)

In [36]:
batch = next(it)
print(batch)

[Graph(num_nodes=764, num_edges=3388,
      ndata_schemes={'attr': Scheme(shape=(3,), dtype=torch.float32), 'label': Scheme(shape=(), dtype=torch.int64)}
      edata_schemes={}), tensor([0, 0, 0, 1, 1])]


In [37]:
batch[0].all_edges()

(tensor([  0,   0,   0,  ..., 763, 763, 763]),
 tensor([  1,  71,   0,  ..., 761, 762, 763]))

In [38]:
# the first element is the batched graph. This single elemtn contains all (in this case 5) graphs
# that we want to consider in each batch.

In [39]:
batched_graph, labels = batch
print('Number of nodes for each graph element in the batch:', batched_graph.batch_num_nodes())
print('Number of edges for each graph element in the batch:', batched_graph.batch_num_edges())

Number of nodes for each graph element in the batch: tensor([620,  69,  54,  13,   8])
Number of edges for each graph element in the batch: tensor([2718,  319,  252,   63,   36])


In [40]:
# Recover the original graph elements from the minibatch
graphs = dgl.unbatch(batched_graph)
print('The original graphs in the minibatch:')
print(graphs)

The original graphs in the minibatch:
[Graph(num_nodes=620, num_edges=2718,
      ndata_schemes={'attr': Scheme(shape=(3,), dtype=torch.float32), 'label': Scheme(shape=(), dtype=torch.int64)}
      edata_schemes={}), Graph(num_nodes=69, num_edges=319,
      ndata_schemes={'attr': Scheme(shape=(3,), dtype=torch.float32), 'label': Scheme(shape=(), dtype=torch.int64)}
      edata_schemes={}), Graph(num_nodes=54, num_edges=252,
      ndata_schemes={'attr': Scheme(shape=(3,), dtype=torch.float32), 'label': Scheme(shape=(), dtype=torch.int64)}
      edata_schemes={}), Graph(num_nodes=13, num_edges=63,
      ndata_schemes={'attr': Scheme(shape=(3,), dtype=torch.float32), 'label': Scheme(shape=(), dtype=torch.int64)}
      edata_schemes={}), Graph(num_nodes=8, num_edges=36,
      ndata_schemes={'attr': Scheme(shape=(3,), dtype=torch.float32), 'label': Scheme(shape=(), dtype=torch.int64)}
      edata_schemes={})]


In [41]:
from dgl.nn import GraphConv

class GCN(nn.Module):
    def __init__(self, in_feats, h_feats, num_classes):
        super(GCN, self).__init__()
        self.conv1 = GraphConv(in_feats, h_feats)
        self.conv2 = GraphConv(h_feats, h_feats)
        self.conv3 = GraphConv(h_feats, num_classes)
    
    def forward(self, g, in_feat):
        h = self.conv1(g, in_feat)
        h = F.relu(h)
        h = self.conv2(g, h)
        h = F.relu(h)
        h = self.conv3(g, h)
        g.ndata['h'] = h
        return dgl.mean_nodes(g, 'h')

In [42]:
model = GCN(dataset.dim_nfeats, 16, dataset.gclasses)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

for epoch in range(50):
    for batched_graph, labels in train_dataloader:
        pred = model(batched_graph, batched_graph.ndata['attr'].float())
        loss = F.cross_entropy(pred, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

num_correct = 0
num_tests = 0
for batched_graph, labels in test_dataloader:
    pred = model(batched_graph, batched_graph.ndata['attr'].float())
    num_correct += (pred.argmax(1) == labels).sum().item()
    num_tests += len(labels)

print('Test accuracy:', num_correct / num_tests)

Test accuracy: 0.2914798206278027


In [43]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data.sampler import SubsetRandomSampler
from sklearn.model_selection import StratifiedKFold
from dgl.data import GINDataset
from dgl.dataloading import GraphDataLoader
from dgl.nn.pytorch.conv import GINConv
from dgl.nn.pytorch.glob import SumPooling
import argparse

class MLP(nn.Module):
    """Construct two-layer MLP-type aggreator for GIN model"""
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.linears = nn.ModuleList()
        # two-layer MLP    
        self.linears.append(nn.Linear(input_dim, hidden_dim, bias=False))
        self.linears.append(nn.Linear(hidden_dim, output_dim, bias=False))
        self.batch_norm = nn.BatchNorm1d((hidden_dim))

    def forward(self, x):
        h = x
        h = F.relu(self.batch_norm(self.linears[0](h)))
        return self.linears[1](h)
    
class GIN(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.ginlayers = nn.ModuleList()
        self.batch_norms = nn.ModuleList()
        num_layers = 5
        # five-layer GCN with two-layer MLP aggregator and sum-neighbor-pooling scheme
        for layer in range(num_layers - 1): # excluding the input layer
            if layer == 0:
                mlp = MLP(input_dim, hidden_dim, hidden_dim)
            else:
                mlp = MLP(hidden_dim, hidden_dim, hidden_dim)
            self.ginlayers.append(GINConv(mlp, learn_eps=False)) # set to True if learning epsilon
            self.batch_norms.append(nn.BatchNorm1d(hidden_dim))
        # linear functions for graph sum poolings of output of each layer
        self.linear_prediction = nn.ModuleList()
        for layer in range(num_layers):
            if layer == 0:
                self.linear_prediction.append(nn.Linear(input_dim, output_dim))
            else:
                self.linear_prediction.append(nn.Linear(hidden_dim, output_dim))
        self.drop = nn.Dropout(0.5)
        self.pool = SumPooling() # change to mean readout (AvgPooling) on social network datasets

    def forward(self, g, h):
        # list of hidden representation at each layer (including the input layer)
        hidden_rep = [h]
        for i, layer in enumerate(self.ginlayers):
            h = layer(g, h)
            h = self.batch_norms[i](h)
            h = F.relu(h)
            hidden_rep.append(h)
        score_over_layer = 0
        # perform graph sum pooling over all nodes in each layer
        for i, h in enumerate(hidden_rep):
            pooled_h = self.pool(g, h)
            score_over_layer += self.drop(self.linear_prediction[i](pooled_h))
        return score_over_layer
    
def split_fold10(labels, fold_idx=0):
    skf = StratifiedKFold(n_splits=10, shuffle=True, random_state=0)
    idx_list = []
    for idx in skf.split(np.zeros(len(labels)), labels):
        idx_list.append(idx)
    train_idx, valid_idx = idx_list[fold_idx]
    return train_idx, valid_idx

def evaluate(dataloader, device, model):
    model.eval()
    total = 0
    total_correct = 0
    for batched_graph, labels in dataloader:
        batched_graph = batched_graph.to(device)
        labels = labels.to(device)
        feat = batched_graph.ndata.pop('attr')
        total += len(labels)
        logits = model(batched_graph, feat)
        _, predicted = torch.max(logits, 1)
        total_correct += (predicted == labels).sum().item()
    acc = 1.0 * total_correct / total
    return acc

def train(train_loader, val_loader, device, model):
    # loss function, optimizer and scheduler
    loss_fcn = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.5)
    
    # training loop    
    for epoch in range(400):
        model.train()
        total_loss = 0
        for batch, (batched_graph, labels) in enumerate(train_loader):
            batched_graph = batched_graph.to(device)
            labels = labels.to(device)
            feat = batched_graph.ndata.pop('attr')
            logits = model(batched_graph, feat)
            loss = loss_fcn(logits, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        scheduler.step()
        train_acc = evaluate(train_loader, device, model)
        valid_acc = evaluate(val_loader, device, model)
        print("Epoch {:05d} | Loss {:.4f} | Train Acc. {:.4f} | Validation Acc. {:.4f} "
              . format(epoch, total_loss / (batch + 1), train_acc, valid_acc))

In [44]:
labels = [l for _, l in dataset]
train_idx, val_idx = split_fold10(labels)

In [45]:
train_loader = GraphDataLoader(dataset, sampler=SubsetRandomSampler(train_idx),
                               batch_size=128, pin_memory=torch.cuda.is_available())
val_loader = GraphDataLoader(dataset, sampler=SubsetRandomSampler(val_idx),
                             batch_size=128, pin_memory=torch.cuda.is_available())

In [46]:
# create GIN model
in_size = dataset.dim_nfeats
out_size = dataset.gclasses
model = GIN(in_size, 16, out_size).to('cpu')

In [47]:
# train(train_dataloader, test_dataloader, 'cpu', model)