-
Notifications
You must be signed in to change notification settings - Fork 0
/
num.lua
71 lines (52 loc) · 1.69 KB
/
num.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
local template = require 'template'
local REG = debug.getregistry()
num = {}
function num.ode(spec)
local required = {N= 'number', eps_abs= 'number'}
local defaults = {eps_rel = 0, a_y = 1, a_dydt = 0}
local is_known = {rkf45= true, rk8pd= true}
for k, tp in pairs(required) do
if type(spec[k]) ~= tp then
error(string.format('parameter %s should be a %s', k, tp))
end
end
for k, v in pairs(defaults) do
if not spec[k] then spec[k] = v end
end
local method = spec.method and spec.method or 'rkf45'
if not is_known[method] then error('unknown ode method: ' .. method) end
spec.method = nil
local ode = template.load(method, spec)
REG['GSL.help_hook'].ODE = ode
local mt = {
__index = {step = ode.step, init = ode.init, evolve = ode.evolve}
}
return setmetatable(ode.new(), mt)
end
local NLINFIT_METHODS = {
set = function(ss, fdf, x0) return ss.lm.set(fdf, x0) end,
iterate = function(ss) return ss.lm.iterate() end,
test = function(ss, epsabs, epsrel) return ss.lm.test(epsabs, epsrel) end,
}
local NLINFIT = {
__index = function(t, k)
if k == 'chisq' then
return t.lm.chisq()
else
return NLINFIT_METHODS[k] or t.lm[k]
end
end
}
REG['GSL.NLINFIT'] = NLINFIT_METHODS
function num.nlinfit(spec)
if not spec.n then error 'number of points "n" not specified' end
if not spec.p then error 'number of parameters "p" not specified' end
if spec.n <= 0 or spec.p <= 0 then
error '"n" and "p" shoud be positive integers'
end
local n, p = spec.n, spec.p
local s = { lm = template.load('lmfit', {N= n, P= p}) }
setmetatable(s, NLINFIT)
return s
end
return num