# Rational Function Initialization

Creating initial coefficients to approximate arbitrary functions for rational functions


In [1]:
!*rm -f go.work && go work init && go work use . "${HOME}/Projects/gemma" "${HOME}/Projects/gomlx" "${HOME}/Projects/gopjrt" "${HOME}/Projects/gonb" 
%goworkfix

	- Added replace rule for module "github.com/gomlx/gemma" to local directory "/home/janpf/Projects/gemma".
	- Added replace rule for module "github.com/gomlx/gomlx" to local directory "/home/janpf/Projects/gomlx".
	- Added replace rule for module "github.com/janpfeifer/gonb" to local directory "/home/janpf/Projects/gonb".
	- Added replace rule for module "github.com/gomlx/gopjrt" to local directory "/home/janpf/Projects/gopjrt".


## Learning Rational Function parameters

This can be used to create initialization values for the rational functions, to mimic any given function.

Buf first let's add the imports and some plotting functions:

In [2]:
//#@title Imports
import (
    "github.com/gomlx/gomlx/backends"
    . "github.com/gomlx/gomlx/graph"
    "github.com/gomlx/gomlx/ml/context"
    "github.com/gomlx/gomlx/ml/data"
    "github.com/gomlx/gomlx/ml/layers/activations"
    "github.com/gomlx/gomlx/ml/layers/rational"
    "github.com/gomlx/gomlx/ml/train"
    "github.com/gomlx/gomlx/types/tensors"
    "github.com/janpfeifer/must"

    _ "github.com/gomlx/gomlx/backends/xla"

    // Plotting
	gonbplotly "github.com/janpfeifer/gonb/gonbui/plotly"
    grob "github.com/MetalBlueberry/go-plotly/graph_objects"
)

var (
    _ = Add
    Backend = backends.New()
)

%%
fmt.Printf("Imports")

Imports

In [3]:
// PlotXY
func PlotXY(title string, xs, ys []float64) *grob.Fig {
    fig := &grob.Fig{
        Layout: &grob.Layout{
            Title: &grob.LayoutTitle{
                Text: title,
            },
            Xaxis: &grob.LayoutXaxis{
                Showgrid: grob.True,
            },
            Yaxis: &grob.LayoutYaxis{
                Showgrid: grob.True,
            },
        },
        Data: grob.Traces{
            &grob.Scatter{
    			Type: grob.TraceTypeScatter,
    			Line: &grob.ScatterLine{
    				Shape: grob.ScatterLineShapeLinear,
    			},
    			Mode: "lines+markers",
    			X:    xs,
    			Y:    ys,
    		},
        },
    }
    return fig
}


In [4]:
// PlotFuncs

var PlotNumPoints = 1001

// PlotFn takes as input a context for variables (can be ignored) and a vector of xs, and it should return the name of the function and the corresponding ys.
type PlotFn func(ctx *context.Context, xs *Node) (name string, ys *Node)

// PlotFuncs takes the title of the graph, the range of X (minX and maxX) and the functions to plot and returns a figure.
func PlotFuncs(title string, minX, maxX float64, ctx *context.Context, fns ...PlotFn) *grob.Fig {
    var fnNames []string
    allValues := context.NewExec(Backend, ctx, func (ctx *context.Context, g *Graph) []*Node {
        xs := Iota(g, shapes.Make(dtypes.Float64, PlotNumPoints), 0)
        xs = MulScalar(xs, (maxX-minX)/float64(PlotNumPoints-1))
        xs = AddScalar(xs, minX)

        outputs := make([]*Node, 0, len(fns)+1)
        outputs = append(outputs, xs)
        fnNames = make([]string, 0, len(fns))
        for _, fn := range fns {
            name, ys := fn(ctx, xs)
            fnNames = append(fnNames, name)
            outputs = append(outputs, ys)
        }
        return outputs
    }).Call()
    xs := tensors.CopyFlatData[float64](allValues[0])
    allYs := xslices.Map(allValues[1:], func(t *tensors.Tensor) []float64 {
        return tensors.CopyFlatData[float64](t)
    })
    fig := &grob.Fig{
        Layout: &grob.Layout{
            Title: &grob.LayoutTitle{
                Text: title,
            },
            Xaxis: &grob.LayoutXaxis{
                Showgrid: grob.True,
            },
            Yaxis: &grob.LayoutYaxis{
                Showgrid: grob.True,
            },
            // Legend: &grob.LayoutLegend{
                //Y:       -0.2,
                //X:       1.0,
                //X anchor: grob.LayoutLegendX anchorRight,
                //Y anchor: grob.LayoutLegendY anchorTop,
            // },
        },
        Data: make(grob.Traces, len(fns)),
    }
    for fnIdx := range len(fns) {
        fig.Data[fnIdx] = &grob.Scatter{
            Name: fnNames[fnIdx],
            Type: grob.TraceTypeScatter,
            Line: &grob.ScatterLine{
                Shape: grob.ScatterLineShapeLinear,
            },
            Mode: "lines",
            X:    xs,
            Y:    allYs[fnIdx],
        }
    }
    return fig
}


