# Flow Matching Examples

A study on FlowMatching based on Meta's ["Flow Matching Guide and Code"](https://ai.meta.com/research/publications/flow-matching-guide-and-code/) study published by Meta.

### Imports and Global Variables

* Also we redirect some projects to the local cloned versions, using `go work`.


In [1]:
!*rm -f go.work && go work init && go work use . "${HOME}/Projects/gomlx" "${HOME}/Projects/gopjrt" "${HOME}/Projects/gonb"
%goworkfix

	- Added replace rule for module "github.com/gomlx/gopjrt" to local directory "/home/janpf/Projects/gopjrt".
	- Added replace rule for module "github.com/gomlx/gomlx" to local directory "/home/janpf/Projects/gomlx".
	- Added replace rule for module "github.com/janpfeifer/gonb" to local directory "/home/janpf/Projects/gonb".


In [2]:
import (
    "bytes"
    "flag"
    colors "image/color"
    "github.com/gomlx/gomlx/backends"
    _ "github.com/gomlx/gomlx/backends/xla"
    fm "github.com/gomlx/gomlx/examples/FlowMatching"
    . "github.com/gomlx/gomlx/graph"
    "github.com/gomlx/gomlx/ml/context"
    "github.com/janpfeifer/gonb/gonbui"
    "github.com/janpfeifer/must"
    "gonum.org/v1/plot"
    "gonum.org/v1/plot/plotter"
    plotvg "gonum.org/v1/plot/vg"
)

var (
    backend = backends.New()
    _ *Node = nil
)

### Plotting A Histogram with [gonum.org/v1/plot](https://github.com/gonum/plot)

We define the `HistogramXYs` function to plot a histogram of a distribution of (X, Y) coordinates.

In [3]:
var Blue = colors.RGBA{0, 0, 0xFF, 0xFF}

// HistogramXYs returns a SVG plot as a string.
// xys must be shaped [numPoints, 2].
func HistogramXYs(title string, xys [][]float32, width, height int) string {
    p := plot.New()
    p.Title.Text = title
    plotter.DefaultLineStyle.Width = plotvg.Points(1)
	plotter.DefaultGlyphStyle.Radius = plotvg.Points(1.5)
    plotter.DefaultGlyphStyle.Color = Blue
    
    pts := make(plotter.XYs, len(xys)+2)
    for ii, xy := range xys {
        pts[ii].X, pts[ii].Y = float64(xy[0]), float64(xy[1])
    }

    // Set the limits of the histogram with 2 fake points.
    ii := len(xys)
    pts[ii].X, pts[ii].Y = -3, -3
    ii++
    pts[ii].X, pts[ii].Y = 3, 3
    scatter := must.M1(plotter.NewScatter(pts))
    p.Add(scatter)
    
    writer := must.M1(p.WriterTo(plotvg.Points(float64(width)), plotvg.Points(float64(height)), "svg"))
    var buf = &bytes.Buffer{}
    writer.WriteTo(buf)
    return buf.String()
}

## Section 2: Quick tour and key concepts / Code 1

This is the GoMLX version of the `Code 1` in the paper, originally written in PyTorch.

But first, let's start plotting our source $p_{t=0}(X)$ and target $q(X) = p_{t=1}(X)$ distribution: 

In [4]:
%%
numPoints := 200
ctx := context.New()
normalPoints := context.ExecOnce(backend, ctx, func (ctx *context.Context, g *Graph) *Node {
        return ctx.RandomNormal(g, shapes.Make(dtypes.F32, numPoints, 2))
    }).Value().([][]float32)
moonsPoints := context.ExecOnce(backend, ctx, func (ctx *context.Context, g *Graph) *Node {
        return fm.MakeMoons(ctx, g, numPoints)
    }).Value().([][]float32)
gonbui.DisplayHTMLF("<table><tr><td>%s</td><td>%s</td></tr></table>",
    HistogramXYs("Source Distribution: Normal", normalPoints, 200, 200),
    HistogramXYs("Target Distribution: Moons", moonsPoints, 200, 200))

0,1
Source Distribution: Normal -3 0 3 -3 0 3,Target Distribution: Moons -3 0 3 -3 0 3


* `dψdt(ctx, xy, t)`: model function for the "slope" ODE function $\frac{d}{dt}\psi(X, t)$ that we want to learn;
* `step(ctx, xy, tStart, tEnd)`: the step function, that moves value $X_{t_{start}}$ to $X_{t_{end}}$ taking one step using the predicted $\frac{d}{dt}\psi(X, t)$.

