In [1]:
import torch
import math
import sys
import os
from timeit import default_timer as timer
from functools import reduce
from field_transformation import *

In [2]:
# From Xiao-Yong

class Param:
    def __init__(self, beta = 6.0, lat = [64, 64], tau = 2.0, nstep = 50, ntraj = 256, nrun = 4, nprint = 256, seed = 11*13, randinit = False, nth = int(os.environ.get('OMP_NUM_THREADS', '2')), nth_interop = 2):
        self.beta = beta
        self.lat = lat
        self.nd = len(lat)
        self.volume = reduce(lambda x,y:x*y, lat)
        self.tau = tau
        self.nstep = nstep
        self.dt = self.tau / self.nstep
        self.ntraj = ntraj
        self.nrun = nrun
        self.nprint = nprint
        self.seed = seed
        self.randinit = randinit
        self.nth = nth
        self.nth_interop = nth_interop
    def initializer(self):
        if self.randinit:
            return torch.empty((param.nd,) + param.lat).uniform_(-math.pi, math.pi)
        else:
            return torch.zeros((param.nd,) + param.lat)
    def summary(self):
        return f"""latsize = {self.lat}
volume = {self.volume}
beta = {self.beta}
trajs = {self.ntraj}
tau = {self.tau}
steps = {self.nstep}
seed = {self.seed}
nth = {self.nth}
nth_interop = {self.nth_interop}
"""
    def uniquestr(self):
        lat = ".".join(str(x) for x in self.lat)
        return f"out_l{lat}_b{param.beta}_n{param.ntraj}_t{param.tau}_s{param.nstep}.out"

def action(param, f):
    return (-param.beta)*torch.sum(torch.cos(plaqphase(f)))

def force(param, f):
    f.requires_grad_(True)
    s = action(param, f)
    f.grad = None
    s.backward()
    ff = f.grad
    f.requires_grad_(False)
    return ff

plaqphase = lambda f: f[0,:] - f[1,:] - torch.roll(f[0,:], shifts=-1, dims=1) + torch.roll(f[1,:], shifts=-1, dims=0)
topocharge = lambda f: torch.floor(0.1 + torch.sum(regularize(plaqphase(f))) / (2*math.pi))
def regularize(f):
    p2 = 2*math.pi
    f_ = (f - math.pi) / p2
    return p2*(f_ - torch.floor(f_) - 0.5)

def leapfrog(param, x, p):
    dt = param.dt
    x_ = x + 0.5*dt*p
    f = force(param, x_)
    p_ = p + (-dt)*f
    print(f'plaq(x) {action(param, x) / (-param.beta*param.volume)}  force.norm {torch.linalg.norm(f)}')
    for i in range(param.nstep-1):
        x_ = x_ + dt*p_
        p_ = p_ + (-dt)*force(param, x_)
    x_ = x_ + 0.5*dt*p_
    return (x_, p_)
def hmc(param, x):
    p = torch.randn_like(x)
    act0 = action(param, x) + 0.5*torch.sum(p*p)
    x_, p_ = leapfrog(param, x, p)
    xr = regularize(x_)
    act = action(param, xr) + 0.5*torch.sum(p_*p_)
    prob = torch.rand([], dtype=torch.float64)
    dH = act-act0
    exp_mdH = torch.exp(-dH)
    acc = prob < exp_mdH
    newx = xr if acc else x
    return (dH, exp_mdH, acc, newx)

put = lambda s: sys.stdout.write(s)

In [3]:
param = Param(
    beta = 2.0,
    lat = (8, 8),
    tau = 2, # 0.3
    nstep = 8, # 3
    # ADJUST ME
    ntraj = 2, # 2**16 # 2**10 # 2**15
    #
    nprint = 2,
    seed = 1331)

In [4]:
torch.manual_seed(param.seed)

torch.set_num_threads(param.nth)
torch.set_num_interop_threads(param.nth_interop)
os.environ["OMP_NUM_THREADS"] = str(param.nth)
os.environ["KMP_BLOCKTIME"] = "0"
os.environ["KMP_SETTINGS"] = "1"
os.environ["KMP_AFFINITY"]= "granularity=fine,verbose,compact,1,0"

torch.set_default_tensor_type(torch.DoubleTensor)

