Skip to content

Commit

Permalink
start adding Hasenbusch to staggered autograd
Browse files Browse the repository at this point in the history
  • Loading branch information
jcosborn committed Jun 19, 2024
1 parent 969e3ed commit 41f9a9d
Show file tree
Hide file tree
Showing 2 changed files with 148 additions and 41 deletions.
164 changes: 125 additions & 39 deletions src/experimental/stagag.nim
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,9 @@ let
warmmd = (intParam("warmmd", 1) != 0)
ntrain = intParam("ntrain", 10)
trajs = intParam("trajs", 10)
nf = floatParam("nf", 1)
nf = intParam("nf", 1)
mass = floatParam("mass", 0.1)
hmasses = floatSeqParam("hmasses") # Hasenbusch masses
arsq = floatParam("arsq", 1e-20)
frsq = floatParam("frsq", 1e-12)
seed0 = defaultComm.broadcast(int(1000*epochTime()))
Expand Down Expand Up @@ -77,6 +78,7 @@ echoparam(ntrain)
echoparam(trajs)
echoparam(nf)
echoparam(mass)
echoparam(hmasses)
echoparam(arsq)
echoparam(frsq)
echoparam(seed)
Expand Down Expand Up @@ -112,12 +114,16 @@ var
p = lo.newgauge
f = lo.newgauge
g0 = lo.newgauge
phi = lo.ColorVector()
psi = lo.ColorVector()
phi = newseq[typeof(lo.ColorVector())](1+hmasses.len)
psi = newseq[typeof(lo.ColorVector())](1+hmasses.len)
ftmp = lo.ColorVector()
for i in 0..<phi.len:
phi[i] = lo.ColorVector()
psi[i] = lo.ColorVector()
type
Gauge = typeof(g)
GaugeV = AgVar[Gauge]
Cvec = typeof(phi)
Cvec = typeof(phi[0])
CvecV = AgVar[Cvec]
template newGaugeV(c: AgTape, x: Gauge): auto = newGaugeFV(c, x)
case infn
Expand Down Expand Up @@ -194,6 +200,8 @@ momvs.add tape.newGaugeV(p0)
let vtau = pushParam(tau)
#vtau.doGrad = false
var nff = 0
var vhmasses = newSeq[type pushParam(0.0)](hmasses.len)
for i in 0..<vhmasses.len: vhmasses[i] = pushParam(hmasses[i])

proc pushTemp =
ptemps.add tape.newFloatV
Expand All @@ -210,6 +218,10 @@ proc `+`(x: FloatV, y: FloatV): FloatV =
pushTemp()
result = ptemps[^1]
add(result, x, y)
proc `-`(x: FloatV, y: float): FloatV =
pushTemp()
result = ptemps[^1]
sub(result, x, y)
proc `-`(x: FloatV, y: FloatV): FloatV =
pushTemp()
result = ptemps[^1]
Expand All @@ -226,6 +238,14 @@ proc `/`(x: FloatV, y: SomeNumber): FloatV =
pushTemp()
result = ptemps[^1]
divd(result, x, float y)
proc `/`(x: SomeNumber, y: FloatV): FloatV =
pushTemp()
result = ptemps[^1]
divd(result, float x, y)
proc `/`(x: FloatV, y: FloatV): FloatV =
pushTemp()
result = ptemps[^1]
divd(result, x, y)

proc addT(veps: FloatV) =
pushGauge()
Expand Down Expand Up @@ -259,11 +279,14 @@ proc addGF(va, vb: FloatV) =
mul(momvs[^1], momvs[^2], gaugevs[^1])
addGx(va, p, momvs[^1])

proc addFf(g: GaugeV) =
proc addFf(g: GaugeV, i = 0) =
if nf == 0: return
pushCvec()
let cv = cvecvs[^1]
stag.agradSolve(g, cv, phi, mass, spf)
if i == 0:
stag.agradSolve(g, cv, phi[i], mass, spf)
else:
stag.agradSolve(g, cv, phi[i], vhmasses[i-1], spf)
pushMom()
stag.agradStagDeriv(momvs[^1], cv)
pushMom()
Expand All @@ -272,12 +295,25 @@ proc addFf(g: GaugeV) =
projtah(momvs[^1], momvs[^2])
nff += 1

