In [1]:
!*rm -f go.work && go work init && go work use . "${HOME}/Projects/gomlx"
%goworkfix

	- Added replace rule for module "github.com/gomlx/gomlx" to local directory "/home/janpf/Projects/gomlx".


## Downloading Dataset

The method `mag.Download()` will download the dataset to the data directory, if not yet downloaded.
It then converts the dataset to tensors, which are then available for use.

The tensor is then saved for faster access. After saved, the next call to `mag.Download()` will take 1/2s.

In [2]:
%env GOMLX_PLATFORM=CUDA

Set: GOMLX_PLATFORM="CUDA"


In [3]:
import (
    "flag"
    mag "github.com/gomlx/gomlx/examples/ogbnmag"
    "github.com/janpfeifer/must"
)

var (
    flagDataDir   = flag.String("data", "~/work/ogbnmag", "Directory to cache downloaded and generated dataset files.")
)

%%
start := time.Now()
must.M(mag.Download(*flagDataDir))
magSampler := must.M1(mag.NewSampler(*flagDataDir))

fmt.Printf("Elapsed: %s\n", time.Since(start))
fmt.Printf("%s\n", magSampler.String())

Elapsed: 1.135631171s
Sampler: 4 node types, 8 edge types
	NodeType "authors": 1,134,649 items
	NodeType "institutions": 8,740 items
	NodeType "fields_of_study": 59,965 items
	NodeType "papers": 736,389 items
	EdgeType "affiliatedWith": ["authors"]->["institutions"], 1,043,998 edges
	EdgeType "affiliations": ["institutions"]->["authors"], 1,043,998 edges
	EdgeType "hasTopic": ["papers"]->["fields_of_study"], 7,505,078 edges
	EdgeType "topicHasPapers": ["fields_of_study"]->["papers"], 7,505,078 edges
	EdgeType "writes": ["authors"]->["papers"], 7,145,660 edges
	EdgeType "writtenBy": ["papers"]->["authors"], 7,145,660 edges
	EdgeType "cites": ["papers"]->["papers"], 5,416,271 edges
	EdgeType "citedBy": ["papers"]->["papers"], 5,416,271 edges


## FNN model, No Graph

The first model will only use the paper features, and no relations. It serves as a baseline.

