Skip to content

Commit

Permalink
update relative residual calculation
Browse files Browse the repository at this point in the history
add preconditioning to CG
  • Loading branch information
jcosborn committed Mar 4, 2024
1 parent 7e66650 commit 58d52bd
Show file tree
Hide file tree
Showing 9 changed files with 205 additions and 35 deletions.
2 changes: 2 additions & 0 deletions src/base/basicOps.nim
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@ template `:=`*[R,X:SomeNumber](r: R; x: X) =
r = R(x)
template `:=`*[R,X:SomeNumber](r: R; x: ptr X) =
r = R(x[])
template `:=`*[R:SomeNumber](r: R; x: bool) =
r = R(if x: 1 else: 0)
proc `+=`*(r: var float32; x: SomeNumber) {.alwaysInline.} =
r = r + float32(x)
proc `-=`*(r: var float32; x: SomeNumber) {.alwaysInline.} =
Expand Down
4 changes: 3 additions & 1 deletion src/maths/complexProxy.nim
Original file line number Diff line number Diff line change
Expand Up @@ -595,9 +595,11 @@ template `+=`*[R,X:ComplexProxy](r: R, x: X) = iadd(r,x)
]#

proc sqrt*(x: ComplexProxy): auto =
mixin copySign
let n = sqrt(x.norm2)
let r = sqrt(0.5*(n + x.re))
let i = select(x.im<0, -1, 1)*sqrt(0.5*(n - x.re))
#let i = select(x.im<0, -1, 1)*sqrt(0.5*(n - x.re))
let i = copySign(sqrt(0.5*(n - x.re)), x.im)
newComplexP(r, i)