proc addFx(veps: FloatV, p,g: GaugeV) =
proc addFx(veps: FloatV, p0,g: GaugeV) =
var p = p0
if nf == 0: return
addFf(g)
pushMom()
let va = (-0.5/mass) * veps
xpay(momvs[^1], p, va, momvs[^2])
for i in 0..hmasses.len:
addFf(g, i)
pushMom()
var va: FloatV
if i == hmasses.len: # last term is just inverse (no ratio)
if i == 0:
va = (-0.5/mass) * veps
else:
va = -0.5 * (veps/vhmasses[i-1])
else:
if i == 0:
va = (-0.5/mass) * veps * (vhmasses[0]*vhmasses[0]-mass*mass)
else:
va = -0.5*veps*(vhmasses[i]*vhmasses[i]/vhmasses[i-1]-vhmasses[i-1])
xpay(momvs[^1], p, va, momvs[^2])
p = momvs[^1]

proc addF(veps: FloatV) =
if nf == 0: return
Expand Down Expand Up @@ -1058,9 +1094,18 @@ var
gav = tape.newFloatV
hgv = tape.newFloatV
hv = tape.newFloatV
psiv = tape.newAgVar(psi)
faxv = tape.newFloatV
fav = tape.newFloatV
dphi = newSeq[type ftmp](hmasses.len)
dphiv = newSeq[type tape.newAgVar(dphi[0])](hmasses.len)
psiv = newSeq[type tape.newAgVar(psi[0])](1+hmasses.len)
faxv = newSeq[type tape.newFloatV](1+hmasses.len)
fav = newSeq[type tape.newFloatV](1+hmasses.len)
for i in 0..<dphi.len:
dphi[i] = ftmp.newOneOf
dphiv[i] = tape.newAgVar(dphi[i])
for i in 0..<psiv.len:
psiv[i] = tape.newAgVar(psi[i])
faxv[i] = tape.newFloatV()
fav[i] = tape.newFloatV()
proc addAction(p: GaugeV, g: GaugeV) =
norm2subtract(p2xv, p, 8.0)
mul(p2v, 0.5, p2xv)
Expand All @@ -1071,10 +1116,30 @@ proc addAction(p: GaugeV, g: GaugeV) =
hv = hgv
else:
add(hgv, p2v, gav)
stag.agradSolve(g, psiv, phi, mass, spa)
norm2subtract(faxv, psiv, 3.0)
mul(fav, 0.5, faxv)
add(hv, hgv, fav)
for i in 0..<phi.len-1:
stag.agradD(g, dphiv[i], phi[i], vhmasses[i])
#mul(dphiv[i], 1.0, phi[i])
if i == 0:
stag.agradSolve(g, psiv[i], dphiv[i], mass, spa)
else:
stag.agradSolve(g, psiv[i], dphiv[i], vhmasses[i-1], spa)
norm2subtract(faxv[i], psiv[i], 3.0)
if phi.len == 1:
stag.agradSolve(g, psiv[^1], phi[^1], mass, spa)
else:
stag.agradSolve(g, psiv[^1], phi[^1], vhmasses[^1], spa)
norm2subtract(faxv[^1], psiv[^1], 3.0)
for i in 0..<(faxv.len-1):
if i == 0:
add(fav[0], faxv[0], faxv[1])
else:
add(fav[i], fav[i-1], faxv[i+1])
if fav.len == 1:
mul(fav[0], 0.5, faxv[0])
else:
mul(fav[^1], 0.5, fav[^2])
#mul(fav[^1], 0.5, faxv[^1])
add(hv, hgv, fav[^1])

proc setupAction =
tape.addTrack
Expand All @@ -1094,6 +1159,15 @@ proc init(m: var Met) =
init(r)
m.verbosity = 1

template masses(i: int): float =
if i==0: mass
else: vhmasses[i-1].obj

template masses(bi: BackwardsIndex): float =
let i = 1 + hmasses.len - int(bi)
if i==0: mass
else: vhmasses[i-1].obj

proc start*(m: var Met) =
tic()
m.state = 0
Expand All @@ -1103,38 +1177,50 @@ proc start*(m: var Met) =
p.randomTAH r
for i in 0..<p.len:
p0[i] := p[i]
if nf != 0:
threadBarrier()
psi.gaussian r
threadBarrier()
if nf != 0:
threads:
stag.rephase
threadBarrier()
stag.D(phi, psi, mass)
threadBarrier()
phi.odd := 0
for i in 0..<phi.len:
threads:
psi[i].gaussian r
threadBarrier()
if i != phi.len-1:
stag.D(ftmp, psi[i], masses(i))
else:
stag.D(phi[i], psi[i], masses(i))
if i != phi.len-1:
stag.solve(phi[i], ftmp, masses(i+1), spa)
threads:
phi[i].odd := 0
threads:
stag.rephase
toc("init p, phi")

