In [2]:
!*rm -f go.work && go work init && go work use . "${HOME}/Projects/gomlx"
%goworkfix

	- Added replace rule for module "github.com/gomlx/gomlx" to local directory "/home/janpf/Projects/gomlx".


## Downloading Dataset

The method `mag.Download()` will download the dataset to the data directory, if not yet downloaded.
It then converts the dataset to tensors, which are then available for use.

The tensor is then saved for faster access. After saved, the next call to `mag.Download()` will take 1/2s.

In [3]:
%env GOMLX_PLATFORM=GPU

Set: GOMLX_PLATFORM="GPU"


In [4]:
import (
    "flag"
    mag "github.com/gomlx/gomlx/examples/ogbnmag"
    "github.com/janpfeifer/must"
)

var (
    flagDataDir   = flag.String("data", "~/work/ogbnmag", "Directory to cache downloaded and generated dataset files.")
)

%%
start := time.Now()
must.M(mag.Download(*flagDataDir))
magSampler := must.M1(mag.NewSampler(*flagDataDir))

fmt.Printf("Elapsed: %s\n", time.Since(start))
fmt.Printf("%s\n", magSampler.String())

Elapsed: 1.050133099s
Sampler: 4 node types, 8 edge types
	NodeType "fields_of_study": 59,965 items
	NodeType "papers": 736,389 items
	NodeType "authors": 1,134,649 items
	NodeType "institutions": 8,740 items
	EdgeType "cites": ["papers"]->["papers"], 5,416,271 edges
	EdgeType "citedBy": ["papers"]->["papers"], 5,416,271 edges
	EdgeType "affiliatedWith": ["authors"]->["institutions"], 1,043,998 edges
	EdgeType "affiliations": ["institutions"]->["authors"], 1,043,998 edges
	EdgeType "hasTopic": ["papers"]->["fields_of_study"], 7,505,078 edges
	EdgeType "topicHasPapers": ["fields_of_study"]->["papers"], 7,505,078 edges
	EdgeType "writes": ["authors"]->["papers"], 7,145,660 edges
	EdgeType "writtenBy": ["papers"]->["authors"], 7,145,660 edges


## FNN model, No Graph

The first model will only use the paper features, and no relations. It serves as a baseline.

