In [1]:
import torch
from torch_geometric.datasets import GeometricShapes

dataset = GeometricShapes(root='data/GeometricShapes')
print(dataset)

GeometricShapes(40)


In [2]:
import torch
from torch_geometric.transforms import SamplePoints

torch.manual_seed(42)

dataset.transform = SamplePoints(num=256)

In [3]:
# load model
from util import read_config
from model import ComposableModel
model_cfg = read_config("notebook/point_net.yaml")
model = ComposableModel("point_net", model_cfg.modules)
model

ComposableModel(
  (point_net): PointNet(
    (conv1): PointNetLayer()
    (conv2): PointNetLayer()
    (classifier): Linear(in_features=32, out_features=40, bias=True)
  )
)

In [4]:
model.point_net.return_dict

True

In [5]:
from torch_geometric.loader import DataLoader

train_dataset = GeometricShapes(root='data/GeometricShapes', train=True,
                                transform=SamplePoints(128))
test_dataset = GeometricShapes(root='data/GeometricShapes', train=False,
                               transform=SamplePoints(128))


train_loader = DataLoader(train_dataset, batch_size=10, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=10)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.CrossEntropyLoss()  # Define loss criterion.

def train(model, optimizer, loader):
    model.train()

    total_loss = 0
    for data in loader:
        optimizer.zero_grad()  # Clear gradients.
        logits = model(data.pos, data.batch)["x"]  # Forward pass.
        loss = criterion(logits, data.y)  # Loss computation.
        loss.backward()  # Backward pass.
        optimizer.step()  # Update model parameters.
        total_loss += loss.item() * data.num_graphs

    return total_loss / len(train_loader.dataset)


@torch.no_grad()
def test(model, loader):
    model.eval()

    total_correct = 0
    for data in loader:
        logits = model(data.pos, data.batch)["x"]
        pred = logits.argmax(dim=-1)
        total_correct += int((pred == data.y).sum())

    return total_correct / len(loader.dataset)

for epoch in range(1, 51):
    loss = train(model, optimizer, train_loader)
    test_acc = test(model, test_loader)
    print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Test Accuracy: {test_acc:.4f}')

Epoch: 01, Loss: 3.7417, Test Accuracy: 0.0500
Epoch: 02, Loss: 3.6927, Test Accuracy: 0.0250
Epoch: 03, Loss: 3.6678, Test Accuracy: 0.0500
Epoch: 04, Loss: 3.6418, Test Accuracy: 0.0500
Epoch: 05, Loss: 3.6065, Test Accuracy: 0.0250
Epoch: 06, Loss: 3.5239, Test Accuracy: 0.0250
Epoch: 07, Loss: 3.4593, Test Accuracy: 0.0500
Epoch: 08, Loss: 3.4031, Test Accuracy: 0.0750
Epoch: 09, Loss: 3.3605, Test Accuracy: 0.1250
Epoch: 10, Loss: 3.2835, Test Accuracy: 0.1000
Epoch: 11, Loss: 3.2149, Test Accuracy: 0.1000
Epoch: 12, Loss: 3.1579, Test Accuracy: 0.1500
Epoch: 13, Loss: 3.0477, Test Accuracy: 0.2000
Epoch: 14, Loss: 2.9329, Test Accuracy: 0.2250
Epoch: 15, Loss: 2.8375, Test Accuracy: 0.2500
Epoch: 16, Loss: 2.7525, Test Accuracy: 0.3000
Epoch: 17, Loss: 2.4533, Test Accuracy: 0.3000
Epoch: 18, Loss: 2.3596, Test Accuracy: 0.3500
Epoch: 19, Loss: 2.1281, Test Accuracy: 0.4000
Epoch: 20, Loss: 1.9795, Test Accuracy: 0.5250
Epoch: 21, Loss: 1.8829, Test Accuracy: 0.3500
Epoch: 22, Lo