In a quick experiment, without much hyperparameter tuning, we got 27.27% on test accuracy, which is inline with the corresponding results in [the leaderboard](https://ogb.stanford.edu/docs/leader_nodeprop/#ogbn-mag) (the Multi Layer Perceptron MLP entry in the bottom).

In [5]:
import (
	. "github.com/gomlx/gomlx/graph"
	"github.com/gomlx/gomlx/ml/context"
)
var manager = NewManager()

func config() (ctx *context.Context) {
    ctx = context.NewContext(manager)
    ctx.RngStateReset()
    ctx.SetParams(map[string]any{
        "train_steps": 1_000_000, 
        "batch_size": 128,
        "optimizer": "adamw", 
        optimizers.LearningRateKey: 0.0001,
        layers.L2RegularizationKey: 1e-4,
        "normalization": "layer",
        "dropout": 0.1,
        "hidden_layers": 2,
        "num_nodes": 256,
        "plots": true,
    })
    return
}

In [6]:
import "github.com/gomlx/gomlx/examples/ogbnmag/fnn"

%%
must.M(mag.Download(*flagDataDir))
ctx := config()
ctx.SetParam("checkpoint", path.Join(*flagDataDir, "fnn-baseline"))
ctx.SetParam("num_checkpoints", 10)
err := fnn.Train(ctx)
if  err != nil {
    fmt.Printf("%+v\n", err)
}


> papers features input shape: (Float32)[128 129]
Training (1000000 steps):  100% [[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m] (2052 steps/s)[0m [loss=2.793] [~loss=3.078] [~acc=28.30%]           


	[Step 1000000] median train step: 407 microseconds

Results on seeds_train:
	Mean Loss (#loss): 3.083
	Mean Accuracy (#acc): 28.07%
Results on seeds_valid:
	Mean Loss (#loss): 3.186
	Mean Accuracy (#acc): 26.28%
Results on seeds_test:
	Mean Loss (#loss): 3.114
	Mean Accuracy (#acc): 27.20%



## GNN (Graph Neural Networks)

To use GNNs we need first graphs to run the model on. The issue is, generally the graphs in real life are too large (social networks, relational datases, etc.), including OGNB-MAG.

So instead we use sampled sub-graphs to train. For inference we can also use sampled subgraphs, but later we show a work around where we can do inference, a node type at a time, where we don't need sampling.

### Sampling Sub-Graphs

We follow the same sampling strategy used in [TensorFlow GNN](https://github.com/tensorflow/gnn) library, describe it its [OGBN-MAG notebook](https://github.com/tensorflow/gnn/blob/main/examples/notebooks/ogbn_mag_e2e.ipynb).

In [7]:
import "github.com/gomlx/gomlx/examples/ogbnmag/sampler"

// MagStrategy takes a sampler created by [ogbnmag.NewSampler], a desired batch size, and the set of
// seed ids to sample from ([ogbnmag.TrainSplit], [ogbnmag.ValidSplit] or [ogbnmag.TestSplit]) and
// returns a sampling strategy, that can be used to create datasets.
func MagStrategy(magSampler *sampler.Sampler, batchSize int, seedIdsCandidates tensor.Tensor) *sampler.Strategy {
	strategy := magSampler.NewStrategy()
	var seeds *sampler.Rule
	if seedIdsCandidates == nil {
		seeds = strategy.Nodes("seeds", "papers", batchSize)
	} else {
		seedIdsData := seedIdsCandidates.Local().FlatCopy().([]int32)
		seeds = strategy.NodesFromSet("seeds", "papers", batchSize, seedIdsData)
	}
	citations := seeds.FromEdges("citations", "cites", 8)

	// Authors
	seedsAuthors := seeds.FromEdges("seedsAuthors", "writtenBy", 8)
	citationsAuthors := citations.FromEdges("citationsAuthors", "writtenBy", 8)

	// Co-authored papers
	coauthoredPapers := seedsAuthors.FromEdges("coauthoredPapers", "writes", 8)
	coauthoredFromCitations := citationsAuthors.FromEdges("coauthoredFromCitations", "writes", 8)

	// Affiliations
	_ = seedsAuthors.FromEdges("authorsInstitutions", "affiliatedWith", 8)
	_ = citationsAuthors.FromEdges("citationAuthorsInstitutions", "affiliatedWith", 8)

	// Topics
	_ = seeds.FromEdges("seedsTopics", "hasTopic", 8)
	_ = coauthoredPapers.FromEdges("coauthoredTopics", "hasTopic", 8)
	_ = citations.FromEdges("citationsTopics", "hasTopic", 8)
	_ = coauthoredFromCitations.FromEdges("coauthoredFromCitationsTopics", "hasTopic", 8)

	return strategy
}


### Benchmarking Sampling

Let's get some benchmarking on the sampler, to make sure it's not slowing training:

1. Benchmark sequential the sampling of one sub-graph: `BenchmarkSequentialMagSampler`
2. Benchmark parallel sampling of one sub-graph: `BenchmarkParallelMagSampler`
3. Benchmark parallel sampling of one epoch (using parallel sampler): `BenchmarkOneEpochMagSampler`

### Training Model

We use the vanilla GNN trainer and model defined in the `gnn` package.

In [41]:
import (
    "github.com/gomlx/gomlx/examples/ogbnmag/gnn"
)

func configGnn() *context.Context {
    ctx := context.NewContext(manager)
    ctx.RngStateReset()

    stepsPerEpoch := mag.SplitTrain.Shape().Size() / gnn.BatchSize + 1
    numEpochs := 10  // Taken from TF-GNN OGBN-MAG notebook.
    
    ctx.SetParams(map[string]any{
        "train_steps": numEpochs + stepsPerEpoch, 
        "optimizer": "adamw", 
        optimizers.LearningRateKey: 0.001,
        layers.ParamL2Regularization: 1e-5,
        layers.ParamDropoutRate: 0.2,
        gnn.ParamEdgeDropoutRate: 0.0,
        "plots": true,
    })
    return ctx
}

%%
ctx := configGnn()
ctx.SetParam("checkpoint", path.Join(*flagDataDir, "gnn-baseline"))
ctx.SetParam("num_checkpoints", 10)
err := gnn.Train(ctx, *flagDataDir)
if  err != nil {
    fmt.Printf("%+v\n", err)
}


loading: "checkpoint-n0000320-20240225-202908-step-00087774"
loading: "checkpoint-n0000318-20240225-202707-step-00087190"


2024-02-25 20:32:30.569105: E external/xla/xla/stream_executor/stream_executor_internal.h:182] SetPriority unimplemented for this stream.


loading: "checkpoint-n0000319-20240225-202808-step-00087482"
> restarting training from global_step=87774


Training (1000 steps):  100% [[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m] (2 steps/s)[0m [loss=1.726] [~loss=1.904] [~acc=51.53%]          


	[Step 88774] median train step: 202642 microseconds
Median training step duration: 202.642867ms

Results on train:


	Mean Loss (#loss): 1.919
	Mean Accuracy (#acc): 50.54%
Results on valid:
	Mean Loss (#loss): 2.160
	Mean Accuracy (#acc): 44.83%
Results on test:
	Mean Loss (#loss): 2.123
	Mean Accuracy (#acc): 45.22%



In [48]:
%%
ctx := configGnn()
ctx.SetParam("checkpoint", path.Join(*flagDataDir, "gnn-baseline"))
ctx.SetParam("num_checkpoints", 10)

_, trainEvalDS, validEvalDS, testEvalDS := must.M4(gnn.MakeDatasets(*flagDataDir))
_, _, _ = trainEvalDS, validEvalDS, testEvalDS
must.M(gnn.Eval(ctx, *flagDataDir, validEvalDS, testEvalDS))

loading: "checkpoint-n0000325-20240225-204246-step-00088774"
loading: "checkpoint-n0000323-20240225-203818-step-00088438"


2024-02-25 21:04:19.547889: E external/xla/xla/stream_executor/stream_executor_internal.h:182] SetPriority unimplemented for this stream.


loading: "checkpoint-n0000324-20240225-203918-step-00088733"
Model in "/home/janpf/work/ogbnmag/gnn-baseline" trained for 88774 steps.
Results on valid:
	Mean Loss (#loss): 2.148
	Mean Accuracy (#acc): 44.92%
	elapsed 18.635094382s (valid)
Results on test:
	Mean Loss (#loss): 2.116
	Mean Accuracy (#acc): 45.29%
	elapsed 12.188028404s (test)
