# Rational Function Initialization

Creating initial coefficients to approximate arbitrary functions for rational functions


In [1]:
!*rm -f go.work && go work init && go work use . "${HOME}/Projects/gemma" "${HOME}/Projects/gomlx" "${HOME}/Projects/gopjrt" "${HOME}/Projects/gonb" 
%goworkfix

	- Added replace rule for module "github.com/gomlx/gopjrt" to local directory "/home/janpf/Projects/gopjrt".
	- Added replace rule for module "github.com/gomlx/gemma" to local directory "/home/janpf/Projects/gemma".
	- Added replace rule for module "github.com/gomlx/gomlx" to local directory "/home/janpf/Projects/gomlx".
	- Added replace rule for module "github.com/janpfeifer/gonb" to local directory "/home/janpf/Projects/gonb".


## Learning Rational Function parameters

This can be used to create initialization values for the rational functions, to mimic any given function.

Buf first let's add the imports and some plotting functions:

In [2]:
//#@title Imports
import (
    "github.com/gomlx/gomlx/backends"
    . "github.com/gomlx/gomlx/pkg/core/graph"
    "github.com/gomlx/gomlx/pkg/core/tensors"
    "github.com/gomlx/gomlx/pkg/ml/context"
    "github.com/gomlx/gomlx/pkg/ml/datasets"
    "github.com/gomlx/gomlx/pkg/ml/layers/activations"
    "github.com/gomlx/gomlx/pkg/ml/layers/rational"
    "github.com/gomlx/gomlx/pkg/ml/train"
    "github.com/janpfeifer/must"

    _ "github.com/gomlx/gomlx/backends/default"

    // Plotting
	gonbplotly "github.com/janpfeifer/gonb/gonbui/plotly"
    
    grob "github.com/MetalBlueberry/go-plotly/generated/v2.34.0/graph_objects"
    ptypes "github.com/MetalBlueberry/go-plotly/pkg/types"
)

var (
    _ = Add
    Backend = backends.MustNew()
)

In [3]:
// PlotXY
func PlotXY(title string, xs, ys []float64) *grob.Fig {
    fig := &grob.Fig{
        Layout: &grob.Layout{
            Title: &grob.LayoutTitle{
                Text: ptypes.S(title),
            },
            Xaxis: &grob.LayoutXaxis{
                Showgrid: ptypes.True,
            },
            Yaxis: &grob.LayoutYaxis{
                Showgrid: ptypes.True,
            },
        },
        Data: []ptypes.Trace{
            &grob.Scatter{
    			// Type: grob.TraceTypeScatter,
    			Line: &grob.ScatterLine{
    				Shape: grob.ScatterLineShapeLinear,
    			},
    			Mode: "lines+markers",
    			X:    ptypes.DataArray(xs),
    			Y:    ptypes.DataArray(ys),
    		},
        },
    }
    return fig
}

In [4]:
// PlotFuncs

var PlotNumPoints = 1001

// PlotFn takes as input a context for variables (can be ignored) and a vector of xs, and it should return the name of the function and the corresponding ys.
type PlotFn func(ctx *context.Context, xs *Node) (name string, ys *Node)

// PlotFuncs takes the title of the graph, the range of X (minX and maxX) and the functions to plot and returns a figure.
func PlotFuncs(title string, minX, maxX float64, ctx *context.Context, fns ...PlotFn) *grob.Fig {
    var fnNames []string
    allValues := context.MustNewExec(Backend, ctx, func (ctx *context.Context, g *Graph) []*Node {
        xs := Iota(g, shapes.Make(dtypes.Float64, PlotNumPoints), 0)
        xs = MulScalar(xs, (maxX-minX)/float64(PlotNumPoints-1))
        xs = AddScalar(xs, minX)

        outputs := make([]*Node, 0, len(fns)+1)
        outputs = append(outputs, xs)
        fnNames = make([]string, 0, len(fns))
        for _, fn := range fns {
            name, ys := fn(ctx, xs)
            fnNames = append(fnNames, name)
            outputs = append(outputs, ys)
        }
        return outputs
    }).MustExec()
    xs := tensors.CopyFlatData[float64](allValues[0])
    allYs := xslices.Map(allValues[1:], func(t *tensors.Tensor) []float64 {
        return tensors.CopyFlatData[float64](t)
    })
    fig := &grob.Fig{
        Layout: &grob.Layout{
            Title: &grob.LayoutTitle{
                Text: ptypes.S(title),
            },
            Xaxis: &grob.LayoutXaxis{
                Showgrid: ptypes.True,
            },
            Yaxis: &grob.LayoutYaxis{
                Showgrid: ptypes.True,
            },
            // Legend: &grob.LayoutLegend{
                //Y:       -0.2,
                //X:       1.0,
                //X anchor: grob.LayoutLegendX anchorRight,
                //Y anchor: grob.LayoutLegendY anchorTop,
            // },
        },
        Data: make([]ptypes.Trace, len(fns)),
    }
    for fnIdx := range len(fns) {
        fig.Data[fnIdx] = &grob.Scatter{
            Name: ptypes.S(fnNames[fnIdx]),
            // Type: grob.TraceTypeScatter,
            Line: &grob.ScatterLine{
                Shape: grob.ScatterLineShapeLinear,
            },
            Mode: "lines",
            X:    ptypes.DataArray(xs),
            Y:    ptypes.DataArray(allYs[fnIdx]),
        }
    }
    return fig
}


