# Open Graph Benchmark (OGB) Node Prediction for Micrososoft Academic Graph (OGBN-MAG)

The [Open Graph Benchmark (OGB)](https://ogb.stanford.edu/) is a collection of realistic, large-scale, and diverse benchmark datasets for machine learning on graphs.

This notebook demonstrate a [Graph Neural Network ([GNN])](https://en.wikipedia.org/wiki/Graph_neural_network) based on [TF-GNN](https://github.com/tensorflow/gnn) [OGBNMAG tutorial](https://colab.research.google.com/github/tensorflow/gnn/blob/master/examples/notebooks/ogbn_mag_e2e.ipynb), but using GoMLX.

The task [OGBN-MAG is described here](https://ogb.stanford.edu/docs/nodeprop/#ogbn-mag). This demo comes with a library that does the downloading, parsing and converting of the data to tensors for fast use.

The model is experimental, but it includes a basic GNN library that can be used for different projects.

**EXPERIMENTAL**, it has been used only for OGBN-MAG still.

See also [OGBN-MAG Leaderboard](https://ogb.stanford.edu/docs/leader_nodeprop/#ogbn-papers100M) -- take with a grain of salt because different models use different tricks that may be considered leaking, or using extra data from the outside (so more data), and in some cases are very overfit to the task. But still, it's a fun dataset to work with.

See the subdirectory `demo` for a command line of the trainer (that can be run on the cloud somewhere) and it will save datapoints that can be plotted in the notebook later.


In [4]:
!*rm -f go.work && go work init && go work use . "${HOME}/Projects/gomlx" "${HOME}/Projects/gonb"
%goworkfix

	- Replace rule for module "github.com/janpfeifer/gonb" to local directory "/home/janpf/Projects/gonb" already exists.
	- Replace rule for module "github.com/gomlx/gomlx" to local directory "/home/janpf/Projects/gomlx" already exists.


## Downloading Dataset

The method `mag.Download()` will download the dataset to the data directory, if not yet downloaded.
It then converts the dataset to tensors, which are then available for use.

The tensor is then saved for faster access. After saved, the next call to `mag.Download()` will take 1/2s.

In [8]:
import (
    "flag"
    . "github.com/gomlx/gomlx/pkg/core/graph"
    mag "github.com/gomlx/gomlx/examples/ogbnmag"
    "github.com/janpfeifer/must"
	"github.com/gomlx/gomlx/backends"

    _ "github.com/gomlx/gomlx/backends/default"
)

var (
    flagDataDir   = flag.String("data", "~/work/ogbnmag", "Directory to cache downloaded and generated dataset files.")
    backend = backends.MustNew()
    _ *Node = nil
)

%%
start := time.Now()
must.M(mag.Download(*flagDataDir))
fmt.Printf("Elapsed: %s\n", time.Since(start))

Elapsed: 1.973077748s


In [9]:
%%
start := time.Now()
must.M(mag.Download(*flagDataDir))
fmt.Printf("Elapsed: %s\n", time.Since(start))


results := graph.NewExec(backend, func (x *Node) []*Node {
    mean := ReduceAllMean(x)
    variance := ReduceAllMean(Square(Sub(x, mean)))
    stddev := Sqrt(variance)
    return []*Node {
        ReduceAllMin(x),
        ReduceAllMax(x),
        mean,
        variance,
        stddev,
    }
}).Call(mag.PapersEmbeddings)
for _, t := range results {
    fmt.Printf("\t%s\n", t)
}


Elapsed: 1.63463491s
	float32(-1.439)
	float32(1.697)
	float32(0.01028)
	float32(0.05374)
	float32(0.2318)


In [None]:
%%
start := time.Now()
must.M(mag.Download(*flagDataDir))
fmt.Printf("Elapsed: %s\n", time.Since(start))

tensors.MutableFlatData[float32](mag.PapersEmbeddings, func (flat []float32) {
    slices.Sort(flat)
    numQuantiles := 20
    fmt.Printf("%g\n", flat[0])
    for ii := range numQuantiles-2 {
        idx := (ii+1)*len(flat)/(numQuantiles)
        fmt.Printf("%g\n", flat[idx])
    }
    fmt.Printf("%g\n", flat[numQuantiles-1])
})
// ZedMono NFM Light g

Elapsed: 1.577930017s


## FNN model, No Graph

The first model will only use the paper features, and no relations. It serves as a baseline.

In a quick experiment, without much hyperparameter tuning, we got 27.27% on test accuracy, which is inline with the corresponding results in [the leaderboard](https://ogb.stanford.edu/docs/leader_nodeprop/#ogbn-mag) (the Multi Layer Perceptron MLP entry in the bottom).

In [5]:
import (
	"github.com/gomlx/gomlx/ml/context"
	"github.com/gomlx/gomlx/ml/layers/regularizers"
)
func config() (ctx *context.Context) {
    ctx = context.NewContext()
    ctx.RngStateReset()
    ctx.SetParams(map[string]any{
        "train_steps": 400_000, 
        "batch_size": 128,
        "optimizer": "adamw", 
        optimizers.LearningRateKey: 0.0001,
        regularizers.ParamL2: 1e-4,
        "normalization": "layer",
        "dropout": 0.1,
        "hidden_layers": 2,
        "num_nodes": 256,
        "plots": true,
    })
    return
}

In [6]:
import "github.com/gomlx/gomlx/examples/ogbnmag/fnn"

%%
must.M(mag.Download(*flagDataDir))
ctx := config()
ctx.SetParam("checkpoint", path.Join(*flagDataDir, "fnn-baseline"))
ctx.SetParam("num_checkpoints", 10)
ctx.SetParam("train_steps", 400_000)

// Using KAN Kolmogorov–Arnold Networks
ctx.SetParam("checkpoint", path.Join(*flagDataDir, "kan-baseline"))
ctx.SetParam("kan", true)
ctx.SetParam("num_nodes", 48)
ctx.SetParam("hidden_layers", 2)
// ctx.SetParam("kan_bspline_magnitude_l1", 1e-3)

err := fnn.Train(backend, ctx)
if  err != nil {
    fmt.Printf("%+v\n", err)
}


loading: "checkpoint-n0000033-20240722-103831-step-00400000"
loading: "checkpoint-n0000031-20240722-103650-step-00380971"
loading: "checkpoint-n0000032-20240722-103751-step-00393020"
> restarting training from global_step=400000



Results on seeds_train:
	Mean Loss+Regularization (#loss+): 3.390
	Mean Loss (#loss): 3.170
	Mean Accuracy (#acc): 24.04%
Results on seeds_valid:
	Mean Loss+Regularization (#loss+): 3.448
	Mean Loss (#loss): 3.227
	Mean Accuracy (#acc): 23.30%
Results on seeds_test:
	Mean Loss+Regularization (#loss+): 3.370
	Mean Loss (#loss): 3.149
	Mean Accuracy (#acc): 24.30%



## GNN (Graph Neural Networks)

To use GNNs we need first graphs to run the model on. The issue is, generally the graphs in real life are too large (social networks, relational datases, etc.), including OGNB-MAG.

So instead we use sampled sub-graphs to train. For inference we can also use sampled subgraphs, but later we show a work around where we can do inference, a layer of nodes at a time, where we don't need sampling.

### Sampling Sub-Graphs

We follow the same sampling strategy used in [TensorFlow GNN](https://github.com/tensorflow/gnn) library, describe it its [OGBN-MAG notebook](https://github.com/tensorflow/gnn/blob/main/examples/notebooks/ogbn_mag_e2e.ipynb).

The `magSampler` variable loads the definition of the data graph: its node and edge sets. The `magStrategy` defines how to sample from those nodes and edges. See it's specification in [gomlx/examples/ogbnmag/sampling.go](https://github.com/gomlx/gomlx/blob/main/examples/ogbnmag/sampling.go), it's only some 20 lines long.

The print out of the sampler and the strategy we are using:

In [4]:
%%
must.M(mag.Download(*flagDataDir))
magSampler := must.M1(mag.NewSampler(*flagDataDir))
magStrategy := mag.NewSamplerStrategy(magSampler, mag.BatchSize, mag.TrainSplit)
fmt.Printf("%s\n", magSampler)
fmt.Printf("\n%s\n", magStrategy)

Sampler: 4 node types, 8 edge types, Frozen
	NodeType "papers": 736,389 items
	NodeType "authors": 1,134,649 items
	NodeType "institutions": 8,740 items
	NodeType "fields_of_study": 59,965 items
	EdgeType "citedBy": ["papers"]->["papers"], 5,416,271 edges
	EdgeType "affiliatedWith": ["authors"]->["institutions"], 1,043,998 edges
	EdgeType "affiliations": ["institutions"]->["authors"], 1,043,998 edges
	EdgeType "hasTopic": ["papers"]->["fields_of_study"], 7,505,078 edges
	EdgeType "topicHasPapers": ["fields_of_study"]->["papers"], 7,505,078 edges
	EdgeType "writes": ["authors"]->["papers"], 7,145,660 edges
	EdgeType "writtenBy": ["papers"]->["authors"], 7,145,660 edges
	EdgeType "cites": ["papers"]->["papers"], 5,416,271 edges

Sampling strategy: (13 Rules)
> Rule "seeds": type=Node, nodeType="papers", Shape=(Int32)[128] (size=128), NodeSet.size=629571
  > Rule "citations": type=Edge, nodeType="papers", Shape=(Int32)[128 8] (size=1024), SourceRule="seeds", EdgeType="cites"
    > Rule "c

### Training Model

We use the vanilla GNN trainer and model defined in the `gnn` package.

In [74]:
import (
    "github.com/gomlx/gomlx/examples/ogbnmag/gnn"
)

func configGnn(baseDir string) *context.Context {
    must.M(mag.Download(baseDir))
    ctx := context.NewContext(manager)
    ctx.RngStateReset()
    
    stepsPerEpoch := mag.TrainSplit.Shape().Size() / mag.BatchSize + 1
    numEpochs := 10  // Taken from TF-GNN OGBN-MAG notebook.
    numTrainSteps := numEpochs * stepsPerEpoch
    
    ctx.SetParams(map[string]any{
        "train_steps": numTrainSteps,
        
        optimizers.ParamOptimizer: "adam", 
        optimizers.ParamLearningRate: 0.001,
        optimizers.ParamCosineScheduleSteps:  numTrainSteps,
        
        layers.ParamL2Regularization: 1e-5,
        layers.ParamDropoutRate: 0.2,

        mag.ParamEmbedDropoutRate: 0.0,
        
        gnn.ParamEdgeDropoutRate: 0.0,
        gnn.ParamNumGraphUpdates: 2,
        gnn.ParamReadoutHiddenLayers: 2,
        gnn.ParamPoolingType: "mean|sum",
        gnn.ParamUsePathToRootStates: false,
        
        "plots": true,
    })
    return ctx
}

%%
mag.BatchSize = 128  // Default is 128.
ctx := configGnn(*flagDataDir)
// ctx.SetParam("checkpoint", path.Join(*flagDataDir, "gnn-small_batch"))
// ctx.SetParam("checkpoint", path.Join(*flagDataDir, "gnn-use_path_to_root"))
ctx.SetParam("checkpoint", path.Join(*flagDataDir, "gnn-reg_2"))
ctx.SetParam("num_checkpoints", 3)
ctx.SetParams(map[string]any{
    // "train_steps": 100,
    // "plots": false,
    // gnn.ParamEdgeDropoutRate: 0.1,
    // mag.ParamEmbedDropoutRate: 0.1,
    gnn.ParamNumGraphUpdates: 2,
    // gnn.ParamUsePathToRootStates: true,
    // gnn.ParamReadoutHiddenLayers: 2,
    layers.ParamDropoutRate: 0.25,
    layers.ParamL2Regularization: 3e-4,
})

// err := gnn.Train(ctx, *flagDataDir)
// if  err != nil {
//     fmt.Printf("%+v\n", err)
// }

In [14]:
%%
ctx := configGnn(*flagDataDir)
ctx.SetParam("checkpoint", path.Join(*flagDataDir, "gnn-baseline_5"))

_, trainEvalDS, validEvalDS, testEvalDS := must.M4(gnn.MakeDatasets(*flagDataDir))
_, _, _ = trainEvalDS, validEvalDS, testEvalDS
must.M(gnn.Eval(ctx, *flagDataDir, trainEvalDS, validEvalDS, testEvalDS))

loading: "checkpoint-n0000171-20240229-012409-step-00049190"
Model in "/home/janpf/work/ogbnmag/gnn-baseline_5" trained for 49190 steps.
Results on train:
	Mean Loss+Regularization (#loss+): 1.542
	Mean Loss (#loss): 1.457
	Mean Accuracy (#acc): 57.49%
	elapsed 1m37.457905276s (train)
Results on valid:
	Mean Loss+Regularization (#loss+): 1.927
	Mean Loss (#loss): 1.841
	Mean Accuracy (#acc): 48.60%
	elapsed 11.140775268s (valid)
Results on test:
	Mean Loss+Regularization (#loss+): 1.922
	Mean Loss (#loss): 1.837
	Mean Accuracy (#acc): 48.10%
	elapsed 7.647081029s (test)


In [122]:
%%
ctx := configGnn(*flagDataDir)
ctx.SetParam("checkpoint", path.Join(*flagDataDir, "gnn-baseline_17"))

_, trainEvalDS, validEvalDS, testEvalDS := must.M4(mag.MakeDatasets(*flagDataDir))
_, _, _ = trainEvalDS, validEvalDS, testEvalDS
must.M(mag.Eval(ctx, *flagDataDir, validEvalDS, testEvalDS))

loading: "checkpoint-n0000142-20240305-235708-step-00049190"
Model in "/home/janpf/work/ogbnmag/gnn-baseline_17" trained for 49190 steps.
Results on valid:
	Mean Loss+Regularization (#loss+): 1.920
	Mean Loss (#loss): 1.834
	Mean Accuracy (#acc): 48.64%
	elapsed 9.335698266s (valid)
Results on test:
	Mean Loss+Regularization (#loss+): 1.917
	Mean Loss (#loss): 1.830
	Mean Accuracy (#acc): 48.38%
	elapsed 5.863748141s (test)


In [7]:
import (
    stdplots "github.com/gomlx/gomlx/examples/notebook/gonb/plots"
    "github.com/gomlx/gomlx/examples/notebook/gonb/plotly"
)

func filterPoints(pt *stdplots.Point) bool {
    // Remove substrings
    for _, s := range []string{"Eval on "} {
        pt.MetricName = strings.Replace(pt.MetricName, s, "", -1)
    }
    // Replace substrings
    for _, pair := range [][2]string{
        {"Accuracy", "acc"}, {"Loss+Regularization", "loss+"}, {"Loss", "loss"}, {" Mean ", "#"}, {" Moving Average ", "~"}, 
        {"Train", "train"}, {"Validation", "valid"}, {"Test", "test"}, {": layer-wise eval", ":#acc"}} {
        pt.MetricName = strings.Replace(pt.MetricName, pair[0], pair[1], -1)
    }
    for _, exclude := range []string{ "Batch", "Train:~loss", "Train:~acc", "loss+" } {
        if strings.Index(pt.MetricName, exclude) != -1 { return false }
    }
    return true
}

func plotVersions(versions ...int) {
    plots := plotly.New()
    for _, version := range versions {
        // checkpoint := fmt.Sprintf("gnn-baseline_%d", version)
        checkpoint := fmt.Sprintf("small_%d", version)
        prefix := fmt.Sprintf("(%d) ", version)
        must.M(plots.LoadCheckpointData(path.Join(*flagDataDir, checkpoint),
            func (pt *stdplots.Point) bool {
                if !filterPoints(pt) { return false }
                pt.MetricName = prefix+pt.MetricName
                return true
            }))
    }
    plots.Plot()
}

%%
plotVersions(81, 82)

In [1]:
!echo $GONB_TMP_DIR

/tmp/gonb_b9950945