In [5]:
def run(param, field = param.initializer()):
    with open(param.uniquestr(), "w") as O:
        params = param.summary()
        O.write(params)
        put(params)
        plaq, topo = (action(param, field) / (-param.beta*param.volume), topocharge(field))
        status = f"Initial configuration:  plaq: {plaq}  topo: {topo}\n"
        O.write(status)
        put(status)
        ts = []
        for n in range(param.nrun):
            t = -timer()
            for i in range(param.ntraj):
                dH, exp_mdH, acc, field = hmc(param, field)
                plaq = action(param, field) / (-param.beta*param.volume)
                topo = topocharge(field)
                ifacc = "ACCEPT" if acc else "REJECT"
                status = f"Traj: {n*param.ntraj+i+1:4}  {ifacc}:  dH: {dH:< 12.8}  exp(-dH): {exp_mdH:< 12.8}  plaq: {plaq:< 12.8}  topo: {topo:< 3.3}\n"
                O.write(status)
                if (i+1) % (param.ntraj//param.nprint) == 0:
                    put(status)
            t += timer()
            ts.append(t)
        print("Run times: ", ts)
        print("Per trajectory: ", [t/param.ntraj for t in ts])
    return field
field = run(param)
field = torch.reshape(field,(1,)+field.shape)
field_run = field

latsize = (8, 8)
volume = 64
beta = 2.0
trajs = 2
tau = 2
steps = 8
seed = 1331
nth = 2
nth_interop = 2
Initial configuration:  plaq: 1.0  topo: 0.0
plaq(x) 1.0  force.norm 9.246826593536493
Traj:    1  ACCEPT:  dH: -2.1286122    exp(-dH):  8.4031965    plaq:  0.87165922   topo:  0.0
plaq(x) 0.8716592163190823  force.norm 13.60065719819702
Traj:    2  ACCEPT:  dH: -0.43065065   exp(-dH):  1.5382581    plaq:  0.85113562   topo:  0.0
plaq(x) 0.851135619433136  force.norm 17.54309727438331
Traj:    3  ACCEPT:  dH: -0.52286113   exp(-dH):  1.686847     plaq:  0.81229714   topo:  0.0
plaq(x) 0.8122971355705815  force.norm 17.226420992986522
Traj:    4  ACCEPT:  dH:  0.039753441  exp(-dH):  0.96102636   plaq:  0.78167924   topo:  0.0
plaq(x) 0.7816792404187667  force.norm 17.04668749206334
Traj:    5  ACCEPT:  dH:  0.42993024   exp(-dH):  0.65055447   plaq:  0.80392524   topo:  0.0
plaq(x) 0.803925239737245  force.norm 17.257422008841356
Traj:    6  ACCEPT:  dH: -0.96340329   exp(-dH):  2.62

  Variable._execution_engine.run_backward(


In [6]:
def ft_flow(flow, f):
    for layer in flow:
        f, lJ = layer.forward(f)
    return f.detach()

def ft_flow_inv(flow, f):
    for layer in reversed(flow):
        f, lJ = layer.reverse(f)
    return f.detach()

def ft_action(param, flow, f):
    y = f
    logJy = 0.0
    for layer in flow:
        y, lJ = layer.forward(y)
        logJy += lJ
    action = U1GaugeAction(param.beta)
    s = action(y) - logJy
    return s

def ft_force(param, flow, field, create_graph = False):
    # f is the field follows the transformed distribution (close to prior distribution)
    f = field
    f.requires_grad_(True)
    s = ft_action(param, flow, f)
    ss = torch.sum(s)
    # f.grad = None
    ff, = torch.autograd.grad(ss, f, create_graph = create_graph)
    f.requires_grad_(False)
    return ff

In [31]:
def train_step(model, action, optimizer, metrics, batch_size, with_force = False, pre_model = None):
    layers, prior = model['layers'], model['prior']
    optimizer.zero_grad()
    #
    xi = None
    if pre_model != None:
        pre_layers, pre_prior = pre_model['layers'], pre_model['prior']
        pre_xi = pre_prior.sample_n(batch_size)
        x = ft_flow(pre_layers, pre_xi)
        xi = ft_flow_inv(layers, x)
    #
    xi, x, logq = apply_flow_to_prior(prior, layers, batch_size=batch_size, xi=xi)
    logp = -action(x)
    #
    force_size = torch.tensor(0.0)
    dkl = calc_dkl(logp, logq)
    loss = torch.tensor(0.0)
    if with_force:
        assert pre_model != None
        force = ft_force(param, layers, xi, True)
        force_size = torch.sum(torch.square(force)) / 1000
        loss = force_size
    else:
        loss = dkl
    #
    loss.backward()
    #
    # minimization target
    # loss mini
    # -> (logq - logp) mini
    # -> (action - logJ) mini
    #
    optimizer.step()
    ess = compute_ess(logp, logq)
    #
    print(grab(loss),
          grab(force_size),
          grab(dkl),
          grab(ess),
          torch.linalg.norm(ft_force(param, layers, xi)))
    #
    metrics['loss'].append(grab(loss))
    metrics['force'].append(grab(force_size))
    metrics['dkl'].append(grab(dkl))
    metrics['logp'].append(grab(logp))
    metrics['logq'].append(grab(logq))
    metrics['ess'].append(grab(ess))

def flow_train(param, with_force = False, pre_model = None):  # packaged from original ipynb by Xiao-Yong Jin
    # Theory
    lattice_shape = param.lat
    link_shape = (2,*param.lat)
    beta = param.beta
    u1_action = U1GaugeAction(beta)
    # Model
    prior = MultivariateUniform(torch.zeros(link_shape), 2*np.pi*torch.ones(link_shape))
    #
    n_layers = 16
    n_s_nets = 2
    hidden_sizes = [8,8]
    kernel_size = 3
    layers = make_u1_equiv_layers(lattice_shape=lattice_shape, n_layers=n_layers, n_mixture_comps=n_s_nets,
                                  hidden_sizes=hidden_sizes, kernel_size=kernel_size)
    set_weights(layers)
    model = {'layers': layers, 'prior': prior}
    # Training
    base_lr = .001
    if with_force:
        base_lr = .0001
    optimizer = torch.optim.Adam(model['layers'].parameters(), lr=base_lr)
    #
    # ADJUST ME
    N_era = 10
    N_epoch = 100
    #
    batch_size = 64
    print_freq = N_epoch # epochs
    plot_freq = 1 # epochs
    history = {
        'loss' : [],
        'force' : [],
        'dkl' : [],
        'logp' : [],
        'logq' : [],
        'ess' : []
    }
    for era in range(N_era):
        for epoch in range(N_epoch):
            train_step(model, u1_action, optimizer, history, batch_size,
                       with_force = with_force, pre_model = pre_model)
            if epoch % print_freq == 0:
                print_metrics(history, print_freq, era, epoch)
    return model,u1_action

def flow_eval(model, u1_action):  # packaged from original ipynb by Xiao-Yong Jin
    ensemble_size = 1024
    u1_ens = make_mcmc_ensemble(model, u1_action, 64, ensemble_size)
    print("Accept rate:", np.mean(u1_ens['accepted']))
    Q = grab(topo_charge(torch.stack(u1_ens['x'], axis=0)))
    X_mean, X_err = bootstrap(Q**2, Nboot=100, binsize=16)
    print(f'Topological susceptibility = {X_mean:.2f} +/- {X_err:.2f}')
    print(f'... vs HMC estimate = 1.23 +/- 0.02')

In [8]:
pre_flow_model, flow_act = flow_train(param)
flow_eval(pre_flow_model,flow_act)
pre_flow = pre_flow_model['layers']

-235.34537692516128 0.0 -235.34537692516128 0.044984338671787305 tensor(179.4703)
== Era 0 | Epoch 0 metrics ==
	loss -235.345
	force 0
	dkl -235.345
	logp 0.326149
	logq -235.019
	ess 0.0449843
-239.1383882108473 0.0 -239.1383882108473 0.019970337139737032 tensor(175.4217)
-239.51936855245063 0.0 -239.51936855245063 0.01603148961419505 tensor(175.1459)
-243.8019138166205 0.0 -243.8019138166205 0.015768733372372272 tensor(167.5826)
-246.55102833203563 0.0 -246.55102833203563 0.01712442356040787 tensor(167.7365)
-248.18771779704565 0.0 -248.18771779704565 0.024465609045932344 tensor(162.6325)
-251.1713251794792 0.0 -251.1713251794792 0.015640384689887384 tensor(158.9293)
-253.6046561143978 0.0 -253.6046561143978 0.01766402489483935 tensor(156.8462)
-256.6035089734507 0.0 -256.6035089734507 0.016136360062499903 tensor(158.3880)
-258.2002477654431 0.0 -258.2002477654431 0.02161296312517467 tensor(158.9310)
-260.66395259167905 0.0 -260.66395259167905 0.060672253207140425 tensor(157.5602)
-

-282.2183491919442 0.0 -282.2183491919442 0.035489424739586224 tensor(176.1671)
-281.80745413557906 0.0 -281.80745413557906 0.03223995457395436 tensor(176.2737)
-281.5786820609213 0.0 -281.5786820609213 0.08399880218640958 tensor(186.4566)
-282.62302465971686 0.0 -282.62302465971686 0.03741876243752672 tensor(172.6518)
-282.24374305783067 0.0 -282.24374305783067 0.05084161874060404 tensor(180.6271)
-281.409349402846 0.0 -281.409349402846 0.13656446152717197 tensor(181.3924)
-282.784865970826 0.0 -282.784865970826 0.0744670922018435 tensor(182.1504)
-283.26606936989566 0.0 -283.26606936989566 0.02623195903615557 tensor(179.7745)
-282.5064218255832 0.0 -282.5064218255832 0.05677787623545568 tensor(174.7246)
-282.245081479515 0.0 -282.245081479515 0.12671190051680237 tensor(189.0983)
-282.62423284798393 0.0 -282.62423284798393 0.042411842654495095 tensor(176.2400)
-282.1892687908897 0.0 -282.1892687908897 0.0886350361750449 tensor(182.5670)
-283.13641900257517 0.0 -283.13641900257517 0.05

-284.8429597529703 0.0 -284.8429597529703 0.1212049823963425 tensor(188.5751)
-284.9351600879518 0.0 -284.9351600879518 0.04529845050025381 tensor(160.4054)
-284.92479100690605 0.0 -284.92479100690605 0.10299279725705703 tensor(162.8252)
-284.7006045154459 0.0 -284.7006045154459 0.0634903546981934 tensor(184.7761)
-284.23417909309563 0.0 -284.23417909309563 0.06553069828929119 tensor(183.3095)
-284.91079694828477 0.0 -284.91079694828477 0.1650090862175595 tensor(159.2117)
-284.86244739779977 0.0 -284.86244739779977 0.07048635421890902 tensor(166.3495)
-284.7051552926587 0.0 -284.7051552926587 0.05878212832837798 tensor(177.9648)
-284.898315807862 0.0 -284.898315807862 0.09834594584974636 tensor(171.2399)
-284.14618266483353 0.0 -284.14618266483353 0.06698763027745851 tensor(169.6701)
-284.46044108063006 0.0 -284.46044108063006 0.16226879887856469 tensor(175.0261)
-284.91939853223977 0.0 -284.91939853223977 0.06593788277972827 tensor(176.2942)
-285.2429056183637 0.0 -285.2429056183637 0

-284.9924276153818 0.0 -284.9924276153818 0.0682682104973811 tensor(160.3706)
-285.47598577640906 0.0 -285.47598577640906 0.1306804690480359 tensor(164.7346)
-285.774458633553 0.0 -285.774458633553 0.21663605512261588 tensor(181.0084)
-285.30548104891386 0.0 -285.30548104891386 0.16002888497508885 tensor(155.8888)
-285.36912382809015 0.0 -285.36912382809015 0.14315807526410207 tensor(150.5734)
-284.9779429399015 0.0 -284.9779429399015 0.0994562994601455 tensor(170.0647)
-285.701175592648 0.0 -285.701175592648 0.23117299012903894 tensor(173.1442)
-284.8885695860637 0.0 -284.8885695860637 0.021625846217532194 tensor(162.5646)
-285.3294141783652 0.0 -285.3294141783652 0.02891149803817181 tensor(154.7157)
-285.58379607904914 0.0 -285.58379607904914 0.1065485742056109 tensor(175.4341)
-284.86548068856814 0.0 -284.86548068856814 0.1513955935533572 tensor(159.7446)
-285.3089959551804 0.0 -285.3089959551804 0.16908856438189449 tensor(151.7761)
-285.1538291026305 0.0 -285.1538291026305 0.085683

-284.88405082152804 0.0 -284.88405082152804 0.0440837282273959 tensor(179.2421)
-285.7150312671116 0.0 -285.7150312671116 0.16567702480950963 tensor(159.0576)
-285.03933336994277 0.0 -285.03933336994277 0.14370817615085807 tensor(171.1891)
-285.8721464961258 0.0 -285.8721464961258 0.11958246018082601 tensor(168.3437)
-285.8337436501543 0.0 -285.8337436501543 0.05933580034129706 tensor(147.9022)
-285.56063784161165 0.0 -285.56063784161165 0.05754454112307764 tensor(170.3893)
-285.91978820175734 0.0 -285.91978820175734 0.0691287882063723 tensor(161.4032)
-285.51218808909346 0.0 -285.51218808909346 0.1042647469233061 tensor(169.6048)
-286.0137515907934 0.0 -286.0137515907934 0.16950230707171174 tensor(170.5170)
-285.6325431759516 0.0 -285.6325431759516 0.12297660716010761 tensor(171.7185)
-286.22892481212875 0.0 -286.22892481212875 0.035725046026131003 tensor(159.0035)
-285.403712218378 0.0 -285.403712218378 0.07808400725196686 tensor(160.8714)
-285.44471471219066 0.0 -285.44471471219066 

-286.47472173797644 0.0 -286.47472173797644 0.09182258305495207 tensor(149.4296)
-285.9607796128446 0.0 -285.9607796128446 0.18862357261061502 tensor(156.8301)
-285.6111727568119 0.0 -285.6111727568119 0.14114451681140844 tensor(205.5014)
-286.40606666537434 0.0 -286.40606666537434 0.23250984215594372 tensor(180.4642)
-286.3391684910078 0.0 -286.3391684910078 0.027495911775428542 tensor(172.7053)
-286.0323313692354 0.0 -286.0323313692354 0.05365137375771418 tensor(179.8557)
-286.16294886419064 0.0 -286.16294886419064 0.19698324732441733 tensor(208.0853)
-285.56547420220267 0.0 -285.56547420220267 0.2738823271308236 tensor(201.6251)
-286.0119648392606 0.0 -286.0119648392606 0.13935801292991742 tensor(170.5158)
-286.0334230540245 0.0 -286.0334230540245 0.17305758299458665 tensor(186.8160)
-285.78114986768156 0.0 -285.78114986768156 0.22943843882789655 tensor(175.7230)
-285.87816378429034 0.0 -285.87816378429034 0.18345726597562845 tensor(178.9020)
-285.9998484833079 0.0 -285.999848483307

-286.3183407226188 0.0 -286.3183407226188 0.05497635291636582 tensor(168.3196)
-286.39846727923185 0.0 -286.39846727923185 0.02999809625746632 tensor(183.9822)
-285.71793775061946 0.0 -285.71793775061946 0.2697860454202388 tensor(240.4530)
-286.89171004273675 0.0 -286.89171004273675 0.13042178985037003 tensor(190.7966)
-286.5078881198041 0.0 -286.5078881198041 0.1564416181632313 tensor(168.1657)
-286.3552168647048 0.0 -286.3552168647048 0.05647459326396601 tensor(155.8566)
-286.2218199331704 0.0 -286.2218199331704 0.049171846473235054 tensor(191.9443)
-286.1500915658622 0.0 -286.1500915658622 0.1631084134551966 tensor(196.8652)
-286.1225216480007 0.0 -286.1225216480007 0.2755557042262862 tensor(266.8346)
-286.5420371069774 0.0 -286.5420371069774 0.16808801799583012 tensor(164.7570)
-286.1599318492364 0.0 -286.1599318492364 0.0773500092606213 tensor(169.0632)
-286.18675701061915 0.0 -286.18675701061915 0.17755604816627615 tensor(172.2632)
-286.0491820692249 0.0 -286.0491820692249 0.2084

-286.48745329945166 0.0 -286.48745329945166 0.3351923603298981 tensor(183.4778)
-286.4086725866524 0.0 -286.4086725866524 0.2339367033684603 tensor(190.1558)
-286.3786840831285 0.0 -286.3786840831285 0.2852631636108011 tensor(177.7794)
-286.19242061668507 0.0 -286.19242061668507 0.2755282730440587 tensor(170.9211)
-286.3353751015657 0.0 -286.3353751015657 0.20449066532426072 tensor(171.3001)
-286.5412564001293 0.0 -286.5412564001293 0.0903277178347152 tensor(151.1315)
-286.64954102488946 0.0 -286.64954102488946 0.20758501131154133 tensor(154.3481)
-286.5632346647491 0.0 -286.5632346647491 0.13914443749305375 tensor(201.0176)
-286.79489591194033 0.0 -286.79489591194033 0.09587419141806518 tensor(165.4573)
-286.31831843890495 0.0 -286.31831843890495 0.1202599021096059 tensor(149.7972)
-286.66820240355975 0.0 -286.66820240355975 0.39738996218866346 tensor(167.7489)
-286.1537156245861 0.0 -286.1537156245861 0.15127741793434388 tensor(195.1199)
-285.92559520353893 0.0 -285.92559520353893 0.

-286.8727808855105 0.0 -286.8727808855105 0.2119762225984894 tensor(198.1626)
-286.4026814334459 0.0 -286.4026814334459 0.053351743906554165 tensor(174.8487)
-285.7469882067933 0.0 -285.7469882067933 0.18192247011366758 tensor(173.1592)
-286.41491595182043 0.0 -286.41491595182043 0.3231778264226082 tensor(235.9829)
-286.05824123879245 0.0 -286.05824123879245 0.1212738108742854 tensor(178.4911)
-286.06622801777286 0.0 -286.06622801777286 0.13319284648792587 tensor(254.3342)
-285.8847073045669 0.0 -285.8847073045669 0.20306200277060413 tensor(186.1448)
-286.58413832945365 0.0 -286.58413832945365 0.15970677779913178 tensor(197.4611)
-286.220231684355 0.0 -286.220231684355 0.26816636434025265 tensor(163.8824)
-286.2588822341328 0.0 -286.2588822341328 0.24439753779200485 tensor(196.8515)
-286.1987582742004 0.0 -286.1987582742004 0.1692952245667373 tensor(156.9427)
-286.36796379927716 0.0 -286.36796379927716 0.17812086778101455 tensor(197.4068)
-286.47062444364667 0.0 -286.47062444364667 0.3

-286.6543886427441 0.0 -286.6543886427441 0.027139273083411825 tensor(187.2683)
-287.11984830590984 0.0 -287.11984830590984 0.23946253513908966 tensor(180.7781)
-286.44726839079055 0.0 -286.44726839079055 0.2990186078514791 tensor(198.1157)
-286.71132282319854 0.0 -286.71132282319854 0.26285098549513436 tensor(166.8573)
-286.96558113287847 0.0 -286.96558113287847 0.2241727284520199 tensor(194.2638)
-286.04790435054997 0.0 -286.04790435054997 0.10465375528569884 tensor(189.9127)
-286.8215774662575 0.0 -286.8215774662575 0.26091876294329636 tensor(244.6845)
-286.1405829925794 0.0 -286.1405829925794 0.2026252107283514 tensor(181.0103)
-286.83591925113876 0.0 -286.83591925113876 0.15957090316272263 tensor(175.1791)
-286.55747095961726 0.0 -286.55747095961726 0.16052866625494644 tensor(192.5709)
-286.71901436589553 0.0 -286.71901436589553 0.201083817478051 tensor(224.3452)
-286.98951987543694 0.0 -286.98951987543694 0.24446415642835975 tensor(198.9069)
-286.97806731676167 0.0 -286.978067316

In [32]:
flow_model, flow_act = flow_train(param, with_force=True, pre_model=pre_flow_model)
flow_eval(flow_model,flow_act)
flow = flow_model['layers']
# flow.eval()

27.77397873000283 27.77397873000283 -324.7726341120924 0.03714393979584539 tensor(165.4058)
== Era 0 | Epoch 0 metrics ==
	loss 27.774
	force 27.774
	dkl -324.773
	logp 85.8983
	logq -238.874
	ess 0.0371439
26.91407619795939 26.91407619795939 -322.95765631459903 0.04061575398475373 tensor(162.8706)
26.767884668397304 26.767884668397304 -325.36317072662706 0.015734281402194558 tensor(162.3905)
26.24992180370059 26.24992180370059 -324.74301718258414 0.04104593314037602 tensor(160.8240)
26.63525599015806 26.63525599015806 -322.7901953025006 0.04243541695345945 tensor(162.0254)
26.81934919795798 26.81934919795798 -323.8843819176688 0.02651314548844088 tensor(162.5620)
26.19574095421433 26.19574095421433 -320.73880448920755 0.03226041012604028 tensor(160.7309)
25.77215270731929 25.77215270731929 -323.2999835264045 0.018805823847524303 tensor(159.3821)
24.34105313985693 24.34105313985693 -324.60268076663067 0.01619048364981672 tensor(154.8887)
24.221254247091423 24.221254247091423 -323.78034

13.009517298151831 13.009517298151831 -309.56884817605516 0.02295520505408384 tensor(113.0542)
13.144325807431873 13.144325807431873 -309.32251728352537 0.020842333623546448 tensor(113.6733)
13.035029344437012 13.035029344437012 -309.4694201366011 0.0775134300486318 tensor(113.1695)
12.53797674384126 12.53797674384126 -308.5255957202083 0.026206231087198852 tensor(111.0059)
12.458744035060564 12.458744035060564 -307.71903256805024 0.01680695848751508 tensor(110.6865)
12.828874297272916 12.828874297272916 -308.2911372681497 0.023990594704324866 tensor(112.2857)
12.154067392536254 12.154067392536254 -309.25290861185056 0.03259338328569028 tensor(109.2751)
12.497425346803494 12.497425346803494 -306.4734808546465 0.05326524846558484 tensor(110.8241)
12.121203929423121 12.121203929423121 -308.2776705269762 0.03326609600935348 tensor(109.0797)
12.089815614075984 12.089815614075984 -308.0969273621632 0.017238037130384545 tensor(108.9841)
10.806682009773917 10.806682009773917 -309.200360337764

17.1950742795009 17.1950742795009 -300.3809682395894 0.018798142279514943 tensor(128.8142)
16.586543326954015 16.586543326954015 -301.09909198766985 0.03644359916797689 tensor(126.7845)
14.966017056882498 14.966017056882498 -301.97647155671586 0.0354412521795978 tensor(120.1424)
16.20105645907035 16.20105645907035 -301.04257765612573 0.01651509306785345 tensor(124.9642)
15.913252370453009 15.913252370453009 -302.0116778646428 0.039363673419194355 tensor(124.1844)
18.02482124113124 18.02482124113124 -301.4853181661258 0.016110428166420773 tensor(131.8740)
16.333901354846862 16.333901354846862 -302.8924691465151 0.03169177237015928 tensor(125.2017)
17.26626051809003 17.26626051809003 -301.26428201288934 0.02507109899157521 tensor(128.6958)
16.03562904976514 16.03562904976514 -302.9059297992959 0.08735869540297513 tensor(124.1890)
15.95982375499482 15.95982375499482 -302.36641187346163 0.07673066490139471 tensor(123.6649)
18.906658143866064 18.906658143866064 -302.1328158426803 0.03314728

35.35429699620677 35.35429699620677 -309.44301014785674 0.0243454032887555 tensor(182.2585)
38.863100078561736 38.863100078561736 -309.4482762827648 0.043294395329799014 tensor(191.3257)
37.21088919406989 37.21088919406989 -310.04823623395305 0.017376955480751313 tensor(186.8121)
31.79659083495742 31.79659083495742 -308.16741739219106 0.01683312082929542 tensor(172.8128)
36.6395714488692 36.6395714488692 -308.61567090146536 0.06606079056073691 tensor(183.3813)
31.81777167971699 31.81777167971699 -309.4214985995179 0.01729362995133938 tensor(172.2053)
35.108514327720776 35.108514327720776 -309.57791530253974 0.03585271883948045 tensor(180.0924)
34.46002990477125 34.46002990477125 -310.15044766897154 0.038331615698732405 tensor(178.7407)
32.34470350292456 32.34470350292456 -309.5443231926607 0.03439944299601824 tensor(173.3701)
36.300631776551015 36.300631776551015 -308.80787133763357 0.01571988714123884 tensor(183.8771)
37.44997248627826 37.44997248627826 -308.86459734469264 0.015735891

45.0503749247797 45.0503749247797 -315.2507622547365 0.051520419576561234 tensor(197.8630)
44.48686161912643 44.48686161912643 -313.3129579933605 0.03517826725861532 tensor(195.9832)
41.39164805193615 41.39164805193615 -313.4585500331968 0.03712566751817073 tensor(187.8514)
43.12674010588215 43.12674010588215 -314.1976178385104 0.01695097952399715 tensor(192.6038)
41.089460299644536 41.089460299644536 -315.84910682253934 0.025425626763451364 tensor(193.3237)
42.90785973805504 42.90785973805504 -314.42229813320694 0.05627110783066065 tensor(197.1323)
47.346610005043594 47.346610005043594 -315.5797255298402 0.024900452146717694 tensor(199.6550)
43.27006133696745 43.27006133696745 -317.3404148080214 0.03341599882890113 tensor(194.3889)
52.93393678286721 52.93393678286721 -316.8484975028126 0.03360927946728099 tensor(207.4920)
47.768320750398246 47.768320750398246 -317.30574006998086 0.03230070831985299 tensor(203.3396)
45.669868230247765 45.669868230247765 -316.05832950059465 0.0182215418

28.665706219857036 28.665706219857036 -310.7225036254682 0.01616279978166587 tensor(166.3245)
28.807414044628743 28.807414044628743 -310.0172954560383 0.018465405212787705 tensor(169.0318)
27.738665609914108 27.738665609914108 -311.45603862417204 0.054659857930243486 tensor(163.4161)
30.826510014182347 30.826510014182347 -311.52465239370935 0.020131368071415405 tensor(172.3857)
28.619654902453735 28.619654902453735 -311.74545093718905 0.016937443572485673 tensor(166.5080)
28.689098849514288 28.689098849514288 -311.38031249925433 0.03027369128116943 tensor(166.7232)
29.024080298414884 29.024080298414884 -310.2851840886251 0.02310738257491598 tensor(168.2898)
29.788112519105063 29.788112519105063 -311.2054602321869 0.022887066548563133 tensor(170.2704)
32.979183148661 32.979183148661 -311.9624475876771 0.02337092320016644 tensor(177.6737)
29.39557202586504 29.39557202586504 -309.8032782208809 0.019687483380012454 tensor(168.9984)
29.3943589586803 29.3943589586803 -312.827794099047 0.0175

80.63370145848394 80.63370145848394 -321.7941037193 0.08383641438267128 tensor(257.4823)
93.92926239966253 93.92926239966253 -319.9509026116176 0.032339620359379435 tensor(257.3936)
86.87167978159343 86.87167978159343 -320.97831850583844 0.03705830383363055 tensor(241.9349)
71.69264790439364 71.69264790439364 -321.6852656308044 0.01630415082118613 tensor(246.9860)
70.32155824389375 70.32155824389375 -320.0616075782882 0.01661250852837919 tensor(235.5833)
74.0201816891283 74.0201816891283 -321.69650642360114 0.0443225174536213 tensor(241.6833)
73.44863018614755 73.44863018614755 -318.9778896026628 0.04542662808671604 tensor(244.5254)
75.18949280149442 75.18949280149442 -319.21320336154315 0.03184169959659128 tensor(250.2024)
79.97491922107112 79.97491922107112 -321.3646861176413 0.018434647396538106 tensor(249.6901)
85.46661567605835 85.46661567605835 -319.223448820679 0.027911253721062266 tensor(259.7973)
79.34782271636779 79.34782271636779 -320.1823346962834 0.017749103369652235 tenso

181.65608768570743 181.65608768570743 -323.47354436727284 0.045019551157424426 tensor(408.0710)
155.12556107958767 155.12556107958767 -322.78427137797996 0.06257573112442112 tensor(408.0915)
155.70709944210273 155.70709944210273 -322.19702026029216 0.04203302553759664 tensor(386.7660)
154.2939438878576 154.2939438878576 -324.0813765313353 0.025152359792144658 tensor(389.1889)
130.0926802616391 130.0926802616391 -321.81524742004854 0.018315291004942832 tensor(355.8395)
136.7728301757141 136.7728301757141 -323.3256254430799 0.0198683622398773 tensor(384.4378)
135.1534894513842 135.1534894513842 -322.886484779454 0.028554687286129705 tensor(372.9541)
140.6655586488424 140.6655586488424 -322.35172329532105 0.035174959984422015 tensor(360.1203)
141.79888153106856 141.79888153106856 -323.08621018318865 0.027008585033358692 tensor(376.5234)
153.88782916348052 153.88782916348052 -323.11633541314416 0.017297137531767523 tensor(391.0686)
135.27238321189066 135.27238321189066 -322.219090871224 0.

79.4273730494417 79.4273730494417 -318.7759146044166 0.019853609200635405 tensor(280.2026)
71.0194834346879 71.0194834346879 -318.66786138077623 0.016423150801170853 tensor(264.2678)
68.89047978827622 68.89047978827622 -319.26151999400406 0.04181828671497601 tensor(260.9306)
74.27185951222205 74.27185951222205 -317.4243042673273 0.027928298531312908 tensor(271.2451)
== Era 7 | Epoch 0 metrics ==
	loss 99.1425
	force 99.1425
	dkl -320.292
	logp 85.9464
	logq -234.346
	ess 0.0331792
76.70587507555432 76.70587507555432 -318.05710098431655 0.01667322302795106 tensor(275.5978)
78.11218690104243 78.11218690104243 -317.66766468801984 0.018697585256455172 tensor(277.9992)
73.72820878453487 73.72820878453487 -318.6173740349351 0.07366837406687728 tensor(269.7089)
73.95378617606819 73.95378617606819 -320.4279857212785 0.07055327758134822 tensor(269.5216)
69.4738971367191 69.4738971367191 -317.90371647725505 0.030587208108588587 tensor(262.0267)
66.16230736777696 66.16230736777696 -316.9361934751

82.74288480358658 82.74288480358658 -317.2621327873443 0.01622506634547729 tensor(277.2914)
85.99712333859195 85.99712333859195 -318.6204254508401 0.05380546533297986 tensor(287.7774)
76.80904345915064 76.80904345915064 -318.2623810209533 0.05065862122462366 tensor(273.8736)
74.1363738046789 74.1363738046789 -319.6510503410915 0.01577946850353168 tensor(279.6327)
97.52560455945566 97.52560455945566 -317.34839986656965 0.02002376881935342 tensor(301.7239)
88.96878198141486 88.96878198141486 -317.714147514894 0.017438766551435576 tensor(298.7170)
76.86626464391331 76.86626464391331 -319.5871053315606 0.01745455281142562 tensor(273.0981)
74.48266370798734 74.48266370798734 -318.35380740854316 0.03521906030082006 tensor(265.4572)
75.19791365641704 75.19791365641704 -318.3088258606043 0.04763871451653718 tensor(267.1702)
66.34998948782567 66.34998948782567 -318.3620008048792 0.04038686444101792 tensor(258.6650)
71.4511849554269 71.4511849554269 -318.69978020464083 0.05816223576693072 tensor

626.7233868574317 626.7233868574317 -337.36729207625706 0.019293629860757076 tensor(1412.7034)
535.446528799626 535.446528799626 -337.91369129802183 0.038495076068937734 tensor(1231.3895)
489.71455337312864 489.71455337312864 -335.6095387313769 0.021939974128057385 tensor(1237.6834)
426.150472470369 426.150472470369 -334.0369309341992 0.015724482293290314 tensor(1090.7742)
476.75774474507864 476.75774474507864 -333.5088935160295 0.015942482204443963 tensor(1058.9312)
498.2129368254296 498.2129368254296 -334.11756547358436 0.029732072415502797 tensor(1118.3511)
387.17725557713584 387.17725557713584 -329.3743623879847 0.015626432470939734 tensor(967.1091)
386.96467865257284 386.96467865257284 -331.1445942459946 0.01778524344830195 tensor(940.8708)
382.5078208701152 382.5078208701152 -328.9895116918243 0.04039735317903083 tensor(906.9559)
374.6321528398615 374.6321528398615 -328.26740397922913 0.019492240990197154 tensor(928.9940)
312.3595396743618 312.3595396743618 -327.06104198210676 0.

3998.431829674206 3998.431829674206 -331.7322923991422 0.02403784415869733 tensor(3844.9855)
5085.606012190049 5085.606012190049 -333.48727839120147 0.05287876052769839 tensor(4292.3186)
4714.050596369251 4714.050596369251 -331.0224902957007 0.015784129707359516 tensor(4401.0227)
5759.189270904075 5759.189270904075 -333.36919972814206 0.04969036948341942 tensor(4762.5804)
5272.183558920063 5272.183558920063 -334.3072087616177 0.0156250091719157 tensor(5383.0500)
4367.983791367229 4367.983791367229 -336.7480240684247 0.03273619429633967 tensor(5870.0304)
5137.712593458661 5137.712593458661 -336.36943618622126 0.016920905531488662 tensor(6668.5153)
4697.755637820138 4697.755637820138 -339.425695200841 0.0400388511256287 tensor(5686.7754)
6518.90617715728 6518.90617715728 -336.939915781685 0.019000968422224096 tensor(8027.8045)
4634.630899732005 4634.630899732005 -337.9641708094486 0.022934430341241367 tensor(8259.9521)
4310.881377053196 4310.881377053196 -341.34003629509095 0.01568644836

In [33]:
def test_force(x = None):
    model = flow_model
    layers, prior = model['layers'], model['prior']
    if x == None:
        pre_model = pre_flow_model
        pre_layers, pre_prior = pre_model['layers'], pre_model['prior']
        pre_xi = pre_prior.sample_n(1)
        x = ft_flow(pre_layers, pre_xi)
    xi = ft_flow_inv(layers, x)
    f = ft_force(param, layers, xi)
    f_s = torch.linalg.norm(f)
    print(f_s)

test_force()
test_force(field_run)

tensor(238.7274)
tensor(144.2265)


In [34]:
field_run = run(param, field_run[0])
field_run = torch.reshape(field_run,(1,)+field_run.shape)

latsize = (8, 8)
volume = 64
beta = 2.0
trajs = 4
tau = 0.3
steps = 8
seed = 1331
nth = 2
nth_interop = 2
Initial configuration:  plaq: 0.6730993066920293  topo: -1.0
plaq(x) 0.6730993066920293  force.norm 20.228018493084267
Traj:    1  ACCEPT:  dH:  0.010293555  exp(-dH):  0.98975924   plaq:  0.70425209   topo:  0.0
plaq(x) 0.7042520923127957  force.norm 18.926189663613435
Traj:    2  ACCEPT:  dH:  0.0059908377  exp(-dH):  0.99402707   plaq:  0.75833399   topo:  0.0
plaq(x) 0.7583339946270954  force.norm 18.513867027891838
Traj:    3  ACCEPT:  dH: -0.0070925583  exp(-dH):  1.0071178    plaq:  0.72418522   topo:  0.0
plaq(x) 0.7241852154893258  force.norm 19.360503196991594
Traj:    4  ACCEPT:  dH: -0.0021589002  exp(-dH):  1.0021612    plaq:  0.7386098    topo:  0.0
plaq(x) 0.7386097994599291  force.norm 19.821337880741375
Traj:    5  ACCEPT:  dH: -0.0015942778  exp(-dH):  1.0015955    plaq:  0.71898587   topo:  0.0
plaq(x) 0.7189858691427513  force.norm 20.164667252684573
Traj:    6 

In [35]:
flows = flow

field = torch.clone(field_run)

print(f'plaq(field) {action(param, field[0]) / (-param.beta*param.volume)}')
# field.requires_grad_(True)
x = field
logJ = 0.0
for layer in reversed(flows):
    x, lJ = layer.reverse(x)
    logJ += lJ

# x is the prior distribution now
    
x.requires_grad_(True)
    
y = x
logJy = 0.0
for layer in flows:
    y, lJ = layer.forward(y)
    logJy += lJ
    
s = action(param, y[0]) - logJy

print(logJ,logJy)


# print("eff_action", s + 136.3786)

print("original_action", action(param, y[0]) + 91)

print("eff_action", s + 56)

s.backward()

f = x.grad

x.requires_grad_(False)

print(f'plaq(x) {action(param, x[0]) / (-param.beta*param.volume)}  logJ {logJ}  force.norm {torch.linalg.norm(f)}')

print(f'plaq(y) {action(param, y[0]) / (-param.beta*param.volume)}')

print(f'plaq(x) {action(param, field_run[0]) / (-param.beta*param.volume)}  force.norm {torch.linalg.norm(force(param, field_run[0]))}')


plaq(field) 0.6937467217995373
tensor([-16.4901], grad_fn=<AddBackward0>) tensor([16.4901], grad_fn=<AddBackward0>)
original_action tensor(2.2005, grad_fn=<AddBackward0>)
eff_action tensor([-49.2896], grad_fn=<AddBackward0>)
plaq(x) -0.11987797017477399  logJ tensor([-16.4901], grad_fn=<AddBackward0>)  force.norm 163.17526172583203
plaq(y) 0.6937464660899486
plaq(x) 0.6937467217995373  force.norm 19.321995366088096


In [36]:
print(x.shape)
x = ft_flow_inv(flow, field_run)
# x = field_run
#for layer in reversed(flows):
#    x, lJ = layer.reverse(x)
ff = ft_force(param, flow, x)
print(torch.linalg.norm(ff))
fff = ft_force(param, flow, x)
print(torch.linalg.norm(fff))

torch.Size([1, 2, 8, 8])
tensor(163.1753)
tensor(163.1753)


In [37]:
x = ft_flow_inv(flow, field_run)
ft_action(param, flow, x)

tensor([-105.2896], grad_fn=<SubBackward0>)

In [38]:
def ft_leapfrog(param, flow, x, p):
    dt = param.dt
    x_ = x + 0.5*dt*p
    f = ft_force(param, flow, x_)
    p_ = p + (-dt)*f
    print(f'force.norm {torch.linalg.norm(f)} ft_action {float(ft_action(param, flow, x_).detach())} pp_action {0.5*torch.sum(p_*p_)}')
    for i in range(param.nstep-1):
        x_ = x_ + dt*p_
        f = ft_force(param, flow, x_)
        # print(f'force.norm {torch.linalg.norm(f)} ft_action {float(ft_action(param, flow, x_).detach())} pp_action {0.5*torch.sum(p_*p_)}')
        p_ = p_ + (-dt)*f
    x_ = x_ + 0.5*dt*p_
    return (x_, p_)

def ft_hmc(param, flow, field):
    x = ft_flow_inv(flow, field)
    p = torch.randn_like(x)
    act0 = ft_action(param, flow, x).detach() + 0.5*torch.sum(p*p)
    x_, p_ = ft_leapfrog(param, flow, x, p)
    xr = regularize(x_)
    act = ft_action(param, flow, xr).detach() + 0.5*torch.sum(p_*p_)
    prob = torch.rand([], dtype=torch.float64)
    dH = act-act0
    exp_mdH = torch.exp(-dH)
    acc = prob < exp_mdH
    # ADJUST ME
    newx = xr if acc else x
    # newx = xr
    newfield = ft_flow(flow, newx)
    return (float(dH), float(exp_mdH), acc, newfield)

In [39]:
def ft_run(param, flow, field = param.initializer()):
    with open(param.uniquestr(), "w") as O:
        params = param.summary()
        O.write(params)
        put(params)
        plaq, topo = (action(param, field) / (-param.beta*param.volume), topocharge(field))
        status = f"Initial configuration:  plaq: {plaq}  topo: {topo}\n"
        O.write(status)
        put(status)
        ts = []
        for n in range(param.nrun):
            t = -timer()
            for i in range(param.ntraj):
                field_run = torch.reshape(field,(1,)+field.shape)
                dH, exp_mdH, acc, field_run = ft_hmc(param, flow, field_run)
                field = field_run[0]
                plaq = action(param, field) / (-param.beta*param.volume)
                topo = topocharge(field)
                ifacc = "ACCEPT" if acc else "REJECT"
                status = f"Traj: {n*param.ntraj+i+1:4}  {ifacc}:  dH: {dH:< 12.8}  exp(-dH): {exp_mdH:< 12.8}  plaq: {plaq:< 12.8}  topo: {topo:< 3.3}\n"
                O.write(status)
                if (i+1) % (param.ntraj//param.nprint) == 0:
                    put(status)
            t += timer()
            ts.append(t)
        print("Run times: ", ts)
        print("Per trajectory: ", [t/param.ntraj for t in ts])
    return field

In [40]:
param = Param(
    beta = 2.0,
    lat = (8, 8),
    tau = 0.3, # 0.3
    nstep = 8, # 3
    # ADJUST ME
    ntraj = 4, # 2**16 # 2**10 # 2**15
    #
    nprint = 4,
    seed = 1331)

field_run = ft_run(param, pre_flow, field_run[0])
field_run = torch.reshape(field_run,(1,)+field_run.shape)

latsize = (8, 8)
volume = 64
beta = 2.0
trajs = 4
tau = 0.3
steps = 8
seed = 1331
nth = 2
nth_interop = 2
Initial configuration:  plaq: 0.6937467217995373  topo: 2.0
force.norm 27.69617606416228 ft_action -52.508123582432155 pp_action 64.60041730663544
Traj:    1  ACCEPT:  dH: -0.11410416   exp(-dH):  1.1208689    plaq:  0.68258013   topo: -1.0
force.norm 16.447812263770363 ft_action -52.44443526393059 pp_action 59.352502653638425
Traj:    2  ACCEPT:  dH: -0.87384304   exp(-dH):  2.3961015    plaq:  0.58139032   topo:  1.0
force.norm 20.116609343336297 ft_action -55.03639829945146 pp_action 62.855754350205906
Traj:    3  ACCEPT:  dH: -0.23774691   exp(-dH):  1.2683881    plaq:  0.62613819   topo: -1.0
force.norm 35.893419051851865 ft_action -53.659288588988986 pp_action 68.88548817263468
Traj:    4  ACCEPT:  dH:  0.91658479   exp(-dH):  0.3998824    plaq:  0.58430707   topo: -1.0
force.norm 45.50642162021287 ft_action -56.07181051740887 pp_action 70.36270051627473
Traj:    5  REJECT:  

In [41]:
param = Param(
    beta = 2.0,
    lat = (8, 8),
    tau = 0.3, # 0.3
    nstep = 8, # 3
    # ADJUST ME
    ntraj = 4, # 2**16 # 2**10 # 2**15
    #
    nprint = 4,
    seed = 1331)

field_run = ft_run(param, flow, field_run[0])
field_run = torch.reshape(field_run,(1,)+field_run.shape)

latsize = (8, 8)
volume = 64
beta = 2.0
trajs = 4
tau = 0.3
steps = 8
seed = 1331
nth = 2
nth_interop = 2
Initial configuration:  plaq: 0.6665683039476181  topo: -1.0
force.norm 307.9423040019176 ft_action -94.68608880721163 pp_action 102.85123465185504
Traj:    1  REJECT:  dH:  442.48482    exp(-dH):  6.7808737e-193  plaq:  0.6665684    topo: -1.0
force.norm 162.31905425742116 ft_action -97.04324594803019 pp_action 80.23882027527453
Traj:    2  REJECT:  dH:  424.77143    exp(-dH):  3.3427981e-185  plaq:  0.66656831   topo: -1.0
force.norm 101.61936938519396 ft_action -97.67048838705216 pp_action 65.61765389100694
Traj:    3  REJECT:  dH:  403.71406    exp(-dH):  4.6688661e-176  plaq:  0.66656887   topo: -1.0
force.norm 167.2469477034903 ft_action -92.25949067835325 pp_action 72.7943947867188
Traj:    4  REJECT:  dH:  270.36149    exp(-dH):  3.8326354e-118  plaq:  0.66656869   topo: -1.0
force.norm 302.6711325323761 ft_action -92.71316521527926 pp_action 113.13539241805272
Traj:    5  