# Fast-Fourier-Transformation Example

## Machine-Learned Inverse FFT

Let's start with a sinusoidal curve: that's the input (`x`) that we want to learn, given its FFT (`y`).
The goal is to do an `InverseRealFFT` by gradient descent.

For this problem the input (`x`) is real (`float32`) and label (`y`), the FFT, is complex (`complex64`).

In [1]:
!*rm -f go.work && go work init
!*go work use . "/opt/janpf/Projects/gomlx.gomlx"
%goworkfix

	- added replace rule for module "github.com/gomlx/gomlx" to local directory "/opt/janpf/Projects/gomlx.gomlx".

In [96]:
import (
    . "github.com/gomlx/gomlx/graph"
    . "github.com/gomlx/gomlx/types/exceptions"
    mg "github.com/gomlx/gomlx/examples/notebook/gonb/margaid"
    "github.com/janpfeifer/gonb/gonbui"
)

// manager always created at initialization.
var manager = NewManager()

const (
    NumPoints = 100
    Frequency = 2.0  // Number of curves in samples.
    RealDType = shapes.Float32
    ComplexDType = shapes.Complex64
)

// CalculateXY returns (x, y) of our problem, where y is a sinusoidal curve and x is its FFT.
func CalculateXY() (x, y tensor.Tensor) {
    e := NewExec(manager, func (g *Graph) (x, y *Node) {
        x = Iota(g, shapes.Make(RealDType, 1, NumPoints), 1)
        x = MulScalar(x, 2.0*math.Pi*Frequency/float64(NumPoints))
        x = Sin(x)
        y = RealFFT(x)
        return
    })
    res := e.Call()
    return res[0], res[1]
}

func Plot(displayId string, width, height int, tensors []tensor.Tensor, names []string) {
    plts := mg.New(width, height)
    for ii, t := range tensors {
        var values []float64
        switch t.DType() {
        case shapes.F64:
            values = t.Local().Flat().([]float64)
        case shapes.F32:
            values32 := t.Local().Flat().([]float32)
            values = slices.Map(values32, func (v float32) float64 { return float64(v) })
        default:
            Panicf("only float32 and float64 tensor dtypes are accepted by Plot, got t.shape=%s", t.Shape())
        }
        var name string
        if len(names) > ii {
            name = names[ii]
        }
        plts.AddValues(name, "", values)
    }
    if displayId == "" {
        plts.Plot()
    } else {
        gonbui.UpdateHTML(displayId, plts.PlotToHTML())
    }
}

%%
x, y := CalculateXY()
fmt.Printf("x: shape=%s\n", x.Shape())
fmt.Printf("y: shape=%s\n", y.Shape())
Plot("", 1024, 320, []tensor.Tensor{x}, nil)

2023-08-06 08:34:18.050942: E external/xla/xla/stream_executor/stream_executor_internal.h:124] SetPriority unimplemented for this stream.


x: shape=(Float32)[1 100]
y: shape=(Complex64)[1 51]


In [94]:
%rm TrainInverseRealFFT

. removed func TrainInverseRealFFT


### Train the model

If you run it, you'll see the plot of the "learnedX" adjusting towards "x", the original sinusoidal curve.

In [102]:
import (
    . "github.com/gomlx/gomlx/graph"
    "github.com/gomlx/gomlx/ml/context"
    "github.com/gomlx/gomlx/ml/data"
    "github.com/gomlx/gomlx/ml/train"
)

var (
	flagNumSteps     = flag.Int("steps", 1000, "Number of gradient descent steps to perform")
	flagLearningRate = flag.Float64("learning_rate", 0.1, "Initial learning rate.")
)

func TrainInverseRealFFT() {
    x, y := CalculateXY()
    ctx := context.NewContext(manager)
	ctx.SetParam(optimizers.LearningRateKey, *flagLearningRate)

    modelFn := func(ctx *context.Context, spec any, inputs []*Node) []*Node {
        g := inputs[0].Graph()
        learnedXVar := ctx.VariableWithShape("learnedX", x.Shape())
        predictedY := RealFFT(learnedXVar.ValueGraph(g))
        return []*Node{predictedY}
    }

    dataset, err := data.InMemoryFromData(manager, "dataset", []any{x}, []any{y})
    if err != nil {
        panic(err)
    }
    dataset.BatchSize(1, false).Infinite(true)

    opt := optimizers.Adam().Done()
    trainer := train.NewTrainer(
        manager, ctx, modelFn,
        losses.MeanAbsoluteError,
        opt,
        nil, nil) // trainMetrics, evalMetrics

	loop := train.NewLoop(trainer)
	commandline.AttachProgressBar(loop) // Attaches a progress bar to the loop.

    // Plot learnedX
    displayId := gonbui.UniqueID()
    gonbui.UpdateHTML(displayId, "")
    train.EveryNSteps(loop, 10, "plot", 0, func(loop *train.Loop, metrics []tensor.Tensor) error {
        learnedXVar := ctx.InspectVariable(context.RootScope, "learnedX")
        learnedX := learnedXVar.Value()
        Plot(displayId, 1024, 320, []tensor.Tensor{x, learnedX}, []string{"Truth", "Learned"})
        return nil
    })
    
	// Loop for given number of steps.
	_, err = loop.RunSteps(dataset, *flagNumSteps)
	if err != nil {
		panic(err)
	}

}

%% --steps=800 --learning_rate=0.01
%env GOMLX_PLATFORM Host
fmt.Println(manager.Platform())
TrainInverseRealFFT()

Set: GOMLX_PLATFORM="Host"
Host


2023-08-06 08:35:22.174327: E external/xla/xla/stream_executor/stream_executor_internal.h:124] SetPriority unimplemented for this stream.


Training (800 steps):  100% [[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m] (3799 steps/s)[0m [loss=0.011] [~loss=0.023]        
