In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
# Import libraries
import numpy as np
import torch
from torch_geometric.loader import NeighborLoader
import utilities

In [3]:
# Define synthetic data parameters
num_samples = 8000
num_dimensions = 2
tau = 1
sigma = 1
phi = 10

# Define NN-GLS parameters
num_neighbors = 4
sigma_init = 0.1

# Define training parameters
num_epochs = 1000
batch_size = 400
learning_rate = 0.0001

# Define constants
NUM_FEATURES = 5

In [4]:
# Generate synthetic data
train_data = utilities.generate_samples(num_samples, num_dimensions, num_neighbors, tau, sigma, phi)
test_data = utilities.generate_samples(num_samples, num_dimensions, num_neighbors, tau, sigma, phi)

In [5]:
# Compute benchmark MSE (always predicting the mean)
benchmark_preds = torch.full(test_data.y.shape, train_data.y.mean())
benchmark_mse = torch.nn.functional.mse_loss(benchmark_preds, test_data.y)
print(f'Benchmark MSE: {benchmark_mse:.3f}')

Benchmark MSE: 2.923


In [6]:
# Create dataloader, model, and optimizer
train_loader = NeighborLoader(train_data, input_nodes=torch.tensor(range(num_samples)), num_neighbors=[num_neighbors], batch_size=batch_size, replace=False, shuffle=True)
model = utilities.NNGLS(num_features=NUM_FEATURES, num_neighbors=num_neighbors, num_dimensions=num_dimensions, sigma=sigma_init)
print(f"sigma: {model.sigma.item():.3f}, phi: {model.phi.item():.3f}, tau: {model.tau.item():.3f}")
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

sigma: 0.100, phi: 5.995, tau: 0.247


In [7]:
# Training/evaluation loop
for epoch in range(num_epochs):
    # Train for one epoch
    model.train()
    for batch_idx, batch in enumerate(train_loader):
        optimizer.zero_grad()
        decorrelated_preds, decorrelated_targets, preds = model(batch)
        loss = torch.nn.functional.mse_loss(decorrelated_preds[:batch_size], decorrelated_targets[:batch_size])
        metric = torch.nn.functional.mse_loss(preds[:batch_size], batch.y[:batch_size])
        loss.backward()
        optimizer.step()
    # Compute predictions on held-out test test
    model.eval()
    decorrelated_preds, decorrelated_targets, preds = model(test_data)
    loss = torch.nn.functional.mse_loss(decorrelated_preds, decorrelated_targets)
    metric = torch.nn.functional.mse_loss(preds, test_data.y)
    print(f"\rEpoch {epoch}, Loss: {loss.item():.3f}, Metric: {metric.item():.3f}, sigma: {model.sigma.item():.3f}, phi: {model.phi.item():.3f}, tau: {model.tau.item():.3f}", end="")

Epoch 999, Loss: 0.484, Metric: 1.325, sigma: 0.651, phi: 4.473, tau: 1.5340