In [1]:
import os
import torch
os.environ['TORCH'] = torch.__version__
print(torch.__version__)

2.3.1+cu121


In [2]:
%matplotlib inline
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE

In [3]:
def visualize(h, color):
    z = TSNE(n_components=2).fit_transform(h.detach().cpu().numpy())
    plt.figure(figsize=(7,7))
    plt.xticks([])
    plt.yticks([])
    plt.scatter(z[:,0],z[:,1],s=70,c=color,cmap="Set2")
    plt.show()

In [4]:
from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import NormalizeFeatures

In [6]:
dataset = Planetoid(root="./tmp/Planetoid", name="Cora", transform=NormalizeFeatures())

print()
print(f'Dataset: {dataset}:')
print('======================')
print(f'Number of graphs: {len(dataset)}')
print(f'Number of features: {dataset.num_features}')
print(f'Number of classes: {dataset.num_classes}')

Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.x
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.tx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.allx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.y
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ty
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ally
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.graph
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.test.index



Dataset: Cora():
Number of graphs: 1
Number of features: 1433
Number of classes: 7


Processing...
Done!


In [7]:
data = dataset[0] # Get the first graph object

print()
print(data)


Data(x=[2708, 1433], edge_index=[2, 10556], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708])


In [8]:
# Gather some statistics about the graph.
print(f'Number of nodes: {data.num_nodes}')
print(f'Number of edges: {data.num_edges}')
print(f'Average node degree: {data.num_edges / data.num_nodes:.2f}')
print(f'Number of training nodes: {data.train_mask.sum()}')
print(f'Training node label rate: {int(data.train_mask.sum()) / data.num_nodes:.2f}')
print(f'Has isolated nodes: {data.has_isolated_nodes()}')
print(f'Has self-loops: {data.has_self_loops()}')
print(f'Is undirected: {data.is_undirected()}')

Number of nodes: 2708
Number of edges: 10556
Average node degree: 3.90
Number of training nodes: 140
Training node label rate: 0.05
Has isolated nodes: False
Has self-loops: False
Is undirected: True


## Training a Multi-layer Perception Network (MLP)

In [9]:
from torch.nn import Linear
import torch.nn.functional as F

In [10]:
class MLP(torch.nn.Module):
    def __init__(self, hidden_channels):
        super().__init__()
        torch.manual_seed(12345)
        self.lin1 = Linear(dataset.num_features, hidden_channels)
        self.lin2 = Linear(hidden_channels, dataset.num_classes)

    def forward(self, x):
        x = self.lin1(x)
        x = x.relu()
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin2(x)
        return x

In [11]:
model = MLP(hidden_channels=16)
print(model)

MLP(
  (lin1): Linear(in_features=1433, out_features=16, bias=True)
  (lin2): Linear(in_features=16, out_features=7, bias=True)
)


In [12]:
# Define loss criterion
criterion = torch.nn.CrossEntropyLoss()

In [13]:
# Define optimizer
optimizer = torch.optim.Adam(model.parameters(),
                             lr=0.01,
                             weight_decay=5e-4)

In [14]:
def train():
    model.train()
    optimizer.zero_grad() # Clear gradients
    out = model(data.x) # Perform a single forward pass.
    loss = criterion(out[data.train_mask],data.y[data.train_mask]) # Compute the loss solely based on the training nodes.
    loss.backward() # Derive gradients
    optimizer.step() # Update parameters based on gradients
    return loss

In [None]:
def test():
    model.eval()
    out = model(data.x)
    pred = out.argmax(dim=1) # Use the class with highest probability
    test_correct = pred[data.test_mask] == data.y[data.test_mask]