In [5]:
// Example functions and plotting

func swish(ctx *context.Context, x *Node) (string, *Node) {
    return "Swish(x)", activations.Swish(x)
}

func gelu(ctx *context.Context, x *Node) (string, *Node) {
    return "Gelu(x)", activations.Gelu(x)
}

%%
// fig :=PlotFuncs("test", -5.0, 5.0, context.New(), swish, gelu, target)
// gonbplotly.DisplayFig(fig)

### Estimate of the Variance Gain

The KAT paper [1] argues that by estimating the "gain" defined as $\alpha = \mathbb{E}[\frac{Var(x)}{F(x)^2}]$, and since we assume $x \sim \mathcal{N}(0, 1)$, we can make $Var[d_{in}wF(x)] = 1$, and it does so empirically.

The function `estimateVarianceGain` calculates that $\alpha$ that is also included in the cache entry, so one can initialize the GR-KAN layers in a variance preserving way.

In [6]:
// estimateVarianceGain used to adjust the W parameter for different curves:
func estimateVarianceGain(ctx *context.Context, fn PlotFn, numPoints int) float64 {
    return tensors.ToScalar[float64](
        context.MustExecOnce(Backend, ctx, func (ctx *context.Context, g *Graph) *Node {
            // input has a variance of 1
            // rng := RngStateFromSeed(42)
            // input := RandomNormal(g, shapes.Make(dtypes.Float64, numPoints))
            input := ctx.RandomNormal(g, shapes.Make(dtypes.Float64, numPoints))
            _, values := fn(ctx, input)
            return Inverse(ReduceAllMean(Square(values)))
        }))
}

func swish(ctx *context.Context, x *Node) (string, *Node) {
    return "Swish(x)", activations.Swish(x)
}

func relu(ctx *context.Context, x *Node) (string, *Node) {
    return "Relu(x)", activations.Relu(x)
}

func identity(ctx *context.Context, x *Node) (string, *Node) {
    return "Identity(x)", x
}

%%
numP := 10_000_000
fmt.Printf("Identity(x) inverse of gain: %g\n", estimateVarianceGain(context.New(), identity, numP))
fmt.Printf("Swish(x) inverse of gain: %g\n", estimateVarianceGain(context.New(), swish, numP))
fmt.Printf("Relu(x) inverse of gain: %g\n", estimateVarianceGain(context.New(), relu, numP))

Identity(x) inverse of gain: 1.0001459406385227
Swish(x) inverse of gain: 2.81095940825933
Relu(x) inverse of gain: 2.001818138065777


### Estimate Initial Values Using Gradient Descent

Now let's create a gradient descent optimizer for an arbitrary univariate function.

The cell below defines `GenerateRationalCacheLine(target PlotFn, numeratorDegrees, denominatorDegrees int, rationalVersion string)` that conveniently outputs
a cache line that can be copy&pasted to the file `cache.go` and used from there.

In [10]:
// GenerateRationalCacheLine
var (
    BatchSize = 10_000
    NumSteps = 50_000
)