In a quick experiment, without much hyperparameter tuning, we got 27.27% on test accuracy, which is inline with the corresponding results in [the leaderboard](https://ogb.stanford.edu/docs/leader_nodeprop/#ogbn-mag) (the Multi Layer Perceptron MLP entry in the bottom).

In [40]:
import (
	. "github.com/gomlx/gomlx/graph"
	"github.com/gomlx/gomlx/ml/context"
)
var manager = NewManager()

func config() (ctx *context.Context) {
    ctx = context.NewContext(manager)
    ctx.RngStateReset()
    ctx.SetParams(map[string]any{
        "train_steps": 2_000_000, 
        "batch_size": 128,
        "optimizer": "adamw", 
        optimizers.LearningRateKey: 0.0001,
        layers.L2RegularizationKey: 1e-4,
        "normalization": "layer",
        "dropout": 0.1,
        "hidden_layers": 2,
        "num_nodes": 256,
        "plots": true,
    })
    return
}

In [41]:
import "github.com/gomlx/gomlx/examples/ogbnmag/fnn"

%%
must.M(mag.Download(*flagDataDir))
ctx := config()
ctx.SetParam("checkpoint", path.Join(*flagDataDir, "fnn-baseline"))
ctx.SetParam("num_checkpoints", 10)
err := fnn.Train(ctx)
if  err != nil {
    fmt.Printf("%+v\n", err)
}


> papers features input shape: (Float32)[128 132]
Training (2000000 steps):   13% [[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m>[0m...................................] (2069 steps/s) [2m19s:13m58s][0m [loss=3.070] [~loss=2.987] [~acc=29.03%]        

^C
signal: interrupt


## GNN (Graph Neural Networks)

To use GNNs we need first graphs to run the model on. The issue is, generally the graphs in real life are too large (social networks, relational datases, etc.), including OGNB-MAG.

So instead we use sampled sub-graphs to train. For inference we can also use sampled subgraphs, but later we show a work around where we can do inference, a node type at a time, where we don't need sampling.

### Sampling Sub-Graphs

We follow the same sampling strategy used in [TensorFlow GNN](https://github.com/tensorflow/gnn) library, describe it its [OGBN-MAG notebook](https://github.com/tensorflow/gnn/blob/main/examples/notebooks/ogbn_mag_e2e.ipynb).

In [6]:
import "github.com/gomlx/gomlx/examples/ogbnmag/sampler"

func GnnStrategy(magSampler *sampler.Sampler, batchSize int, seedIdsCandidates []int32) *sampler.Strategy {
    strategy := magSampler.NewStrategy()
    var seeds *sampler.Rule
    if seedIdsCandidates == nil {
        seeds = strategy.Nodes("seeds", "papers", batchSize)
    } else {
        seeds = strategy.NodesFromSet("seeds", "papers", batchSize, seedIdsCandidates)
    }    
    citations := seeds.FromEdges("citations", "cites", 8)

    // Authors
    seedsAuthors := seeds.FromEdges("seedsAuthors", "writtenBy", 8)
    citationsAuthors := citations.FromEdges("citationsAuthors", "writtenBy", 8)

    // Co-authored papers
    coauthoredPapers := seedsAuthors.FromEdges("coauthoredPapers", "writes", 8)
    coauthoredFromCitations := citationsAuthors.FromEdges("coauthoredFromCitations", "writes", 8)
    
     // Affiliations
    _ = seedsAuthors.FromEdges("authorsInstitutions", "affiliatedWith", 8)
    _ = citationsAuthors.FromEdges("citationAuthorsInstitutions", "affiliatedWith", 8)

    // Topics
    _ = seeds.FromEdges("seedsTopics", "hasTopic", 8)
    _ = coauthoredPapers.FromEdges("coauthoredTopics", "hasTopic", 8)
    _ = citations.FromEdges("citationsTopics", "hasTopic", 8)
    _ = coauthoredFromCitations.FromEdges("coauthoredFromCitationsTopics", "hasTopic", 8)
    
    return strategy
}

%%
must.M(mag.Download(*flagDataDir))
magSampler := must.M1(mag.NewSampler(*flagDataDir))
strategy := GnnStrategy(magSampler, 32, mag.TrainSplit.Local().FlatCopy().([]int32))
fmt.Printf("%s\n", strategy)


Sampling strategy: (12 rules)
> Rule "seeds": type=Node, nodeType="papers", shape=(Int32)[32] (size=32), nodeSet.size=629571
  > Rule "citations": type=Edge, nodeType="papers", shape=(Int32)[32 8] (size=256), sourceRule="seeds", edgeType="cites"
    > Rule "citationsAuthors": type=Edge, nodeType="authors", shape=(Int32)[32 8 8] (size=2048), sourceRule="citations", edgeType="writtenBy"
      > Rule "coauthoredFromCitations": type=Edge, nodeType="papers", shape=(Int32)[32 8 8 8] (size=16384), sourceRule="citationsAuthors", edgeType="writes"
        > Rule "coauthoredFromCitationsTopics": type=Edge, nodeType="fields_of_study", shape=(Int32)[32 8 8 8 8] (size=131072), sourceRule="coauthoredFromCitations", edgeType="hasTopic"
      > Rule "citationAuthorsInstitutions": type=Edge, nodeType="institutions", shape=(Int32)[32 8 8 8] (size=16384), sourceRule="citationsAuthors", edgeType="affiliatedWith"
    > Rule "citationsTopics": type=Edge, nodeType="fields_of_study", shape=(Int32)[32 8 8] (si

### Benchmarking Sampling

Let's get some benchmarking on the sampler, to make sure it's not slowing training:

1. Benchmark sequential the sampling of one sub-graph: `BenchmarkSequentialMagSampler`
2. Benchmark parallel sampling of one sub-graph: `BenchmarkParallelMagSampler`
3. Benchmark parallel sampling of one epoch (using parallel sampler): `BenchmarkOneEpochMagSampler`

In [38]:
import (
    mldata "github.com/gomlx/gomlx/ml/data"
)

func BenchmarkSequentialMagSampler(b *testing.B) {
    must.M(mag.Download(*flagDataDir))
    magSampler := must.M1(mag.NewSampler(*flagDataDir))
    strategy := GnnStrategy(magSampler, 32, mag.TrainSplit.Local().FlatCopy().([]int32))
    
    ds := strategy.NewDataset("train").Infinite().Shuffle()
    b.ResetTimer()
    for _ = range b.N {
        _, inputs, _, err := ds.Yield()
        if err != nil {
            b.Fatalf("Failed to sample: %+v", err)
        }
        _ = inputs
    }
}

func BenchmarkParallelMagSampler(b *testing.B) {
    must.M(mag.Download(*flagDataDir))
    magSampler := must.M1(mag.NewSampler(*flagDataDir))
    strategy := GnnStrategy(magSampler, 32, mag.TrainSplit.Local().FlatCopy().([]int32))
    
    ds := strategy.NewDataset("train").Infinite().Shuffle()
    parallelDS := mldata.Parallel(ds)
    b.ResetTimer()
    for _ = range b.N {
        _, inputs, _, err := parallelDS.Yield()
        if err != nil {
            b.Fatalf("Failed to sample: %+v", err)
        }
        _ = inputs
    }
}

func BenchmarkOneEpochMagSampler(b *testing.B) {
    must.M(mag.Download(*flagDataDir))
    magSampler := must.M1(mag.NewSampler(*flagDataDir))
    strategy := GnnStrategy(magSampler, 32, mag.TrainSplit.Local().FlatCopy().([]int32))
    
    ds := strategy.NewDataset("train").Shuffle()
    parallelDS := mldata.Parallel(ds)
    b.ResetTimer()
    var count int
    for _ = range b.N {
        count = 0
        for {
            _, inputs, _, err := parallelDS.Yield()
            if err == io.EOF {
                break
            }
            if err != nil {
                b.Fatalf("Failed to sample: %+v", err)
            }
            _ = inputs
            count ++
        }
        parallelDS.Reset()
    }
}

%%
%test -test.bench=. -test.run=Benchmark -test.benchtime=5s

goos: linux
goarch: amd64
pkg: gonb_e851bff2
cpu: 12th Gen Intel(R) Core(TM) i9-12900K
BenchmarkOneEpochMagSampler-24      	       3	1829138955 ns/op
BenchmarkParallelMagSampler-24      	   69396	     89149 ns/op
BenchmarkSequentialMagSampler-24    	   12021	    520495 ns/op
PASS
