In [30]:
import torch
import torch_geometric
from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader

In [3]:
dataset = TUDataset(root="data/TUDataset", name='COLLAB')

Downloading https://www.chrsmrrs.com/graphkerneldatasets/COLLAB.zip
Extracting data\TUDataset\COLLAB\COLLAB.zip
Processing...
Done!


### About Collab Dataset
COLLAB is a scientific collaboration dataset. A graph corresponds to a researcher’s ego network, i.e., the researcher and its collaborators are nodes and an edge indicates collaboration between two researchers. A researcher’s ego network has three possible labels, i.e., High Energy Physics, Condensed Matter Physics, and Astro Physics, which are the fields that the researcher belongs to. The dataset has 5,000 graphs and each graph has label 0, 1, or 2.
https://paperswithcode.com/dataset/collab 
https://networkrepository.com/COLLAB.php

In [24]:
print(f'Dataset: {dataset}')
print(f'Num Graphs: {len(dataset)}')
print(f'Num Nodes: {dataset.num_nodes}')
print(f'Num classes: {dataset.num_classes}')

Dataset: COLLAB(5000)
Num Graphs: 5000
Num Nodes: 372474
Num classes: 3


In [133]:
torch_geometric.utils.degree(dataset[4].edge_index[0], dataset[4].num_nodes)

# dataset[0].edge_index

tensor([42., 46., 39., 45.,  8., 45., 39., 45., 42., 45.,  7., 45., 45., 45.,
        42., 47., 45., 45., 45., 45., 47., 42., 45., 45., 47., 45., 39., 45.,
        42., 45., 45., 45., 45., 45., 45., 45., 45., 45., 45., 47., 47., 45.,
        45., 45., 45., 45., 45., 42.])

In [166]:
d = (torch_geometric.utils.degree(dataset[0].edge_index[0], num_nodes=data.num_nodes))
# dataset[0]

tensor([44., 44., 44.,  ...,  0.,  0.,  0.])

In [173]:
max_degree = 0
for data in dataset:
    deg = torch_geometric.utils.degree(data.edge_index[1], num_nodes=data.num_nodes)
    max_degree = max(max_degree, max(deg).item())
# assign to one hot degree for each data (OneHotDegree receive maximum degree parameter)
dataset.transform = torch_geometric.transforms.OneHotDegree(int(max_degree))

In [181]:
split = 0.8
seed = 123

num_split = round(len(dataset) * split)
torch.manual_seed(seed)
dataset.shuffle()

COLLAB(5000)

In [182]:
train_dataset = dataset[:num_split]
test_dataset = dataset[num_split:]
print('Train: ', train_dataset)
print('Test: ', test_dataset)

Train:  COLLAB(4000)
Test:  COLLAB(1000)


In [183]:
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=True)

### Model Construction

In [92]:
from torch.nn import Linear
import torch.nn.functional as F

# Graph neural network models
from torch_geometric.nn import GCNConv

# pooling method (for readout layer)
from torch_geometric.nn import global_mean_pool

In [93]:
class GCN(torch.nn.Module):
    def __init__(self, data, hidden_channels):
        super(GCN, self).__init__()
        # seed
        torch.manual_seed(42)
        self.conv1 = GCNConv(data.num_node_features, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.conv3 = GCNConv(hidden_channels, hidden_channels)
        self.lin = Linear(hidden_channels, data.num_classes)
        
    def forward(self, x, edge_index, batch):
        x = self.conv1(x, edge_index)
        x = x.relu()
        x = self.conv2(x, edge_index)
        x = x.relu()
        x = self.conv3(x, edge_index)        
        
        x = global_mean_pool(x, batch)
        
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin(x)
        
        return x

In [94]:
def train(model, loader):
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    
    model.train()
    
    for data in loader:
        out = model(data.x, data.edge_index, data.batch)
        loss = criterion(out, data.y)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
    return out, loss

def test(model, loader):
    model.eval()
    correct = 0
    for data in loader:
        out = model(data.x, data.edge_index, data.batch)
        pred = out.argmax(dim=1)
        correct += int((pred == data.y).sum())
    return correct/len(loader.dataset)

In [174]:
model = GCN(dataset, 42)

In [184]:
list_loss = []
list_train_acc = []
list_test_acc = []

for epoch in range(0, 10):
    out, loss = train(model, train_loader)
    train_acc = test(model, train_loader)
    test_acc = test(model, test_loader)
    
    list_train_acc.append(round(train_acc, 4))
    list_test_acc.append(round(test_acc, 4))
    list_loss.append(round(loss.item(), 4))

    print(f"epoch: {epoch+1} train_acc: {train_acc:.4f} loss: {loss:.4f} test_acc: {test_acc:.4f}")

epoch: 1 train_acc: 0.8625 loss: 0.4549 test_acc: 0.6190
epoch: 2 train_acc: 0.8700 loss: 0.2385 test_acc: 0.6260
epoch: 3 train_acc: 0.8798 loss: 0.3288 test_acc: 0.5460
epoch: 4 train_acc: 0.8768 loss: 0.3962 test_acc: 0.5580
epoch: 5 train_acc: 0.8832 loss: 0.3722 test_acc: 0.6570
epoch: 6 train_acc: 0.8840 loss: 0.2360 test_acc: 0.6690
epoch: 7 train_acc: 0.8850 loss: 0.0912 test_acc: 0.6370
epoch: 8 train_acc: 0.8938 loss: 0.1360 test_acc: 0.6190
epoch: 9 train_acc: 0.8632 loss: 0.3285 test_acc: 0.7300
epoch: 10 train_acc: 0.8992 loss: 0.3347 test_acc: 0.6850


In [185]:
# for data in train_loader:
#     print(data.y)
# #     break
for tl in train_loader:
    print(tl.x[0])
    break

tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 