From 0591db6411878a390deec294a751b9c6fd97b0f3 Mon Sep 17 00:00:00 2001 From: Sebastien Binet Date: Tue, 11 Mar 2025 15:16:13 +0100 Subject: [PATCH] fit: export Func1D.Hessian Signed-off-by: Sebastien Binet --- fit/curve1d_example_test.go | 131 ++++++++++++++++++++++++++++++++++++ fit/fit.go | 8 +++ 2 files changed, 139 insertions(+) diff --git a/fit/curve1d_example_test.go b/fit/curve1d_example_test.go index 519154a95..7b30069b7 100644 --- a/fit/curve1d_example_test.go +++ b/fit/curve1d_example_test.go @@ -5,6 +5,7 @@ package fit_test import ( + "fmt" "image/color" "log" "math" @@ -13,7 +14,10 @@ import ( "go-hep.org/x/hep/hbook" "go-hep.org/x/hep/hplot" "gonum.org/v1/gonum/floats" + "gonum.org/v1/gonum/mat" "gonum.org/v1/gonum/optimize" + "gonum.org/v1/gonum/stat" + "gonum.org/v1/gonum/stat/distuv" "gonum.org/v1/plot/plotter" "gonum.org/v1/plot/vg" ) @@ -289,3 +293,130 @@ func ExampleCurve1D_powerlaw() { } } } + +func ExampleCurve1D_hessian() { + var ( + cst = 3.0 + mean = 30.0 + sigma = 20.0 + want = []float64{cst, mean, sigma} + ) + + xdata, ydata, err := readXY("testdata/gauss-data.txt") + if err != nil { + log.Fatal(err) + } + + // use a small sample + xdata = xdata[:min(25, len(xdata))] + ydata = ydata[:min(25, len(ydata))] + + gauss := func(x, cst, mu, sigma float64) float64 { + v := (x - mu) + return cst * math.Exp(-v*v/sigma) + } + + f1d := fit.Func1D{ + F: func(x float64, ps []float64) float64 { + return gauss(x, ps[0], ps[1], ps[2]) + }, + X: xdata, + Y: ydata, + Ps: []float64{10, 10, 10}, + } + res, err := fit.Curve1D(f1d, nil, &optimize.NelderMead{}) + if err != nil { + log.Fatal(err) + } + + if err := res.Status.Err(); err != nil { + log.Fatal(err) + } + if got := res.X; !floats.EqualApprox(got, want, 1e-3) { + log.Fatalf("got= %v\nwant=%v\n", got, want) + } + + inv := mat.NewSymDense(len(res.Location.X), nil) + f1d.Hessian(inv, res.Location.X) + // fmt.Printf("hessian: %1.2e\n", mat.Formatted(inv, mat.Prefix(" "))) + + popt := res.Location.X + pcov := mat.NewDense(len(popt), len(popt), nil) + { + var chol mat.Cholesky + if ok := chol.Factorize(inv); !ok { + log.Fatalf("cov-matrix not positive semi-definite") + } + + err := chol.InverseTo(inv) + if err != nil { + log.Fatalf("could not inverse matrix: %+v", err) + } + pcov.Copy(inv) + } + + // compute goodness-of-fit. + gof := newGoF(f1d.X, f1d.Y, popt, func(x float64) float64 { + return f1d.F(x, popt) + }) + + pcov.Scale(gof.SSE/float64(len(f1d.X)-len(popt)), pcov) + + // fmt.Printf("pcov: %1.2e\n", mat.Formatted(pcov, mat.Prefix(" "))) + + var ( + n = float64(len(f1d.X)) // number of data points + ndf = n - float64(len(popt)) // number of degrees of freedom + t = distuv.StudentsT{ + Mu: 0, + Sigma: 1, + Nu: ndf, + }.Quantile(0.5 * (1 + 0.95)) + ) + + for i, p := range popt { + sigma := math.Sqrt(pcov.At(i, i)) + fmt.Printf("c%d: %1.5e [%1.5e, %1.5e] -- truth: %g\n", i, p, p-sigma*t, p+sigma*t, want[i]) + } + // Output: + //c0: 2.99999e+00 [2.99999e+00, 3.00000e+00] -- truth: 3 + //c1: 3.00000e+01 [3.00000e+01, 3.00000e+01] -- truth: 30 + //c2: 2.00000e+01 [2.00000e+01, 2.00000e+01] -- truth: 20 +} + +type GoF struct { + SSE float64 // Sum of squares due to error + Rsquare float64 // R-Square is the square of the correlation between the response values and the predicted response values + NdF int // Number of degrees of freedom + AdjRsquare float64 // Degrees of freedom adjusted R-Square + RMSE float64 // Root mean squared error +} + +func newGoF(xs, ys, ps []float64, f func(float64) float64) GoF { + switch { + case len(xs) != len(ys): + panic("invalid lengths") + } + + var gof GoF + + var ( + ye = make([]float64, len(ys)) + nn = float64(len(xs) - 1) + vv = float64(len(xs) - len(ps)) + ) + + for i, x := range xs { + ye[i] = f(x) + dy := ys[i] - ye[i] + gof.SSE += dy * dy + gof.RMSE += dy * dy + } + + gof.Rsquare = stat.RSquaredFrom(ye, ys, nil) + gof.AdjRsquare = 1 - ((1 - gof.Rsquare) * nn / vv) + gof.RMSE = math.Sqrt(gof.RMSE / float64(len(ys)-len(ps))) + gof.NdF = len(ys) - len(ps) + + return gof +} diff --git a/fit/fit.go b/fit/fit.go index df27fd1de..c29f9e3a8 100644 --- a/fit/fit.go +++ b/fit/fit.go @@ -87,6 +87,14 @@ func (f *Func1D) init() { } } +// Hessian computes the hessian matrix at the provided x point. +func (f *Func1D) Hessian(hess *mat.SymDense, x []float64) { + if f.hess == nil { + f.init() + } + f.hess(hess, x) +} + // FuncND describes a multivariate function F(x0, x1... xn; p0, p1... pn) // for which the parameters ps can be found with a fit. type FuncND struct {