In [5]:
// Example functions and plotting

func swish(ctx *context.Context, x *Node) (string, *Node) {
    return "Swish(x)", activations.Swish(x)
}

func gelu(ctx *context.Context, x *Node) (string, *Node) {
    return "Gelu(x)", activations.Gelu(x)
}

%%
// fig :=PlotFuncs("test", -5.0, 5.0, context.New(), swish, gelu, target)
// gonbplotly.DisplayFig(fig)

In [6]:
// estimateVarianceGain used to adjust the W parameter for different curves:

func swish(ctx *context.Context, x *Node) (string, *Node) {
    return "Swish(x)", activations.Swish(x)
}

func relu(ctx *context.Context, x *Node) (string, *Node) {
    return "Relu(x)", activations.Relu(x)
}

func identity(ctx *context.Context, x *Node) (string, *Node) {
    return "Identity(x)", x
}

func estimateVarianceGain(ctx *context.Context, fn PlotFn, numPoints int) float64 {
    return tensors.ToScalar[float64](
        context.ExecOnce(Backend, ctx, func (ctx *context.Context, g *Graph) *Node {
            // input has a variance of 1
            // rng := RngStateFromSeed(42)
            // input := RandomNormal(g, shapes.Make(dtypes.Float64, numPoints))
            input := ctx.RandomNormal(g, shapes.Make(dtypes.Float64, numPoints))
            _, values := fn(ctx, input)
            return Inverse(ReduceAllMean(Square(values)))
        }))
}

%%
numP := 10_000_000
fmt.Printf("Identity(x) inverse of gain: %g\n", estimateVarianceGain(context.New(), identity, numP))
fmt.Printf("Swish(x) inverse of gain: %g\n", estimateVarianceGain(context.New(), swish, numP))
fmt.Printf("Relu(x) inverse of gain: %g\n", estimateVarianceGain(context.New(), relu, numP))


Identity(x) inverse of gain: 1.0003968298250188
Swish(x) inverse of gain: 2.810054071831942
Relu(x) inverse of gain: 1.99979725500813


Now let's create a gradient descent optimizer for an arbitrary univariate function.

The cell below defines `GenerateRationalCacheLine(target PlotFn, numeratorDegrees, denominatorDegrees int, rationalVersion string)` that conveniently outputs
a cache line that can be copy&pasted to the file `cache.go` and used from there.

In [7]:
// GenerateRationalCacheLine
var (
    BatchSize = 10_000
    NumSteps = 50_000
)

func loopTrain(ctx *context.Context, targetFn, trainableFn PlotFn, numSteps int) error {
    ds := data.NewConstantDataset()
    var targetName, trainableName string
    // lossFn takes the predictions and the labels from the target, and return the mean-squared-loss.
    lossFn := func(labels, logits []*Node) *Node {
        predicted, label := logits[0], logits[1]
        return ReduceAllMean(Square(Sub(predicted, label)))
    }
    modelFn := func(ctx *context.Context, spec any, inputs []*Node) []*Node {
        g := inputs[0].Graph()
        input := ctx.RandomNormal(g, shapes.Make(dtypes.Float64, BatchSize))
        minX, maxX := -10.0,10.0
        xs := Iota(g, shapes.Make(dtypes.Float64, BatchSize), 0)
        xs = MulScalar(xs, (maxX-minX)/float64(BatchSize-1))
        xs = AddScalar(xs, minX)
        input = Concatenate([]*Node{input, xs}, -1)

        var target, predicted *Node
        trainableName, predicted = trainableFn(ctx, input)
        targetName, target = targetFn(ctx, input)        
        return []*Node{predicted, target}
    }
    trainer := train.NewTrainer(Backend, ctx, modelFn, lossFn, optimizers.ByName(ctx, "adam"), nil, nil)
    loop := train.NewLoop(trainer)
	commandline.AttachProgressBar(loop) // Attaches a progress bar to the loop.
    metrics, err := loop.RunSteps(ds, numSteps)
    if err != nil { 
        return err
    }
    fmt.Printf("%s/%s: loss=%v\n", targetName, trainableName, metrics[0])
    return nil
}