# inorm2, redot, iredot, dot, idot
Expand Down
5 changes: 4 additions & 1 deletion src/physics/stagSolve.nim
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,10 @@ proc solveXX*(s: Staggered; r,x: Field; m: SomeNumber; sp0: var SolverParams;
let flops = (s.g.len*4*72+60)*r.l.nEven*sp.iterations
sp.flops = flops.float
if sp0.verbosity>0:
echo "solveXX(QEX): ", sp.getStats
if parEven:
echo "solveEE(QEX): ", sp.getStats
else:
echo "solveOO(QEX): ", sp.getStats
of sbQuda:
tic()
if parEven:
Expand Down
5 changes: 4 additions & 1 deletion src/simd/simdArray.nim
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ template map021(T,L,op1,op2:untyped):untyped {.dirty.} =
bind forStatic
#forStatic i, 0, L-1:
for i in 0..<L:
result[][i] = op2(x[][i], y[][i])
result[][i] := op2(x[][i], y[][i])
template map021x(T1,T2,TR,L,op1,op2:untyped):untyped {.dirty.} =
proc op1*(x:T1,y:T2):TR {.alwaysInline,noInit.} =
mixin `[]`
Expand Down Expand Up @@ -360,6 +360,8 @@ template makeSimdArray2*(L:typed;B,F:typedesc;N0,N:typed,T:untyped) {.dirty.} =
map021(T, L, `-`, `-`)
map021(T, L, `*`, `*`)
map021(T, L, `/`, `/`)
#map021(T, L, `<`, `<`)
map021(T, L, copySign, copySign)

map110(T, L, assign, assign)
map110(T, L, neg, neg)
Expand Down Expand Up @@ -447,6 +449,7 @@ template makeSimdArray2*(L:typed;B,F:typedesc;N0,N:typed,T:untyped) {.dirty.} =
template `*`*(x:T; y:SomeNumber):T = mul(x, y.to(type(T)))
template `/`*(x:SomeNumber; y:T):T = divd(x.to(type(T)), y)
template `/`*(x:T; y:SomeNumber):T = divd(x, y.to(type(T)))
template `<`*(x:T; y:SomeNumber):T = `<`(x, y.to(type(T)))
template `+=`*(r:var T; x:SomeNumber) = iadd(r, x.to(type(T)))
template `-=`*(r:var T; x:SomeNumber) = isub(r, x.to(type(T)))
template `*=`*(r:var T; x:SomeNumber) = imul(r, x.to(type(T)))
Expand Down
8 changes: 8 additions & 0 deletions src/simd/simdWrap.nim
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ f1(cos)
f1(acos)
f1(load1)
f2(atan2)
f2(copySign)
f2s(`+`)
f2s(`-`)
f2s(`*`)
Expand All @@ -169,6 +170,9 @@ f2s(sub)
f2s(mul)
f2s(min)
f2s(max)
f2s(`==`)
f2s(`<`)
#f2s('>')


# special cases
Expand Down Expand Up @@ -223,3 +227,7 @@ template `+=`*(x: SomeNumber, y: Simd) =
template exp*(xx: Simd[Indexed]): untyped =
let x = xx
exp(x[][x.indexedIdx])

#template select*(x: Simd[T], y,z: SomeNumber): untyped =
# mixin f
# asSimd(f(x[], y))
18 changes: 17 additions & 1 deletion src/simd/simdX86Ops.nim
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ template basicDefs(T,F,N,P,S:untyped) {.dirty.} =
template assign1*(r: var T; x: SomeNumber) =
r = `P "_set1_" S`(F(x))
template assign*(r: var T; x: SomeNumber) = assign1(r, x)
template assign*(r: var T; x: bool) = assign1(r, (if x: 1 else: 0))
macro assign*(r:var T; x:varargs[SomeNumber]):auto =
template setX:auto {.gensym.} = `P "_setr_" S`()
template setF(x):auto {.gensym.} = F(x)
Expand All @@ -158,7 +159,7 @@ template basicDefs(T,F,N,P,S:untyped) {.dirty.} =
#proc assign*(r:var T; x:T) {.alwaysInline.} =
# r = x
#template `=`*(r: var T; x: T) = {.emit: [r, " = ", x].}
template assign*(r: T; x:T) =
template assign*(r: T; x: T) =
r = x
#proc assign*(r:var array[N,F]; x:T) {.alwaysInline.} =
# assign(r[0].addr, x)
Expand Down Expand Up @@ -205,11 +206,23 @@ template basicDefs(T,F,N,P,S:untyped) {.dirty.} =
template neg*(x:T):T = sub(`P "_setzero_" S`(), x)
#template inv*(x:T):T = `P "_rcp_" S`(x)
template inv*(x:T):T = divd(1.0,x)
template `and`*(x,y:T):T = `P "_and_" S`(x,y)
template andnot*(x,y:T):T = `P "_andnot_" S`(x,y)
template `or`*(x,y:T):T = `P "_or_" S`(x,y)
#template `==`*(x,y:T):auto = `P "_cmp_" S "_mask"`(x,y,MM_CMPINT_EQ)
#template `<`*(x,y:T):auto = `P "_cmp_" S "_mask"`(x,y,MM_CMPLT_EQ)
#template `<`*(x,y:T):auto = `P "_cmp_" S "_mask"`(x,y,MM_CMPLT_EQ)
#template `<`*(x,y:T):auto = `P "_cmp_" S`(x,y,CMP_LT_OS)

binaryMixed(T, add, add)
binaryMixed(T, sub, sub)
binaryMixed(T, mul, mul)
binaryMixed(T, divd, divd)
binaryMixed(T, `and`, `and`)
binaryMixed(T, andnot, andnot)
binaryMixed(T, `or`, `or`)
binaryMixed(T, `==`, `==`)
binaryMixed(T, `<`, `<`)

template neg*(r:var T; x:T) = r = neg(x)
template add*(r: T; x,y:T) = r = add(x,y)
Expand Down Expand Up @@ -263,6 +276,7 @@ template basicDefs(T,F,N,P,S:untyped) {.dirty.} =
template `/=`*(r: T, x:T) = idiv(r,x)

unaryMixedVar(T, `:=`, assign)
template `:=`*(r: T; x:bool) = assign(r,x)
template `:=`*(r: T; x:openArray[SomeNumber]) = assign(r,x)
unaryMixedVar(T, `+=`, iadd)
unaryMixedVar(T, `-=`, isub)
Expand Down Expand Up @@ -293,6 +307,8 @@ basicDefs(m256d, float64, 4, mm256, pd)
basicDefs(m512, float32, 16, mm512, ps)
basicDefs(m512d, float64, 8, mm512, pd)

proc copySign*[T:SimdX86](to,frm: T): T {.alwaysInline.} =
result = `or`(`and`(frm, -0.0), andnot(to, -0.0))

proc simdReduce*(r:var SomeNumber; x:m128) {.alwaysInline.} =
let y = mm_hadd_ps(x, x)
Expand Down
3 changes: 3 additions & 0 deletions src/simd/simdX86Types.nim
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,6 @@ when defined(AVX512):
SimdH16* = Simd[m512h]

template eval*(x: SimdX86): untyped = x

#var CMP_EQ_OS {.importc: "_CMP_EQ_OS", imm.} = cint
var CMP_LT_OS* {.importc: "_CMP_LT_OS", imm.}: cint
78 changes: 66 additions & 12 deletions src/solvers/cg.nim
Original file line number Diff line number Diff line change
Expand Up @@ -7,27 +7,34 @@ export solverBase
type
CgState*[T] = object
r,Ap,b: T
p,x: T
b2,r2,r2old,r2stop: float
p,x,z: T
b2,r2,r2old,r2stop,rz,rzold: float
iterations: int
precon: bool

proc reset*(cgs: var CgState) =
cgs.b2 = -1
cgs.iterations = 0
cgs.r2old = 1.0
cgs.rzold = 1.0
cgs.r2stop = 0.0

proc newCgState*[T](x,b: T): CgState[T] =
result.r = newOneOf(b)
result.Ap = newOneOf(b)
proc newCgState*[T](x,b: T; precon=false): CgState[T] =
result.r = newOneOf b
result.Ap = newOneOf b
result.b = b
result.p = newOneOf(x)
result.p = newOneOf x
result.x = x
result.precon = precon
if precon:
result.z = newOneof b
else:
result.z = result.r
result.reset

# solves: A x = b
proc solve*(state: var CgState; op: auto; sp: var SolverParams) =
mixin apply
mixin apply, applyPrecon
tic()
let vrb = sp.verbosity
template verb(n:int; body:untyped):untyped =
Expand All @@ -41,15 +48,25 @@ proc solve*(state: var CgState; op: auto; sp: var SolverParams) =
onNoSync(sub):
body

const precon = compiles(op.applyPrecon(state.z, state.r))
if precon != state.precon:
state.precon = precon
if precon:
state.z = newOneOf state.r
else:
state.z = state.r

let
r = state.r
p = state.p
Ap = state.Ap
x = state.x
b = state.b
z = state.z
var
b2 = state.b2
r2 = state.r2
rz = state.rz

if b2<0: # first call
mythreads:
Expand All @@ -61,14 +78,17 @@ proc solve*(state: var CgState; op: auto; sp: var SolverParams) =
mythreads:
x := 0
r := 0
if precon:
z := 0
r2 = 0.0
rz = 0.0
else:
threads:
op.apply(Ap, x)
subset:
r := b - Ap
p := 0
r2 = r.norm2
p := 0
verb(3):
echo("p2: ", p.norm2)
echo("r2: ", r2)
Expand All @@ -78,22 +98,31 @@ proc solve*(state: var CgState; op: auto; sp: var SolverParams) =
let maxits = sp.maxits
var itn0 = state.iterations
var r2o0 = state.r2old
var rzo0 = state.rzold

toc("cg setup")
if r2 > r2stop:
threads:
var itn = itn0
var r2o = r2o0
var rzo = rzo0
verb(1):
#echo(-1, " ", r2)
echo(itn, " ", r2/b2)

while itn<maxits and r2>r2stop:
tic()
let beta = r2/r2o
when precon:
op.applyPrecon(z, r)
subset:
rz = r.redot z
else:
rz = r2
let beta = rz/rzo
r2o = r2
rzo = rz
subset:
p := r + beta*p
p := z + beta*p
toc("p update", flops=2*numNumbers(r[0])*sub.lenOuter)
verb(3):
echo "beta: ", beta
Expand All @@ -103,7 +132,7 @@ proc solve*(state: var CgState; op: auto; sp: var SolverParams) =
subset:
let pAp = p.redot(Ap)
toc("pAp", flops=2*numNumbers(p[0])*sub.lenOuter)
let alpha = r2/pAp
let alpha = rz/pAp
x += alpha*p
toc("x", flops=2*numNumbers(p[0])*sub.lenOuter)
r -= alpha*Ap
Expand Down Expand Up @@ -132,6 +161,7 @@ proc solve*(state: var CgState; op: auto; sp: var SolverParams) =
if threadNum==0:
itn0 = itn
r2o0 = r2o
rzo0 = rzo
#var fr2: float
#op.apply(Ap, x)
#subset:
Expand All @@ -144,6 +174,8 @@ proc solve*(state: var CgState; op: auto; sp: var SolverParams) =
state.iterations = itn0
state.r2old = r2o0
state.r2 = r2
state.rzold = rzo0
state.rz = rz
verb(1):
echo state.iterations, " acc r2:", r2/b2
#threads:
Expand Down Expand Up @@ -178,12 +210,25 @@ when isMainModule:
var v1 = lo.ColorVector()
var v2 = lo.ColorVector()
var v3 = lo.ColorVector()

type opArgs = object
m: type(m)
var oa = opArgs(m: m)
proc apply*(oa: opArgs; r: type(v1); x: type(v1)) =
r := oa.m*x
#mul(r, m, x)
type opArgsP = object
m: type(m)
var oap = opArgsP(m: m)
proc apply*(oa: opArgsP; r: type(v1); x: type(v1)) =
r := oa.m*x
#mul(r, m, x)
proc applyPrecon*(oa: opArgsP; r: type(v1); x: type(v1)) =
for e in r:
let t = sqrt(1.0 / m[e][0,0])
r[e] := t * x[e]
#mul(r, m, x)

var sp:SolverParams
sp.r2req = 1e-20
sp.maxits = 200
Expand Down Expand Up @@ -231,4 +276,13 @@ when isMainModule:
sp.maxits += 10
cg.solve(oa, sp)
let c = cg.x.norm2
echo cg.iterations, ": ", c
echo cg.iterations, ": ", c, " ", cg.r2

v2 := 0
cg.reset
sp.maxits = 0
while cg.r2 > cg.r2stop:
sp.maxits += 10
cg.solve(oap, sp)
let c = cg.x.norm2
echo cg.iterations, ": ", c, " ", cg.r2
Loading

0 comments on commit 58d52bd

Please sign in to comment.