In [4]:
import warnings

warnings.filterwarnings("ignore")

from ogb.graphproppred import PygGraphPropPredDataset
from torch_geometric.loader import DataLoader

dataset = PygGraphPropPredDataset(name="ogbg-moltox21", root="../dataset/")
print(dataset)

split_idx = dataset.get_idx_split()
train_loader = DataLoader(dataset[split_idx["train"]], batch_size=32, shuffle=True)
valid_loader = DataLoader(dataset[split_idx["valid"]], batch_size=32, shuffle=False)
test_loader = DataLoader(dataset[split_idx["test"]], batch_size=32, shuffle=False)

PygGraphPropPredDataset(7831)


## Exploration


In [None]:
# Load first batch
batch = next(iter(train_loader))
print(batch)

DataBatch(edge_index=[2, 1012], edge_attr=[1012, 3], x=[497, 9], y=[32, 12], num_nodes=497, batch=[497], ptr=[33])


**PyG molecular batch structure**

- **x [num_nodes, num_node_features]**  
  Atom feature matrix. Each row is one atom across all molecules; features are categorical descriptors (atom type, valence-related classes, aromaticity classes).

- **edge_index [2, num_edges]**  
  Bond connectivity. Each column is a directed bond between two atom indices in the concatenated node list. Undirected bonds appear twice (u,v) and (v,u).

- **edge_attr [num_edges, num_edge_features]**  
  Bond feature matrix. Each row encodes bond type and stereochemical categories for the corresponding column in `edge_index`.

- **y [num_graphs, num_tasks]**  
  Graph-level toxicity labels. Each row corresponds to one moleculeâ€™s 12-task target vector.

- **num_nodes**  
  Total atoms across the batch.

- **batch [num_nodes]**  
  For atom index i, `batch[i]` gives the molecule index. Supports graph-wise pooling.

- **ptr [num_graphs+1]**  
  Prefix sums of node counts. Graph g occupies node indices `[ptr[g], ptr[g+1])`. Used to unbatch and slice per-molecule subgraphs.


In [6]:
# Make a small GCN model to train on the dataset
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool


class GCN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, out_channels)

    def forward(self, x, edge_index, batch):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        x = global_mean_pool(x, batch)  # Global pooling
        return x

In [7]:
# Train the model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = GCN(
    in_channels=dataset.num_node_features,
    hidden_channels=64,
    out_channels=dataset.num_tasks,
).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
model.train()

for epoch in range(1, 6):
    total_loss = 0
    for batch in train_loader:
        batch = batch.to(device)
        optimizer.zero_grad()
        # Parse Long to Float for binary classification
        x = batch.x.float()
        # Change nan labels to 0 for loss computation
        batch.y[batch.y != batch.y] = 0
        out = model(x, batch.edge_index, batch.batch)
        loss = F.binary_cross_entropy_with_logits(out, batch.y.float())
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * batch.num_graphs
    total_loss /= len(train_loader.dataset)
    print(f"Epoch {epoch}, Loss: {total_loss:.4f}")

# Evaluate the model
model.eval()
correct = 0
for batch in test_loader:
    batch = batch.to(device)
    out = model(batch.x.float(), batch.edge_index, batch.batch)
    pred = (out > 0).float()
    correct += (pred == batch.y).sum().item()
accuracy = correct / (len(test_loader.dataset) * dataset.num_tasks)
print(f"Test Accuracy: {accuracy:.4f}")

Epoch 1, Loss: 0.2328
Epoch 2, Loss: 0.2069
Epoch 3, Loss: 0.2039
Epoch 4, Loss: 0.2035
Epoch 5, Loss: 0.2031
Test Accuracy: 0.6790
