# Open Graph Benchmark (OGB) Node Prediction for Micrososoft Academic Graph (OGBN-MAG)

The [Open Graph Benchmark (OGB)](https://ogb.stanford.edu/) is a collection of realistic, large-scale, and diverse benchmark datasets for machine learning on graphs.

This notebook demonstrate a [Graph Neural Network ([GNN])](https://en.wikipedia.org/wiki/Graph_neural_network) based on [TF-GNN](https://github.com/tensorflow/gnn) [OGBNMAG tutorial](https://colab.research.google.com/github/tensorflow/gnn/blob/master/examples/notebooks/ogbn_mag_e2e.ipynb), but using GoMLX.

The task [OGBN-MAG is described here](https://ogb.stanford.edu/docs/nodeprop/#ogbn-mag). This demo comes with a library that does the downloading, parsing and converting of the data to tensors for fast use.

The model is experimental, but it includes a basic GNN library that can be used for different projects.

**EXPERIMENTAL**, it has been used only for OGBN-MAG still.

See also [OGBN-MAG Leaderboard](https://ogb.stanford.edu/docs/leader_nodeprop/#ogbn-papers100M) -- take with a grain of salt because different models use different tricks that may be considered leaking, or using extra data from the outside (so more data), and in some cases are very overfit to the task. But still, it's a fun dataset to work with.

See the subdirectory `demo` for a command line of the trainer (that can be run on the cloud somewhere) and it will save datapoints that can be plotted in the notebook later.


In [1]:
!*rm -f go.work && go work init && go work use . "${HOME}/Projects/gomlx" "${HOME}/Projects/gonb"
%goworkfix

	- Added replace rule for module "github.com/gomlx/gomlx" to local directory "/home/janpf/Projects/gomlx".
	- Added replace rule for module "github.com/janpfeifer/gonb" to local directory "/home/janpf/Projects/gonb".


## Downloading Dataset

The method `mag.Download()` will download the dataset to the data directory, if not yet downloaded.
It then converts the dataset to tensors, which are then available for use.

The tensor is then saved for faster access. After saved, the next call to `mag.Download()` will take 1/2s.

In [3]:
import (
    "flag"
    mag "github.com/gomlx/gomlx/examples/ogbnmag"
    "github.com/janpfeifer/must"
    . "github.com/gomlx/gomlx/graph"
	"github.com/gomlx/gomlx/ml/context"
)

var (
    flagDataDir   = flag.String("data", "~/work/ogbnmag", "Directory to cache downloaded and generated dataset files.")
    
    manager = NewManager()
)

%%
start := time.Now()
must.M(mag.Download(*flagDataDir))
magSampler := must.M1(mag.NewSampler(*flagDataDir))
magStrategy := mag.MagStrategy(magSampler, mag.BatchSize, mag.TrainSplit)


fmt.Printf("Elapsed: %s\n", time.Since(start))
fmt.Printf("\n%s\n", magSampler)
fmt.Printf("\n%s\n", magStrategy)

Elapsed: 942.678854ms

Sampler: 4 node types, 8 edge types, Frozen
	NodeType "papers": 736,389 items
	NodeType "authors": 1,134,649 items
	NodeType "institutions": 8,740 items
	NodeType "fields_of_study": 59,965 items
	EdgeType "writtenBy": ["papers"]->["authors"], 7,145,660 edges
	EdgeType "cites": ["papers"]->["papers"], 5,416,271 edges
	EdgeType "citedBy": ["papers"]->["papers"], 5,416,271 edges
	EdgeType "affiliatedWith": ["authors"]->["institutions"], 1,043,998 edges
	EdgeType "affiliations": ["institutions"]->["authors"], 1,043,998 edges
	EdgeType "hasTopic": ["papers"]->["fields_of_study"], 7,505,078 edges
	EdgeType "topicHasPapers": ["fields_of_study"]->["papers"], 7,505,078 edges
	EdgeType "writes": ["authors"]->["papers"], 7,145,660 edges

Sampling strategy: (13 Rules)
> Rule "seeds": type=Node, nodeType="papers", Shape=(Int32)[128] (size=128), NodeSet.size=629571
  > Rule "citations": type=Edge, nodeType="papers", Shape=(Int32)[128 8] (size=1024), SourceRule="seeds", EdgeTyp

## FNN model, No Graph

The first model will only use the paper features, and no relations. It serves as a baseline.

In a quick experiment, without much hyperparameter tuning, we got 27.27% on test accuracy, which is inline with the corresponding results in [the leaderboard](https://ogb.stanford.edu/docs/leader_nodeprop/#ogbn-mag) (the Multi Layer Perceptron MLP entry in the bottom).

In [4]:
func config() (ctx *context.Context) {
    ctx = context.NewContext(manager)
    ctx.RngStateReset()
    ctx.SetParams(map[string]any{
        "train_steps": 1_000_000, 
        "batch_size": 128,
        "optimizer": "adamw", 
        optimizers.LearningRateKey: 0.0001,
        layers.L2RegularizationKey: 1e-4,
        "normalization": "layer",
        "dropout": 0.1,
        "hidden_layers": 2,
        "num_nodes": 256,
        "plots": true,
    })
    return
}

In [6]:
import "github.com/gomlx/gomlx/examples/ogbnmag/fnn"

%%
must.M(mag.Download(*flagDataDir))
ctx := config()
ctx.SetParam("checkpoint", path.Join(*flagDataDir, "fnn-baseline"))
ctx.SetParam("num_checkpoints", 10)
err := fnn.Train(ctx)
if  err != nil {
    fmt.Printf("%+v\n", err)
}


> papers features input shape: (Float32)[128 129]
Training (1000000 steps):  100% [[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m] (2052 steps/s)[0m [loss=2.793] [~loss=3.078] [~acc=28.30%]           


	[Step 1000000] median train step: 407 microseconds

Results on seeds_train:
	Mean Loss (#loss): 3.083
	Mean Accuracy (#acc): 28.07%
Results on seeds_valid:
	Mean Loss (#loss): 3.186
	Mean Accuracy (#acc): 26.28%
Results on seeds_test:
	Mean Loss (#loss): 3.114
	Mean Accuracy (#acc): 27.20%



## GNN (Graph Neural Networks)

To use GNNs we need first graphs to run the model on. The issue is, generally the graphs in real life are too large (social networks, relational datases, etc.), including OGNB-MAG.

So instead we use sampled sub-graphs to train. For inference we can also use sampled subgraphs, but later we show a work around where we can do inference, a node type at a time, where we don't need sampling.

### Sampling Sub-Graphs

We follow the same sampling strategy used in [TensorFlow GNN](https://github.com/tensorflow/gnn) library, describe it its [OGBN-MAG notebook](https://github.com/tensorflow/gnn/blob/main/examples/notebooks/ogbn_mag_e2e.ipynb).

In [7]:
import "github.com/gomlx/gomlx/examples/ogbnmag/sampler"

// MagStrategy takes a sampler created by [ogbnmag.NewSampler], a desired batch size, and the set of
// seed ids to sample from ([ogbnmag.TrainSplit], [ogbnmag.ValidSplit] or [ogbnmag.TestSplit]) and
// returns a sampling strategy, that can be used to create datasets.
func MagStrategy(magSampler *sampler.Sampler, batchSize int, seedIdsCandidates tensor.Tensor) *sampler.Strategy {
	strategy := magSampler.NewStrategy()
	var seeds *sampler.Rule
	if seedIdsCandidates == nil {
		seeds = strategy.Nodes("seeds", "papers", batchSize)
	} else {
		seedIdsData := seedIdsCandidates.Local().FlatCopy().([]int32)
		seeds = strategy.NodesFromSet("seeds", "papers", batchSize, seedIdsData)
	}
	citations := seeds.FromEdges("citations", "cites", 8)

	// Authors
	seedsAuthors := seeds.FromEdges("seedsAuthors", "writtenBy", 8)
	citationsAuthors := citations.FromEdges("citationsAuthors", "writtenBy", 8)

	// Co-authored papers
	coauthoredPapers := seedsAuthors.FromEdges("coauthoredPapers", "writes", 8)
	coauthoredFromCitations := citationsAuthors.FromEdges("coauthoredFromCitations", "writes", 8)

	// Affiliations
	_ = seedsAuthors.FromEdges("authorsInstitutions", "affiliatedWith", 8)
	_ = citationsAuthors.FromEdges("citationAuthorsInstitutions", "affiliatedWith", 8)

	// Topics
	_ = seeds.FromEdges("seedsTopics", "hasTopic", 8)
	_ = coauthoredPapers.FromEdges("coauthoredTopics", "hasTopic", 8)
	_ = citations.FromEdges("citationsTopics", "hasTopic", 8)
	_ = coauthoredFromCitations.FromEdges("coauthoredFromCitationsTopics", "hasTopic", 8)

	return strategy
}


### Benchmarking Sampling

Let's get some benchmarking on the sampler, to make sure it's not slowing training:

1. Benchmark sequential the sampling of one sub-graph: `BenchmarkSequentialMagSampler`
2. Benchmark parallel sampling of one sub-graph: `BenchmarkParallelMagSampler`
3. Benchmark parallel sampling of one epoch (using parallel sampler): `BenchmarkOneEpochMagSampler`

### Training Model

We use the vanilla GNN trainer and model defined in the `gnn` package.

In [74]:
import (
    "github.com/gomlx/gomlx/examples/ogbnmag/gnn"
)

func configGnn(baseDir string) *context.Context {
    must.M(mag.Download(baseDir))
    ctx := context.NewContext(manager)
    ctx.RngStateReset()
    
    stepsPerEpoch := mag.TrainSplit.Shape().Size() / mag.BatchSize + 1
    numEpochs := 10  // Taken from TF-GNN OGBN-MAG notebook.
    numTrainSteps := numEpochs * stepsPerEpoch
    
    ctx.SetParams(map[string]any{
        "train_steps": numTrainSteps,
        
        optimizers.ParamOptimizer: "adam", 
        optimizers.ParamLearningRate: 0.001,
        optimizers.ParamCosineScheduleSteps:  numTrainSteps,
        
        layers.ParamL2Regularization: 1e-5,
        layers.ParamDropoutRate: 0.2,

        mag.ParamEmbedDropoutRate: 0.0,
        
        gnn.ParamEdgeDropoutRate: 0.0,
        gnn.ParamNumGraphUpdates: 2,
        gnn.ParamReadoutHiddenLayers: 2,
        gnn.ParamPoolingType: "mean|sum",
        gnn.ParamUsePathToRootStates: false,
        
        "plots": true,
    })
    return ctx
}

%%
mag.BatchSize = 128  // Default is 128.
ctx := configGnn(*flagDataDir)
// ctx.SetParam("checkpoint", path.Join(*flagDataDir, "gnn-small_batch"))
// ctx.SetParam("checkpoint", path.Join(*flagDataDir, "gnn-use_path_to_root"))
ctx.SetParam("checkpoint", path.Join(*flagDataDir, "gnn-reg_2"))
ctx.SetParam("num_checkpoints", 3)
ctx.SetParams(map[string]any{
    // "train_steps": 100,
    // "plots": false,
    // gnn.ParamEdgeDropoutRate: 0.1,
    // mag.ParamEmbedDropoutRate: 0.1,
    gnn.ParamNumGraphUpdates: 2,
    // gnn.ParamUsePathToRootStates: true,
    // gnn.ParamReadoutHiddenLayers: 2,
    layers.ParamDropoutRate: 0.25,
    layers.ParamL2Regularization: 3e-4,
})

// err := gnn.Train(ctx, *flagDataDir)
// if  err != nil {
//     fmt.Printf("%+v\n", err)
// }

In [14]:
%%
ctx := configGnn(*flagDataDir)
ctx.SetParam("checkpoint", path.Join(*flagDataDir, "gnn-baseline_5"))

_, trainEvalDS, validEvalDS, testEvalDS := must.M4(gnn.MakeDatasets(*flagDataDir))
_, _, _ = trainEvalDS, validEvalDS, testEvalDS
must.M(gnn.Eval(ctx, *flagDataDir, trainEvalDS, validEvalDS, testEvalDS))

loading: "checkpoint-n0000171-20240229-012409-step-00049190"
Model in "/home/janpf/work/ogbnmag/gnn-baseline_5" trained for 49190 steps.
Results on train:
	Mean Loss+Regularization (#loss+): 1.542
	Mean Loss (#loss): 1.457
	Mean Accuracy (#acc): 57.49%
	elapsed 1m37.457905276s (train)
Results on valid:
	Mean Loss+Regularization (#loss+): 1.927
	Mean Loss (#loss): 1.841
	Mean Accuracy (#acc): 48.60%
	elapsed 11.140775268s (valid)
Results on test:
	Mean Loss+Regularization (#loss+): 1.922
	Mean Loss (#loss): 1.837
	Mean Accuracy (#acc): 48.10%
	elapsed 7.647081029s (test)


In [122]:
%%
ctx := configGnn(*flagDataDir)
ctx.SetParam("checkpoint", path.Join(*flagDataDir, "gnn-baseline_17"))

_, trainEvalDS, validEvalDS, testEvalDS := must.M4(mag.MakeDatasets(*flagDataDir))
_, _, _ = trainEvalDS, validEvalDS, testEvalDS
must.M(mag.Eval(ctx, *flagDataDir, validEvalDS, testEvalDS))

loading: "checkpoint-n0000142-20240305-235708-step-00049190"
Model in "/home/janpf/work/ogbnmag/gnn-baseline_17" trained for 49190 steps.
Results on valid:
	Mean Loss+Regularization (#loss+): 1.920
	Mean Loss (#loss): 1.834
	Mean Accuracy (#acc): 48.64%
	elapsed 9.335698266s (valid)
Results on test:
	Mean Loss+Regularization (#loss+): 1.917
	Mean Loss (#loss): 1.830
	Mean Accuracy (#acc): 48.38%
	elapsed 5.863748141s (test)


In [8]:
import (
    "github.com/gomlx/gomlx/examples/notebook/gonb/plots"
    "github.com/janpfeifer/gonb/gonbui/plotly"
    grob "github.com/MetalBlueberry/go-plotly/graph_objects"
)

func renamePoint(p *plots.Point) {
    p.MetricName = strings.Replace(p.MetricName, "Eval on ", "", -1)
    
    p.MetricName = strings.Replace(p.MetricName, "Accuracy", "acc", -1)
    p.MetricName = strings.Replace(p.MetricName, "Loss+Regularization", "loss+", -1)
    p.MetricName = strings.Replace(p.MetricName, "Loss", "loss", -1)
    
    p.MetricName = strings.Replace(p.MetricName, " Mean ", "#", -1)
    p.MetricName = strings.Replace(p.MetricName, " Moving Average ", "~", -1)
}

func FilterPoints(points plots.Points) {
    points.Map(renamePoint)
    batchSize := float64(4919)
    _ = batchSize
    points.Filter(func (p plots.Point) bool {
        return strings.Index(p.MetricName, "Batch") < 0 &&
            // strings.Index(p.MetricName, "~") < 0 &&
            strings.Index(p.MetricName, "Train:~loss") < 0 &&
            strings.Index(p.MetricName, "Train:~acc") < 0 &&
            strings.Index(p.MetricName, "loss+") < 0 &&
            // math.Abs(math.Round(p.Step / batchSize)*batchSize - p.Step) < 10 &&
            true
    })
}

%%
checkpoints := []string{"gnn-baseline_24", "gnn-baseline_26"}
prefixes := []string{"(A)", "(B)"}

var merged plots.Points
for ii, checkpoint := range checkpoints {
    pts := plots.NewPoints(must.M1(plots.LoadPointsFromCheckpoint(path.Join(*flagDataDir, checkpoint))))
    FilterPoints(pts)
    pts.Map(func(p *plots.Point) {
       p.MetricName = prefixes[ii] + p.MetricName 
    })
    
    if merged == nil {
        merged = pts
    } else {
        merged.Add(pts)
    }
}

fmt.Printf("%s\n", merged)

╭───────┬───────────────┬───────────────┬───────────────┬───────────────┬────────────────┬────────────────┬────────────────┬────────────────╮
│ Step  │ (A)train:#acc │ (A)valid:#acc │ (B)train:#acc │ (B)valid:#acc │ (A)train:#loss │ (A)valid:#loss │ (B)train:#loss │ (B)valid:#loss │
├───────┼───────────────┼───────────────┼───────────────┼───────────────┼────────────────┼────────────────┼────────────────┼────────────────┤
│ 200   │ 0.180916      │ 0.171514      │ 0.166444      │ 0.122666      │ 3.773040       │ 3.754423       │ 3.932029       │ 4.029837       │
│ 440   │ 0.247725      │ 0.240095      │ 0.228974      │ 0.250166      │ 3.233717       │ 3.228036       │ 3.304488       │ 3.325431       │
│ 728   │ 0.279135      │ 0.266171      │ 0.276081      │ 0.279186      │ 2.960663       │ 2.991662       │ 2.985939       │ 3.023475       │
│ 1074  │ 0.305296      │ 0.280638      │ 0.302901      │ 0.301041      │ 2.745518       │ 2.798317       │ 2.795963       │ 2.841690       │
│ 1489

In [15]:
import (
    stdplots "github.com/gomlx/gomlx/examples/notebook/gonb/plots"
    "github.com/gomlx/gomlx/examples/notebook/gonb/plotly"
)
%%
versions := []int{71, 72}

plots := plotly.New()
for _, version := range versions {
    // checkpoint := fmt.Sprintf("gnn-baseline_%d", version)
    checkpoint := fmt.Sprintf("small_%d", version)
    prefix := fmt.Sprintf("(%d) ", version)
    must.M(plots.LoadCheckpointData(path.Join(*flagDataDir, checkpoint), 
        func (pt *stdplots.Point) bool {
            renamePoint(pt)
            pt.MetricName = prefix+pt.MetricName
            return strings.Index(pt.MetricName, "Batch") < 0 &&
                // strings.Index(pt.MetricName, "~") < 0 &&
                strings.Index(pt.MetricName, "Train:~loss") < 0 &&
                strings.Index(pt.MetricName, "Train:~acc") < 0 &&
                strings.Index(pt.MetricName, "loss+") < 0 &&
                // math.Abs(math.Round(pt.Step / batchSize)*batchSize - pt.Step) < 10 &&
                true
        }))
}

// plots.Dynamic()
// plots.DynamicPlot(false)
plots.Plot()

### Layer-wise Inference Experiments

In [25]:
import (
    mldata "github.com/gomlx/gomlx/ml/data"
    "github.com/gomlx/gomlx/ml/context/checkpoints"
)

func loadCtx(checkpointName string) (ctx *context.Context) {
    ctx = context.NewContext(manager)
	mag.UploadOgbnMagVariables(ctx)

    // Exclude from saving all the variables created by the `mag` package -- specially the frozen papers embeddings,
    // which take most space.
    var varsToExclude []*context.Variable
    ctx.InAbsPath(mag.OgbnMagVariablesScope).EnumerateVariablesInScope(func(v *context.Variable) {
        varsToExclude = append(varsToExclude, v)
    })
    _ = must.M1(checkpoints.
                Build(ctx).
                DirFromBase(checkpointName, *flagDataDir).
                ExcludeVarsFromSaving(varsToExclude...).
                Done())
    return ctx
}

%%
must.M(mag.Download(*flagDataDir))
magSampler := must.M1(mag.NewSampler(*flagDataDir))
magStrategy := mag.MagStrategy(magSampler, mag.BatchSize, nil)
ctx := loadCtx("samll_71")
_ = mag.LayerWiseInference(ctx, magStrategy)

fmt.Println("Done.")

Rule "citations": NumNodes=736389
Rule "citationsAuthors": NumNodes=1134649
Rule "papersByCitationAuthors": NumNodes=736389
Rule "papersByCitationAuthorsTopics": NumNodes=59965
Rule "citationAuthorsInstitutions": NumNodes=8740
Rule "citationsTopics": NumNodes=59965
Rule "seedsBase": NumNodes=736389
Rule "seedsAuthors": NumNodes=1134649
Rule "papersByAuthors": NumNodes=736389
Rule "papersByAuthorsTopics": NumNodes=59965
Rule "authorsInstitutions": NumNodes=8740
Rule "seedsTopics": NumNodes=59965


panic: Missing OGBN-MAG dataset variables ("EdgeWrites"), pls call UploadOgbnMagVariables() on context first.

goroutine 1 [running]:
github.com/gomlx/exceptions.Panicf(...)
	/home/janpf/src/go/pkg/mod/github.com/gomlx/exceptions@v0.0.3/exceptions.go:92
github.com/gomlx/gomlx/examples/ogbnmag.getMagVar(0xc0001a02a0?, 0xc00017a900, {0xc00012c8e0, 0xa})
	/home/janpf/Projects/gomlx/examples/ogbnmag/model.go:30 +0xae
github.com/gomlx/gomlx/examples/ogbnmag.createEdgesIndices(0xc0000d0900, 0xc00017a900)
	/home/janpf/Projects/gomlx/examples/ogbnmag/lwinference.go:105 +0x159
github.com/gomlx/gomlx/examples/ogbnmag.createEdgesInputs(0xc00017a900?, 0xc00017a900, 0xc0000d0180, {0xc00017e248, 0x26, 0x41})
	/home/janpf/Projects/gomlx/examples/ogbnmag/lwinference.go:94 +0x36
github.com/gomlx/gomlx/examples/ogbnmag.LayerWiseInference.BuildLayerWiseInferenceModel.func1(0xc0000d0240, 0xc00017a900)
	/home/janpf/Projects/gomlx/examples/ogbnmag/lwinference.go:44 +0xc8
reflect.Value.call({0x7540c0?, 0xc00