In [5]:
func dψdt(ctx *context.Context, xy, t *Node) *Node {
    if t.IsScalar() {
        batchSize := xy.Shape().Dimensions[0]
        t = BroadcastToDims(t, batchSize, 1)
    }
    inputs := Concatenate([]*Node{t, xy}, -1)
    return fnn.New(ctx.In("dψdt"), inputs, /*num_outputs*/ 2).
        NumHiddenLayers(3, 64).
        Activation(activations.TypeGeluApprox).
        Done()
}

func step(ctx *context.Context, xy, tStart, tEnd *Node) *Node {
    // For simplicity, using midpoint ODE solver in this example
    slope0 := dψdt(ctx, xy, tStart)
    ΔT := Sub(tEnd, tStart)
    halfΔT := DivScalar(ΔT, 2)
    midPoint := Add(xy, Mul(slope0, halfΔT))
    slope1 := dψdt(ctx, midPoint, Add(tStart, halfΔT))
    return Add(xy, Mul(slope1, ΔT))
}

* Simple training loop:

In [6]:
var DType = dtypes.Float32

func trainStep(ctx *context.Context, g *Graph, batchSize int, opt optimizers.Interface) {
	xy1 := fm.MakeMoons(ctx, g, batchSize)
	xy0 := ctx.RandomNormal(g, shapes.Make(DType, batchSize, 2))
	t := ctx.RandomUniform(g, shapes.Make(DType, batchSize, 1))
	xyT := Add(
		Mul(OneMinus(t), xy0),
		Mul(t, xy1))
	targetSlope := Sub(xy1, xy0)
	predictedSlope := dψdt(ctx, xyT, t)
	loss := losses.MeanSquaredError([]*Node{targetSlope}, []*Node{predictedSlope})
	opt.UpdateGraph(ctx, g, loss)
}

func train(ctx *context.Context, numSteps, batchSize int) {
	opt := optimizers.Adam().LearningRate(0.01).Done()    
    trainStepExec := context.NewExec(backend, ctx, func(ctx *context.Context, g *Graph) {
		trainStep(ctx, g, batchSize, opt)
	})
    for _ = range numSteps {
        _ = trainStepExec.Call()
    }
}

%%
ctx := context.New().Checked(false)
start := time.Now()
train(ctx, 100, 256)
fmt.Printf("Training 100 steps in %s\n", time.Since(start))

Training 100 steps in 1.162776425s


* **Plotting results**

In [7]:
%%
ctx := context.New().Checked(false)
numTrainSteps := 10_000
batchSize := 256
start := time.Now()
train(ctx, numTrainSteps, batchSize)
fmt.Printf("Training %d steps (batchSize=%d) in %s\n", numTrainSteps, batchSize, time.Since(start))

numPoints := 100
numPlots := 9
svgPlots := make([]string, 0, numPlots)
// xy_0 are a normally distributed points.
xy := context.ExecOnce(backend, ctx, func (ctx *context.Context, g *Graph) *Node {
    return ctx.RandomNormal(g, shapes.Make(dtypes.F32, numPoints, 2))
})
stepExec := context.NewExec(backend, ctx, step)
for pIdx := range numPlots {
    tEnd := float32(pIdx) / float32(numPlots-1)  // From 0.0 to 1.0
    if pIdx > 0 {
        // If not the initial state, take one step forward from tStart to tEnd
        tStart := float32(pIdx-1) / float32(numPlots-1)
        xy = stepExec.Call(xy, tStart, tEnd)[0]
    }
    svgPlot := HistogramXYs(fmt.Sprintf("t=%.2f", tEnd), xy.Value().([][]float32), 200, 200)
    svgPlots = append(svgPlots, svgPlot)
}

gonbui.DisplayHTMLF("<table><tr><td>\n%s\n</td></tr></table>", strings.Join(svgPlots, "\n</td><td>\n"))

Training 10000 steps (batchSize=256) in 3.378114957s


0,1,2,3,4,5,6,7,8
t=0.00 -3 0 3 -3 0 3,t=0.12 -3 0 3 -3 0 3,t=0.25 -3 0 3 -3 0 3,t=0.38 -3 0 3 -3 0 3,t=0.50 -3 0 3 -3 0 3,t=0.62 -3 0 3 -3 0 3,t=0.75 -3 0 3 -3 0 3,t=0.88 -3 0 3 -3 0 3,t=1.00 -3 0 3 -3 0 3