proc getH*(m: Met): float =
tic()
var p2 = 0.0
var f2 = 0.0
if nf != 0:
threads:
stag.rephase
stag.solve(psi, phi, mass, spa)
toc("fa solve")
#echo "psi e: ", psi.even.norm2
#echo "psi o: ", psi.odd.norm2
threads:
var p2t = 0.0
for i in 0..<p.len:
p2t += p[i].norm2subtract(8.0)
threadMaster: p2 = p2t
if nf != 0:
toc("p2")
var f2 = 0.0
if nf != 0:
threads:
stag.rephase
var psi2 = psi.norm2subtract(3.0)
threadMaster: f2 = psi2
for i in 0..<phi.len-1:
threads:
stag.D(ftmp, phi[i], masses(i+1))
#ftmp := phi[i]
stag.solve(psi[i], ftmp, masses(i), spa)
stag.solve(psi[^1], phi[^1], masses(^1), spa)
threads:
for i in 0..<psi.len:
var psi2 = psi[i].norm2subtract(3.0)
threadMaster: f2 += psi2
stag.rephase
toc("fa solve")
let
ga0 = gc.actionA g
fa0 = 0.5*f2
Expand All @@ -1148,13 +1234,13 @@ proc getH*(m: Met): float =
echo &"Begin H: {h0} T: {t0} Sg: {ga0} Sf: {fa0}"
tape.setTrack 1
tape.run
echo &" H: {hv.obj} T: {p2v.obj} Sg: {gav.obj} Sf: {fav.obj}"
echo &" H: {hv.obj} T: {p2v.obj} Sg: {gav.obj} Sf: {fav[^1].obj}"
tape.setTrack 0
else:
echo &"End H: {h0} T: {t0} Sg: {ga0} Sf: {fa0}"
tape.setTrack 2
tape.run
echo &" H: {hv.obj} T: {p2v.obj} Sg: {gav.obj} Sf: {fav.obj}"
echo &" H: {hv.obj} T: {p2v.obj} Sg: {gav.obj} Sf: {fav[^1].obj}"
tape.setTrack 0

# w = sum_i w_i
Expand Down
25 changes: 23 additions & 2 deletions src/hmc/agradOps.nim
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,25 @@ proc mulgrad1(c: float, r: var float, y: float) =
proc mulgrad2(c: float, x: float, r: var float) =
r += x * c

proc mul[F:Field](r: F, x: float, y: F) =
threads:
r := x * y
proc mulgrad1[F:Field](c: F, r: var float, y: F) =
var rr = 0.0
threads:
var l: typeof(redot(y[0][], c[0][]))
for s in c:
l += redot(y[s][], c[s][])
var m = simdReduce l
threadRankSum m
threadSingle:
rr = m
r += rr
proc mulgrad2[F:Field](c: F, x: float, r: F) =
threads:
for s in c:
r[s][] += x * c[s][]

proc mul[G:GaugeF](r: G, x: float, y: G) =
threads:
for mu in 0..<r.len:
Expand Down Expand Up @@ -406,6 +425,7 @@ proc agradDbck[I,O](op: AgOp[I,O]) {.nimcall.} =
when g is AgVar:
if g.doGrad:
#g.grad += rephase [outer(c shift x') - outer(x shift c')]
for mu in 0..<g.grad.len: g.grad[mu] *= 2.0
s.rephase g.grad
s.stagD2deriv(g.grad, r.grad, x.maybeObj)
s.rephase g.grad
Expand Down Expand Up @@ -508,6 +528,7 @@ proc agradSolve(c: var AgTape, s,g,r,x,m,p: auto) =
var op = newAgOp((s,g,x,m,p), r, agradSolvefwd, agradSolvebck)
c.add op
template agradSolve*(s: Staggered, g,r,x,m,p: auto) =
## g: gauge, r: result, x: src, m: mass, p: solve params
r.ctx.agradSolve(s, g, r, x, m, addr p)

when isMainModule:
Expand Down Expand Up @@ -644,8 +665,8 @@ when isMainModule:
echo " ", (nv.obj-n0)/eps
v1 -= eps * c

#testAgradD()
testAgradD()
#testAgradSolve()
testAgradStagDeriv()
#testAgradStagDeriv()

qexFinalize()

0 comments on commit 41f9a9d

Please sign in to comment.