// GenerateRationalCacheLine takes the target function and approximates with a rational function with the given parameters
// and outputs the cache line that can be added to the `cache.go` file.
//
// See rational.New for details on the parameters.
func GenerateRationalCacheLine(approximateName string, targetFn PlotFn, numeratorDegrees, denominatorDegrees int, rationalVersion string) {
    ctx := context.New()
    learnableFn := func (ctx *context.Context, x *Node) (string, *Node) {
        ctx = ctx.In("rational")
        output := rational.New(ctx, x).
            WithDegrees(numeratorDegrees, denominatorDegrees).
            WithInputGroups(1).
            Version(rationalVersion).
            Approximate("identity").
            Done()
        return "Rational(x)", output
    }
    ctx.SetParam(optimizers.LearningRateKey, 1.0e-4)
    err := loopTrain(ctx, targetFn, learnableFn, NumSteps)
    if err != nil {
        fmt.Fprintf(os.Stderr, "%+v\n", err)
        os.Exit(1)
    }
    ctx = ctx.Reuse()  // After function is trained, we want to reuse the learned values.
    gonbplotly.DisplayFig(PlotFuncs("Target vs Learned", -5, 5, ctx, targetFn, learnableFn))

    gain := estimateVarianceGain(ctx, learnableFn, 10_000_000)
    numT := ctx.InspectVariable("/rational", "numeratorCoeffs").Value()
    denT := ctx.InspectVariable("/rational", "denominatorCoeffs").Value()

    fmt.Println("Cache entry line:\n")
    fmt.Printf("\t\tinitCacheKey{Approximation: %q, Version: %q, NumeratorDegree: %d, DenominatorDegree: %d}: &initCacheValue{\n"+
               "\t\t\tNum: %#v,\n\t\t\tDen: %#v,\n\t\t\tGainEstimate: %g},\n",
               approximateName, rationalVersion, numeratorDegrees, denominatorDegrees,
               tensors.CopyFlatData[float64](numT), tensors.CopyFlatData[float64](denT),
               gain)
}


In [8]:
func swish(ctx *context.Context, x *Node) (string, *Node) {
    return "Swish(x)", activations.Swish(x)
}

%%
BatchSize=20_000
NumSteps=30_000
GenerateRationalCacheLine("swish", swish, 6, 5, "B")

Training (30000 steps):  100% [[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m] (3615 steps/s)[0m [step=29999] [loss+=0.000] [~loss+=0.001] [~loss=0.001]        
Swish(x)/Rational(x): loss=(Float64): 0.00035127144006569727


Cache entry line:

		initCacheKey{Approximation: "swish", Version: "B", NumeratorDegree: 6, DenominatorDegree: 5}: &initCacheValue{
			Num: []float64{0.016692362691065492, 0.5235379987754564, 0.20960774317448108, 0.04770533539960032, 0.013941059440460883, 0.0026524214293613614, 0.00016557036763404634},
			Den: []float64{-0.11426863049462892, 0.02100977104137373, -0.04450592690506476, 0.0015574138000169233, -0.0003577561069601771},
			GainEstimate: 2.8468086450136547},


## Reading from `rationals_config.json`

It takes as input the file downloaded from [github.com/ml-research/rational_activations/rational/rationals_config.json](https://github.com/ml-research/rational_activations/blob/master/rational/rationals_config.json)

**Note**: the json was mal-formed and there is an invalid entry as of the time of the writing. A `CoefficientConfig` is outside a `DegreesConfig`, in the Go scheme described below. It required a bit of manual editting of the file.

It outputs the various configurations that can be copy&pasted to Go.

In [None]:
import (
    "encoding/json"
    "regexp"
    "strconv"

    "github.com/janpfeifer/must"
    "github.com/gomlx/gomlx/ml/data"
)

type CoefficientsConfig struct {
    Numerator []float64 `json:"init_w_numerator"`
    Denominator []float64 `json:"init_w_denominator"`
    UpperBound float64 `json:"ub"`
    LowerBound float64 `json:"lb"`
}

// DegreesConfig map the target function (the one being approximated) names to the approximation fields.
type DegreesConfig map[string]CoefficientsConfig

// RationalsConfig maps a version name and degrees configuration to a DegreesConfig.
type RationalsConfig map[string]DegreesConfig


func parseRationalVersion(str string) (bool, string, int, int) {
    re := regexp.MustCompile(`^Rational_version_([A-Z])(\d+)/(\d+)$`)
    matches := re.FindStringSubmatch(str)

    if matches == nil {
        return false, "", 0, 0
    }

    version := matches[1]
    numerator, _ := strconv.Atoi(matches[2])
    denominator, _ := strconv.Atoi(matches[3])

    return true, version, numerator, denominator
}


func ConvertRationalsConfig(filePath string) {
    filePath = data.ReplaceTildeInDir(filePath)
    f := must.M1(os.Open(filePath))
    defer f.Close()

    var config RationalsConfig
    dec := json.NewDecoder(f)
    must.M(dec.Decode(&config))
    for configName, DegreesConfig := range config {
        valid, version, numeratorDegree, denominatorDegree := parseRationalVersion(configName)
        if !valid {
            fmt.Printf("skipping %q\n", configName)
            continue
        }
        for approx, values := range DegreesConfig {
            _ = values
            fmt.Printf("\tinitCacheKey{Approximation:%q, Version:%q, NumeratorDegree:%d, DenominatorDegree:%d}: initCacheValue{\n",
                       approx, version, numeratorDegree, denominatorDegree)
            fmt.Printf("\t\tNumeratorCoefficients: %#v, DenominatorCoefficients: %#v},\n", values.Numerator, values.Denominator)
        }
    }
}

%%
// ConvertRationalsConfig("~/Downloads/rationals_config.json")
ConvertRationalsConfig("rationals_config.json")