func loopTrain(ctx *context.Context, targetFn, trainableFn PlotFn, numSteps int) error {
    ds := datasets.NewConstantDataset()
    var targetName, trainableName string
    // lossFn takes the predictions and the labels from the target, and return the mean-squared-loss.
    lossFn := func(labels, logits []*Node) *Node {
        predicted, label := logits[0], logits[1]
        return Sqrt(ReduceAllMean(Square(Sub(predicted, label))))
    }
    modelFn := func(ctx *context.Context, spec any, inputs []*Node) []*Node {
        g := inputs[0].Graph()
        input := ctx.RandomNormal(g, shapes.Make(dtypes.Float64, BatchSize))
        minX, maxX := -10.0,10.0
        xs := Iota(g, shapes.Make(dtypes.Float64, BatchSize), 0)
        xs = MulScalar(xs, (maxX-minX)/float64(BatchSize-1))
        xs = AddScalar(xs, minX)
        input = Concatenate([]*Node{input, xs}, -1)

        var target, predicted *Node
        trainableName, predicted = trainableFn(ctx, input)
        targetName, target = targetFn(ctx, input)        
        return []*Node{predicted, target}
    }
    trainer := train.NewTrainer(Backend, ctx, modelFn, lossFn, optimizers.ByName(ctx, "adam"), nil, nil)
    loop := train.NewLoop(trainer)
	commandline.AttachProgressBar(loop) // Attaches a progress bar to the loop.
    metrics, err := loop.RunSteps(ds, numSteps)
    if err != nil { 
        return err
    }
    fmt.Printf("%s/%s: loss=%v\n", targetName, trainableName, metrics[0])
    return nil
}

// GenerateRationalCacheLine takes the target function and approximates with a rational function with the given parameters
// and outputs the cache line that can be added to the `cache.go` file.
//
// See rational.New for details on the parameters.
func GenerateRationalCacheLine(approximateName string, targetFn PlotFn, numeratorDegrees, denominatorDegrees int, rationalVersion string) {
    ctx := context.New()
    learnableFn := func (ctx *context.Context, x *Node) (string, *Node) {
        ctx = ctx.In("rational")
        output := rational.New(ctx, x).
            WithDegrees(numeratorDegrees, denominatorDegrees).
            WithInputGroups(1).
            Version(rationalVersion).
            Approximate("identity").
            Done()
        return "Rational(x)", output
    }
    ctx.SetParam(optimizers.LearningRateKey, 1.0e-4)
    err := loopTrain(ctx, targetFn, learnableFn, NumSteps)
    if err != nil {
        fmt.Fprintf(os.Stderr, "%+v\n", err)
        os.Exit(1)
    }
    ctx = ctx.Reuse()  // After function is trained, we want to reuse the learned values.
    gonbplotly.DisplayFig(PlotFuncs("Target vs Learned", -5, 5, ctx, targetFn, learnableFn))

    gain := estimateVarianceGain(ctx, learnableFn, 10_000_000)
    numT := ctx.InspectVariable("/rational", "numeratorCoeffs").Value()
    denT := ctx.InspectVariable("/rational", "denominatorCoeffs").Value()

    fmt.Println("Cache entry line:\n")
    fmt.Printf("\t\tinitCacheKey{Approximation: %q, Version: %q, NumeratorDegree: %d, DenominatorDegree: %d}: &initCacheValue{\n"+
               "\t\t\tNum: %#v,\n\t\t\tDen: %#v,\n\t\t\tGainEstimate: %g},\n",
               approximateName, rationalVersion, numeratorDegrees, denominatorDegrees,
               tensors.CopyFlatData[float64](numT), tensors.CopyFlatData[float64](denT),
               gain)
}


**Example 1**: Learning the $Swish(x)$ activation curve, using 6/5 degrees rational function, version "B":

In [11]:
func swish(ctx *context.Context, x *Node) (string, *Node) {
    return "Swish(x)", activations.Swish(x)
}

%%
BatchSize=50_000
NumSteps=100_000
GenerateRationalCacheLine("swish", swish, 6, 5, "B")

Swish(x)/Rational(x): loss=float64(0.012)


Cache entry line:

		initCacheKey{Approximation: "swish", Version: "B", NumeratorDegree: 6, DenominatorDegree: 5}: &initCacheValue{
			Num: []float64{-0.0012626242939438315, 0.5621299881691898, 0.2948359391204854, 0.13423005084150424, 0.038218664721966625, 0.00499405512353664, 0.0002334292300551995},
			Den: []float64{-0.3317051900332943, -0.013564178055387946, -0.07887396541471477, 0.0003131085291739652, -0.00046738201074876176},
			GainEstimate: 2.8132747357829495},


## Reading from `rationals_config.json`

It takes as input the file downloaded from [github.com/ml-research/rational_activations/rational/rationals_config.json](https://github.com/ml-research/rational_activations/blob/master/rational/rationals_config.json)

It outputs the various configurations that can be copy&pasted to Go.

In [35]:
import (
    "encoding/json"
    "regexp"
    "strconv"

    "github.com/janpfeifer/must"
    "github.com/gomlx/gomlx/pkg/ml/datasets"
)

type CoefficientsConfig struct {
    Numerator []float64 `json:"init_w_numerator"`
    Denominator []float64 `json:"init_w_denominator"`
    UpperBound float64 `json:"ub"`
    LowerBound float64 `json:"lb"`
}

// RationalsConfig maps a version name and degrees configuration to its coefficients.
type RationalsConfig map[string]CoefficientsConfig


func parseRationalVersion(str string) (bool, string, int, int, string) {
    re := regexp.MustCompile(`^Rational_version_([A-Z])(\d+)/(\d+).(\w+)$`)
    matches := re.FindStringSubmatch(str)

    if matches == nil {
        return false, "", 0, 0, ""
    }

    version := matches[1]
    numerator, _ := strconv.Atoi(matches[2])
    denominator, _ := strconv.Atoi(matches[3])
    approxFunc := matches[4]

    return true, version, numerator, denominator, approxFunc
}

func ConvertRationalsConfig(filePath string) {
    filePath = must.M1(fsutil.ReplaceTildeInDir(filePath))
    f := must.M1(os.Open(filePath))
    defer f.Close()

    config := make(RationalsConfig)
    generic := make(map[string]json.RawMessage)
    dec := json.NewDecoder(f)
    must.M(dec.Decode(&generic))
    for n1, raw1 := range generic {
        var coef CoefficientsConfig
        if err := json.Unmarshal(raw1, &coef); err == nil && len(coef.Numerator) > 0 {
            // Take coefficients at level-1:
            fmt.Printf("n1=%s -> %+v\n", n1, coef)
            config[n1] = coef
            continue
        }
        generic2 := make(map[string]json.RawMessage)
        if err := json.Unmarshal(raw1, &generic2); err != nil {
            fmt.Printf("Unknown format for JSON entry %q, skipping it.\n", n1)
            continue
        }
        for n2, raw2 := range generic2 {
            var coef2 CoefficientsConfig
            name := n1+"."+n2
            if err := json.Unmarshal(raw2, &coef2); err != nil {
                fmt.Printf("Unknown format for JSON entry %q, skipping it.\n", name)
                continue
            }
            config[name] = coef2
        }
    }
    for configName, coef := range config {
        valid, version, numeratorDegree, denominatorDegree, approx := parseRationalVersion(configName)
        if !valid {
            fmt.Printf("skipping %q\n", configName)
            continue
        }
        fmt.Printf("\tinitCacheKey{Approximation:%q, Version:%q, NumeratorDegree:%d, DenominatorDegree:%d}: initCacheValue{\n",
                   approx, version, numeratorDegree, denominatorDegree)
        fmt.Printf("\t\tNumeratorCoefficients: %#v\n\t\tDenominatorCoefficients: %#v},\n", coef.Numerator, coef.Denominator)
    }
}

%%
// ConvertRationalsConfig("~/Downloads/rationals_config.json")
ConvertRationalsConfig("rationals_config.json")


n1=identity -> {Numerator:[0 1 0 0 0 0] Denominator:[0 0 0 0] UpperBound:-3 LowerBound:3}
	initCacheKey{Approximation:"sigmoid", Version:"B", NumeratorDegree:5, DenominatorDegree:4}: initCacheValue{
		NumeratorCoefficients: []float64{0.5000000002774382, 0.2500039727332485, 0.05544474230118124, 0.006888449237990345, 0.00048491391666921244, 1.5646015289718136e-05}
		DenominatorCoefficients: []float64{7.956371345839366e-06, 0.11088550952772189, 7.76547864226066e-07, 0.0009697684428133153}},
	initCacheKey{Approximation:"relu", Version:"B", NumeratorDegree:5, DenominatorDegree:4}: initCacheValue{
		NumeratorCoefficients: []float64{0.033897129202224346, 0.4999985439606278, 1.6701363611130988, 1.9901021632350815, 0.9413089613384323, 0.1509133373584318}
		DenominatorCoefficients: []float64{-2.1040152094202414e-05, 3.980247851167207, -3.166344237241501e-05, 0.30183382300945066}},
	initCacheKey{Approximation:"swish", Version:"A", NumeratorDegree:5, DenominatorDegree:4}: initCacheValue{
		Numerat