In [1]:
import torch
import math
import sys
import os
from timeit import default_timer as timer
from functools import reduce
from field_transformation import *

In [2]:
# From Xiao-Yong

class Param:
    def __init__(self, beta = 6.0, lat = [64, 64], tau = 2.0, nstep = 50, ntraj = 256, nrun = 4, nprint = 256, seed = 11*13, randinit = False, nth = int(os.environ.get('OMP_NUM_THREADS', '2')), nth_interop = 2):
        self.beta = beta
        self.lat = lat
        self.nd = len(lat)
        self.volume = reduce(lambda x,y:x*y, lat)
        self.tau = tau
        self.nstep = nstep
        self.dt = self.tau / self.nstep
        self.ntraj = ntraj
        self.nrun = nrun
        self.nprint = nprint
        self.seed = seed
        self.randinit = randinit
        self.nth = nth
        self.nth_interop = nth_interop
    def initializer(self):
        if self.randinit:
            return torch.empty((param.nd,) + param.lat).uniform_(-math.pi, math.pi)
        else:
            return torch.zeros((param.nd,) + param.lat)
    def summary(self):
        return f"""latsize = {self.lat}
volume = {self.volume}
beta = {self.beta}
trajs = {self.ntraj}
tau = {self.tau}
steps = {self.nstep}
seed = {self.seed}
nth = {self.nth}
nth_interop = {self.nth_interop}
"""
    def uniquestr(self):
        lat = ".".join(str(x) for x in self.lat)
        return f"out_l{lat}_b{param.beta}_n{param.ntraj}_t{param.tau}_s{param.nstep}.out"

def action(param, f):
    return (-param.beta)*torch.sum(torch.cos(plaqphase(f)))

def force(param, f):
    f.requires_grad_(True)
    s = action(param, f)
    f.grad = None
    s.backward()
    ff = f.grad
    f.requires_grad_(False)
    return ff

plaqphase = lambda f: f[0,:] - f[1,:] - torch.roll(f[0,:], shifts=-1, dims=1) + torch.roll(f[1,:], shifts=-1, dims=0)
topocharge = lambda f: torch.floor(0.1 + torch.sum(regularize(plaqphase(f))) / (2*math.pi))
def regularize(f):
    p2 = 2*math.pi
    f_ = (f - math.pi) / p2
    return p2*(f_ - torch.floor(f_) - 0.5)

def leapfrog(param, x, p):
    dt = param.dt
    x_ = x + 0.5*dt*p
    f = force(param, x_)
    p_ = p + (-dt)*f
    print(f'plaq(x) {action(param, x) / (-param.beta*param.volume)}  force.norm {torch.linalg.norm(f)}')
    for i in range(param.nstep-1):
        x_ = x_ + dt*p_
        p_ = p_ + (-dt)*force(param, x_)
    x_ = x_ + 0.5*dt*p_
    return (x_, p_)
def hmc(param, x):
    p = torch.randn_like(x)
    act0 = action(param, x) + 0.5*torch.sum(p*p)
    x_, p_ = leapfrog(param, x, p)
    xr = regularize(x_)
    act = action(param, xr) + 0.5*torch.sum(p_*p_)
    prob = torch.rand([], dtype=torch.float64)
    dH = act-act0
    exp_mdH = torch.exp(-dH)
    acc = prob < exp_mdH
    newx = xr if acc else x
    return (dH, exp_mdH, acc, newx)

put = lambda s: sys.stdout.write(s)

In [3]:
param = Param(
    beta = 2.0,
    lat = (8, 8),
    tau = 2, # 0.3
    nstep = 8, # 3
    # ADJUST ME
    ntraj = 2, # 2**16 # 2**10 # 2**15
    #
    nprint = 2,
    seed = 1331)

In [4]:
torch.manual_seed(param.seed)

torch.set_num_threads(param.nth)
torch.set_num_interop_threads(param.nth_interop)
os.environ["OMP_NUM_THREADS"] = str(param.nth)
os.environ["KMP_BLOCKTIME"] = "0"
os.environ["KMP_SETTINGS"] = "1"
os.environ["KMP_AFFINITY"]= "granularity=fine,verbose,compact,1,0"

torch.set_default_tensor_type(torch.DoubleTensor)

In [5]:
def run(param, field = param.initializer()):
    with open(param.uniquestr(), "w") as O:
        params = param.summary()
        O.write(params)
        put(params)
        plaq, topo = (action(param, field) / (-param.beta*param.volume), topocharge(field))
        status = f"Initial configuration:  plaq: {plaq}  topo: {topo}\n"
        O.write(status)
        put(status)
        ts = []
        for n in range(param.nrun):
            t = -timer()
            for i in range(param.ntraj):
                dH, exp_mdH, acc, field = hmc(param, field)
                plaq = action(param, field) / (-param.beta*param.volume)
                topo = topocharge(field)
                ifacc = "ACCEPT" if acc else "REJECT"
                status = f"Traj: {n*param.ntraj+i+1:4}  {ifacc}:  dH: {dH:< 12.8}  exp(-dH): {exp_mdH:< 12.8}  plaq: {plaq:< 12.8}  topo: {topo:< 3.3}\n"
                O.write(status)
                if (i+1) % (param.ntraj//param.nprint) == 0:
                    put(status)
            t += timer()
            ts.append(t)
        print("Run times: ", ts)
        print("Per trajectory: ", [t/param.ntraj for t in ts])
    return field
field = run(param)
field = torch.reshape(field,(1,)+field.shape)
field_run = field

latsize = (8, 8)
volume = 64
beta = 2.0
trajs = 2
tau = 2
steps = 8
seed = 1331
nth = 2
nth_interop = 2
Initial configuration:  plaq: 1.0  topo: 0.0
plaq(x) 1.0  force.norm 9.246826593536493
Traj:    1  ACCEPT:  dH: -2.1286122    exp(-dH):  8.4031965    plaq:  0.87165922   topo:  0.0
plaq(x) 0.8716592163190823  force.norm 13.60065719819702
Traj:    2  ACCEPT:  dH: -0.43065065   exp(-dH):  1.5382581    plaq:  0.85113562   topo:  0.0
plaq(x) 0.851135619433136  force.norm 17.54309727438331
Traj:    3  ACCEPT:  dH: -0.52286113   exp(-dH):  1.686847     plaq:  0.81229714   topo:  0.0
plaq(x) 0.8122971355705815  force.norm 17.226420992986522
Traj:    4  ACCEPT:  dH:  0.039753441  exp(-dH):  0.96102636   plaq:  0.78167924   topo:  0.0
plaq(x) 0.7816792404187667  force.norm 17.04668749206334
Traj:    5  ACCEPT:  dH:  0.42993024   exp(-dH):  0.65055447   plaq:  0.80392524   topo:  0.0
plaq(x) 0.803925239737245  force.norm 17.257422008841356
Traj:    6  ACCEPT:  dH: -0.96340329   exp(-dH):  2.62

  Variable._execution_engine.run_backward(


In [6]:
def ft_flow(flow, f):
    for layer in flow:
        f, lJ = layer.forward(f)
    return f.detach()

def ft_flow_inv(flow, f):
    for layer in reversed(flow):
        f, lJ = layer.reverse(f)
    return f.detach()

def ft_action(param, flow, f):
    y = f
    logJy = 0.0
    for layer in flow:
        y, lJ = layer.forward(y)
        logJy += lJ
    action = U1GaugeAction(param.beta)
    s = action(y) - logJy
    return s

def ft_force(param, flow, field, create_graph = False):
    # f is the field follows the transformed distribution (close to prior distribution)
    f = field
    f.requires_grad_(True)
    s = ft_action(param, flow, f)
    ss = torch.sum(s)
    # f.grad = None
    ff, = torch.autograd.grad(ss, f, create_graph = create_graph)
    f.requires_grad_(False)
    return ff

In [47]:
def train_step(model, action, optimizer, metrics, batch_size, with_force = False, pre_model = None):
    layers, prior = model['layers'], model['prior']
    optimizer.zero_grad()
    #
    xi = None
    if pre_model != None:
        pre_layers, pre_prior = pre_model['layers'], pre_model['prior']
        pre_xi = pre_prior.sample_n(batch_size)
        x = ft_flow(pre_layers, pre_xi)
        xi = ft_flow_inv(layers, x)
    #
    xi, x, logq = apply_flow_to_prior(prior, layers, batch_size=batch_size, xi=xi)
    logp = -action(x)
    #
    force_size = torch.tensor(0.0)
    dkl = calc_dkl(logp, logq)
    loss = torch.tensor(0.0)
    if with_force:
        assert pre_model != None
        force = ft_force(param, layers, xi, True)
        force_size = torch.sum(torch.square(force))
        loss = force_size
    else:
        loss = dkl
    #
    loss.backward()
    #
    # minimization target
    # loss mini
    # -> (logq - logp) mini
    # -> (action - logJ) mini
    #
    optimizer.step()
    ess = compute_ess(logp, logq)
    #
    print(grab(loss),
          grab(force_size),
          grab(dkl),
          grab(ess),
          torch.linalg.norm(ft_force(param, layers, xi)))
    #
    metrics['loss'].append(grab(loss))
    metrics['force'].append(grab(force_size))
    metrics['dkl'].append(grab(dkl))
    metrics['logp'].append(grab(logp))
    metrics['logq'].append(grab(logq))
    metrics['ess'].append(grab(ess))

def flow_train(param, with_force = False, pre_model = None):  # packaged from original ipynb by Xiao-Yong Jin
    # Theory
    lattice_shape = param.lat
    link_shape = (2,*param.lat)
    beta = param.beta
    u1_action = U1GaugeAction(beta)
    # Model
    prior = MultivariateUniform(torch.zeros(link_shape), 2*np.pi*torch.ones(link_shape))
    #
    n_layers = 16
    n_s_nets = 2
    hidden_sizes = [8,8]
    kernel_size = 3
    layers = make_u1_equiv_layers(lattice_shape=lattice_shape, n_layers=n_layers, n_mixture_comps=n_s_nets,
                                  hidden_sizes=hidden_sizes, kernel_size=kernel_size)
    set_weights(layers)
    model = {'layers': layers, 'prior': prior}
    # Training
    base_lr = .001
    optimizer = torch.optim.Adam(model['layers'].parameters(), lr=base_lr)
    optimizer_wf = torch.optim.Adam(model['layers'].parameters(), lr=base_lr / 100.0)
    #
    # ADJUST ME
    N_era = 10
    N_epoch = 100
    #
    batch_size = 64
    print_freq = N_epoch # epochs
    plot_freq = 1 # epochs
    history = {
        'loss' : [],
        'force' : [],
        'dkl' : [],
        'logp' : [],
        'logq' : [],
        'ess' : []
    }
    for era in range(N_era):
        for epoch in range(N_epoch):
            train_step(model, u1_action, optimizer, history, batch_size)
            if with_force:
                train_step(model, u1_action, optimizer_wf, history, batch_size,
                           with_force = with_force, pre_model = pre_model)
            if epoch % print_freq == 0:
                print_metrics(history, print_freq, era, epoch)
    return model,u1_action

def flow_eval(model, u1_action):  # packaged from original ipynb by Xiao-Yong Jin
    ensemble_size = 1024
    u1_ens = make_mcmc_ensemble(model, u1_action, 64, ensemble_size)
    print("Accept rate:", np.mean(u1_ens['accepted']))
    Q = grab(topo_charge(torch.stack(u1_ens['x'], axis=0)))
    X_mean, X_err = bootstrap(Q**2, Nboot=100, binsize=16)
    print(f'Topological susceptibility = {X_mean:.2f} +/- {X_err:.2f}')
    print(f'... vs HMC estimate = 1.23 +/- 0.02')

In [8]:
pre_flow_model, flow_act = flow_train(param)
flow_eval(pre_flow_model,flow_act)
pre_flow = pre_flow_model['layers']

-235.34537692516128 0.0 -235.34537692516128 0.044984338671787305 tensor(179.4703)
== Era 0 | Epoch 0 metrics ==
	loss -235.345
	force 0
	dkl -235.345
	logp 0.326149
	logq -235.019
	ess 0.0449843
-239.1383882108473 0.0 -239.1383882108473 0.019970337139737032 tensor(175.4217)
-239.51936855245063 0.0 -239.51936855245063 0.01603148961419505 tensor(175.1459)
-243.8019138166205 0.0 -243.8019138166205 0.015768733372372272 tensor(167.5826)
-246.55102833203563 0.0 -246.55102833203563 0.01712442356040787 tensor(167.7365)
-248.18771779704565 0.0 -248.18771779704565 0.024465609045932344 tensor(162.6325)
-251.1713251794792 0.0 -251.1713251794792 0.015640384689887384 tensor(158.9293)
-253.6046561143978 0.0 -253.6046561143978 0.01766402489483935 tensor(156.8462)
-256.6035089734507 0.0 -256.6035089734507 0.016136360062499903 tensor(158.3880)
-258.2002477654431 0.0 -258.2002477654431 0.02161296312517467 tensor(158.9310)
-260.66395259167905 0.0 -260.66395259167905 0.060672253207140425 tensor(157.5602)
-

-282.2183491919442 0.0 -282.2183491919442 0.035489424739586224 tensor(176.1671)
-281.80745413557906 0.0 -281.80745413557906 0.03223995457395436 tensor(176.2737)
-281.5786820609213 0.0 -281.5786820609213 0.08399880218640958 tensor(186.4566)
-282.62302465971686 0.0 -282.62302465971686 0.03741876243752672 tensor(172.6518)
-282.24374305783067 0.0 -282.24374305783067 0.05084161874060404 tensor(180.6271)
-281.409349402846 0.0 -281.409349402846 0.13656446152717197 tensor(181.3924)
-282.784865970826 0.0 -282.784865970826 0.0744670922018435 tensor(182.1504)
-283.26606936989566 0.0 -283.26606936989566 0.02623195903615557 tensor(179.7745)
-282.5064218255832 0.0 -282.5064218255832 0.05677787623545568 tensor(174.7246)
-282.245081479515 0.0 -282.245081479515 0.12671190051680237 tensor(189.0983)
-282.62423284798393 0.0 -282.62423284798393 0.042411842654495095 tensor(176.2400)
-282.1892687908897 0.0 -282.1892687908897 0.0886350361750449 tensor(182.5670)
-283.13641900257517 0.0 -283.13641900257517 0.05

-284.8429597529703 0.0 -284.8429597529703 0.1212049823963425 tensor(188.5751)
-284.9351600879518 0.0 -284.9351600879518 0.04529845050025381 tensor(160.4054)
-284.92479100690605 0.0 -284.92479100690605 0.10299279725705703 tensor(162.8252)
-284.7006045154459 0.0 -284.7006045154459 0.0634903546981934 tensor(184.7761)
-284.23417909309563 0.0 -284.23417909309563 0.06553069828929119 tensor(183.3095)
-284.91079694828477 0.0 -284.91079694828477 0.1650090862175595 tensor(159.2117)
-284.86244739779977 0.0 -284.86244739779977 0.07048635421890902 tensor(166.3495)
-284.7051552926587 0.0 -284.7051552926587 0.05878212832837798 tensor(177.9648)
-284.898315807862 0.0 -284.898315807862 0.09834594584974636 tensor(171.2399)
-284.14618266483353 0.0 -284.14618266483353 0.06698763027745851 tensor(169.6701)
-284.46044108063006 0.0 -284.46044108063006 0.16226879887856469 tensor(175.0261)
-284.91939853223977 0.0 -284.91939853223977 0.06593788277972827 tensor(176.2942)
-285.2429056183637 0.0 -285.2429056183637 0

-284.9924276153818 0.0 -284.9924276153818 0.0682682104973811 tensor(160.3706)
-285.47598577640906 0.0 -285.47598577640906 0.1306804690480359 tensor(164.7346)
-285.774458633553 0.0 -285.774458633553 0.21663605512261588 tensor(181.0084)
-285.30548104891386 0.0 -285.30548104891386 0.16002888497508885 tensor(155.8888)
-285.36912382809015 0.0 -285.36912382809015 0.14315807526410207 tensor(150.5734)
-284.9779429399015 0.0 -284.9779429399015 0.0994562994601455 tensor(170.0647)
-285.701175592648 0.0 -285.701175592648 0.23117299012903894 tensor(173.1442)
-284.8885695860637 0.0 -284.8885695860637 0.021625846217532194 tensor(162.5646)
-285.3294141783652 0.0 -285.3294141783652 0.02891149803817181 tensor(154.7157)
-285.58379607904914 0.0 -285.58379607904914 0.1065485742056109 tensor(175.4341)
-284.86548068856814 0.0 -284.86548068856814 0.1513955935533572 tensor(159.7446)
-285.3089959551804 0.0 -285.3089959551804 0.16908856438189449 tensor(151.7761)
-285.1538291026305 0.0 -285.1538291026305 0.085683

-284.88405082152804 0.0 -284.88405082152804 0.0440837282273959 tensor(179.2421)
-285.7150312671116 0.0 -285.7150312671116 0.16567702480950963 tensor(159.0576)
-285.03933336994277 0.0 -285.03933336994277 0.14370817615085807 tensor(171.1891)
-285.8721464961258 0.0 -285.8721464961258 0.11958246018082601 tensor(168.3437)
-285.8337436501543 0.0 -285.8337436501543 0.05933580034129706 tensor(147.9022)
-285.56063784161165 0.0 -285.56063784161165 0.05754454112307764 tensor(170.3893)
-285.91978820175734 0.0 -285.91978820175734 0.0691287882063723 tensor(161.4032)
-285.51218808909346 0.0 -285.51218808909346 0.1042647469233061 tensor(169.6048)
-286.0137515907934 0.0 -286.0137515907934 0.16950230707171174 tensor(170.5170)
-285.6325431759516 0.0 -285.6325431759516 0.12297660716010761 tensor(171.7185)
-286.22892481212875 0.0 -286.22892481212875 0.035725046026131003 tensor(159.0035)
-285.403712218378 0.0 -285.403712218378 0.07808400725196686 tensor(160.8714)
-285.44471471219066 0.0 -285.44471471219066 

-286.47472173797644 0.0 -286.47472173797644 0.09182258305495207 tensor(149.4296)
-285.9607796128446 0.0 -285.9607796128446 0.18862357261061502 tensor(156.8301)
-285.6111727568119 0.0 -285.6111727568119 0.14114451681140844 tensor(205.5014)
-286.40606666537434 0.0 -286.40606666537434 0.23250984215594372 tensor(180.4642)
-286.3391684910078 0.0 -286.3391684910078 0.027495911775428542 tensor(172.7053)
-286.0323313692354 0.0 -286.0323313692354 0.05365137375771418 tensor(179.8557)
-286.16294886419064 0.0 -286.16294886419064 0.19698324732441733 tensor(208.0853)
-285.56547420220267 0.0 -285.56547420220267 0.2738823271308236 tensor(201.6251)
-286.0119648392606 0.0 -286.0119648392606 0.13935801292991742 tensor(170.5158)
-286.0334230540245 0.0 -286.0334230540245 0.17305758299458665 tensor(186.8160)
-285.78114986768156 0.0 -285.78114986768156 0.22943843882789655 tensor(175.7230)
-285.87816378429034 0.0 -285.87816378429034 0.18345726597562845 tensor(178.9020)
-285.9998484833079 0.0 -285.999848483307

-286.3183407226188 0.0 -286.3183407226188 0.05497635291636582 tensor(168.3196)
-286.39846727923185 0.0 -286.39846727923185 0.02999809625746632 tensor(183.9822)
-285.71793775061946 0.0 -285.71793775061946 0.2697860454202388 tensor(240.4530)
-286.89171004273675 0.0 -286.89171004273675 0.13042178985037003 tensor(190.7966)
-286.5078881198041 0.0 -286.5078881198041 0.1564416181632313 tensor(168.1657)
-286.3552168647048 0.0 -286.3552168647048 0.05647459326396601 tensor(155.8566)
-286.2218199331704 0.0 -286.2218199331704 0.049171846473235054 tensor(191.9443)
-286.1500915658622 0.0 -286.1500915658622 0.1631084134551966 tensor(196.8652)
-286.1225216480007 0.0 -286.1225216480007 0.2755557042262862 tensor(266.8346)
-286.5420371069774 0.0 -286.5420371069774 0.16808801799583012 tensor(164.7570)
-286.1599318492364 0.0 -286.1599318492364 0.0773500092606213 tensor(169.0632)
-286.18675701061915 0.0 -286.18675701061915 0.17755604816627615 tensor(172.2632)
-286.0491820692249 0.0 -286.0491820692249 0.2084

-286.48745329945166 0.0 -286.48745329945166 0.3351923603298981 tensor(183.4778)
-286.4086725866524 0.0 -286.4086725866524 0.2339367033684603 tensor(190.1558)
-286.3786840831285 0.0 -286.3786840831285 0.2852631636108011 tensor(177.7794)
-286.19242061668507 0.0 -286.19242061668507 0.2755282730440587 tensor(170.9211)
-286.3353751015657 0.0 -286.3353751015657 0.20449066532426072 tensor(171.3001)
-286.5412564001293 0.0 -286.5412564001293 0.0903277178347152 tensor(151.1315)
-286.64954102488946 0.0 -286.64954102488946 0.20758501131154133 tensor(154.3481)
-286.5632346647491 0.0 -286.5632346647491 0.13914443749305375 tensor(201.0176)
-286.79489591194033 0.0 -286.79489591194033 0.09587419141806518 tensor(165.4573)
-286.31831843890495 0.0 -286.31831843890495 0.1202599021096059 tensor(149.7972)
-286.66820240355975 0.0 -286.66820240355975 0.39738996218866346 tensor(167.7489)
-286.1537156245861 0.0 -286.1537156245861 0.15127741793434388 tensor(195.1199)
-285.92559520353893 0.0 -285.92559520353893 0.

-286.8727808855105 0.0 -286.8727808855105 0.2119762225984894 tensor(198.1626)
-286.4026814334459 0.0 -286.4026814334459 0.053351743906554165 tensor(174.8487)
-285.7469882067933 0.0 -285.7469882067933 0.18192247011366758 tensor(173.1592)
-286.41491595182043 0.0 -286.41491595182043 0.3231778264226082 tensor(235.9829)
-286.05824123879245 0.0 -286.05824123879245 0.1212738108742854 tensor(178.4911)
-286.06622801777286 0.0 -286.06622801777286 0.13319284648792587 tensor(254.3342)
-285.8847073045669 0.0 -285.8847073045669 0.20306200277060413 tensor(186.1448)
-286.58413832945365 0.0 -286.58413832945365 0.15970677779913178 tensor(197.4611)
-286.220231684355 0.0 -286.220231684355 0.26816636434025265 tensor(163.8824)
-286.2588822341328 0.0 -286.2588822341328 0.24439753779200485 tensor(196.8515)
-286.1987582742004 0.0 -286.1987582742004 0.1692952245667373 tensor(156.9427)
-286.36796379927716 0.0 -286.36796379927716 0.17812086778101455 tensor(197.4068)
-286.47062444364667 0.0 -286.47062444364667 0.3

-286.6543886427441 0.0 -286.6543886427441 0.027139273083411825 tensor(187.2683)
-287.11984830590984 0.0 -287.11984830590984 0.23946253513908966 tensor(180.7781)
-286.44726839079055 0.0 -286.44726839079055 0.2990186078514791 tensor(198.1157)
-286.71132282319854 0.0 -286.71132282319854 0.26285098549513436 tensor(166.8573)
-286.96558113287847 0.0 -286.96558113287847 0.2241727284520199 tensor(194.2638)
-286.04790435054997 0.0 -286.04790435054997 0.10465375528569884 tensor(189.9127)
-286.8215774662575 0.0 -286.8215774662575 0.26091876294329636 tensor(244.6845)
-286.1405829925794 0.0 -286.1405829925794 0.2026252107283514 tensor(181.0103)
-286.83591925113876 0.0 -286.83591925113876 0.15957090316272263 tensor(175.1791)
-286.55747095961726 0.0 -286.55747095961726 0.16052866625494644 tensor(192.5709)
-286.71901436589553 0.0 -286.71901436589553 0.201083817478051 tensor(224.3452)
-286.98951987543694 0.0 -286.98951987543694 0.24446415642835975 tensor(198.9069)
-286.97806731676167 0.0 -286.978067316

In [48]:
flow_model, flow_act = flow_train(param, with_force=True, pre_model=pre_flow_model)
flow_eval(flow_model,flow_act)
flow = flow_model['layers']
# flow.eval()

-234.1049267240735 0.0 -234.1049267240735 0.029680248119994864 tensor(180.2339)
21739.546722845327 21739.546722845327 -322.0913006781591 0.027003106659758434 tensor(147.3328)
== Era 0 | Epoch 0 metrics ==
	loss 10752.7
	force 10869.8
	dkl -278.098
	logp 42.957
	logq -235.141
	ess 0.0283417
-236.25039963366106 0.0 -236.25039963366106 0.016675070530259576 tensor(176.9464)
21198.78064525335 21198.78064525335 -318.63883044463967 0.028715342619460765 tensor(145.4881)
-238.81326578165553 0.0 -238.81326578165553 0.01881355300787671 tensor(172.0803)
19091.394608603572 19091.394608603572 -317.11527159276955 0.017443370356200617 tensor(138.0659)
-239.66416952878052 0.0 -239.66416952878052 0.018499183480395916 tensor(169.8493)
17234.53732401005 17234.53732401005 -314.9936872239797 0.06508884330417708 tensor(131.1823)
-243.99699154242143 0.0 -243.99699154242143 0.015628975618486952 tensor(168.1516)
16994.97538714464 16994.97538714464 -313.2089804609801 0.03749801879381865 tensor(130.2652)
-245.586

17588.370137844176 17588.370137844176 -292.6676490064457 0.03972894758231132 tensor(132.5929)
-279.6749213162015 0.0 -279.6749213162015 0.02719592167278899 tensor(183.6588)
15682.06365914967 15682.06365914967 -292.66326168998467 0.016069312339225275 tensor(125.1989)
-280.2841202021322 0.0 -280.2841202021322 0.043314077837702615 tensor(181.0820)
16255.22364044438 16255.22364044438 -292.81258065611524 0.07868678126131966 tensor(127.4725)
-279.79595405335385 0.0 -279.79595405335385 0.020976136635223745 tensor(178.0183)
17439.25487882264 17439.25487882264 -292.6060632275021 0.02750860718604779 tensor(132.0270)
-280.75750401167903 0.0 -280.75750401167903 0.05353286558089714 tensor(174.5673)
16901.92675767087 16901.92675767087 -293.5275028540224 0.06581131959413922 tensor(129.9788)
-280.317106220563 0.0 -280.317106220563 0.03203826574969674 tensor(173.0950)
17309.25081036592 17309.25081036592 -292.6488053299377 0.10170481089574196 tensor(131.5376)
-281.1524063854301 0.0 -281.1524063854301 0.

-282.4807059955758 0.0 -282.4807059955758 0.14291288706490882 tensor(178.9289)
18874.459959940825 18874.459959940825 -291.3795182703992 0.07030003236513079 tensor(137.3579)
-282.1063038722565 0.0 -282.1063038722565 0.07005811817688795 tensor(164.6192)
15999.889675571672 15999.889675571672 -291.2544725156252 0.043254657819785124 tensor(126.4677)
-282.34430767812455 0.0 -282.34430767812455 0.07313710620485946 tensor(166.8378)
19204.322144249734 19204.322144249734 -290.7417599014404 0.022355992121180914 tensor(138.5538)
-282.70811692118724 0.0 -282.70811692118724 0.15274205614409392 tensor(170.4010)
19300.292802728596 19300.292802728596 -291.77067320995775 0.08613840755735863 tensor(138.9054)
-281.7883220278575 0.0 -281.7883220278575 0.0593760286293341 tensor(175.5788)
20172.757869619654 20172.757869619654 -290.50538137531873 0.10015470749820889 tensor(141.9935)
-283.2855742424162 0.0 -283.2855742424162 0.02429837068164689 tensor(177.8057)
18925.014204152798 18925.014204152798 -291.071907

-283.82997813525446 0.0 -283.82997813525446 0.047971942496735955 tensor(162.2761)
20804.782074619536 20804.782074619536 -289.44289218399405 0.06847707611261382 tensor(144.1516)
-284.7385084301931 0.0 -284.7385084301931 0.11174980110338288 tensor(166.9953)
22501.828172289876 22501.828172289876 -289.62507885055743 0.10384940232615705 tensor(149.8188)
-283.8451274162702 0.0 -283.8451274162702 0.08128298190489573 tensor(174.0630)
18415.706056955838 18415.706056955838 -289.9874309181903 0.031027101025085442 tensor(135.6195)
-283.7715257288183 0.0 -283.7715257288183 0.10132101029586715 tensor(169.6645)
19033.440745112046 19033.440745112046 -289.84569041454176 0.02308953856057568 tensor(137.8827)
-283.7548951755654 0.0 -283.7548951755654 0.04441672174412427 tensor(159.3728)
20058.049617360866 20058.049617360866 -289.25959598637985 0.10853510406708548 tensor(141.4752)
-283.94828615480503 0.0 -283.94828615480503 0.09967296532172916 tensor(179.2638)
17909.82804161143 17909.82804161143 -289.71119

21702.097014781455 21702.097014781455 -288.90719782226654 0.03420746491515093 tensor(147.2025)
-284.2908859607855 0.0 -284.2908859607855 0.03028627110814998 tensor(160.1179)
22008.040140366575 22008.040140366575 -288.1981431112241 0.13246143640640562 tensor(148.1549)
-284.80801591290987 0.0 -284.80801591290987 0.08562709695745312 tensor(163.1786)
18980.43848593627 18980.43848593627 -288.52334922439263 0.06203418940906437 tensor(137.6249)
-284.9972138922167 0.0 -284.9972138922167 0.05267259035368084 tensor(152.0160)
21442.29901169374 21442.29901169374 -288.6456242339548 0.1515201676585606 tensor(146.2547)
-284.35336359711584 0.0 -284.35336359711584 0.13939284782777597 tensor(154.4785)
22504.364843388328 22504.364843388328 -288.56804675856284 0.06255730633829308 tensor(149.8352)
-284.5409371831063 0.0 -284.5409371831063 0.12389276316335401 tensor(147.0434)
22246.156831054708 22246.156831054708 -288.46771428478314 0.11106726731121917 tensor(148.8996)
-284.4790251101251 0.0 -284.4790251101

24656.225097452232 24656.225097452232 -288.13103980136066 0.12892492053704682 tensor(156.8658)
-285.3239273678372 0.0 -285.3239273678372 0.10935601480289914 tensor(176.9002)
20339.993140699466 20339.993140699466 -288.67041560794354 0.14771263925462197 tensor(142.3877)
-285.1149211116293 0.0 -285.1149211116293 0.06907624891136874 tensor(170.7116)
21019.88433495773 21019.88433495773 -288.1622083315353 0.08650820532043481 tensor(144.8833)
-285.1543752442691 0.0 -285.1543752442691 0.16410688611091423 tensor(160.9448)
21237.712172664877 21237.712172664877 -288.160092435269 0.20930462667707614 tensor(145.5633)
-285.98261428168223 0.0 -285.98261428168223 0.05253283680017044 tensor(164.3468)
21623.971805392794 21623.971805392794 -288.0335244809429 0.15737441500017496 tensor(146.9096)
-285.22544960577426 0.0 -285.22544960577426 0.08252912298224813 tensor(162.6209)
18729.03645649703 18729.03645649703 -288.5359742131693 0.1355635084578931 tensor(136.6928)
-284.73737064622935 0.0 -284.737370646229

-285.6395679446849 0.0 -285.6395679446849 0.189803658085953 tensor(177.2715)
23736.039197236834 23736.039197236834 -288.59030441328684 0.07892917846152321 tensor(153.8006)
-285.6286233196395 0.0 -285.6286233196395 0.16739066414144196 tensor(173.8059)
22934.583554890472 22934.583554890472 -287.74821344971707 0.03177230089077971 tensor(151.2852)
-285.2912541019243 0.0 -285.2912541019243 0.07983554968642283 tensor(178.0631)
26542.771042466586 26542.771042466586 -288.31325754061754 0.0995507139604263 tensor(162.7172)
-285.1016305577672 0.0 -285.1016305577672 0.1623297537290794 tensor(181.4990)
21630.029392891203 21630.029392891203 -288.0866142751 0.01931296062160344 tensor(146.8889)
-285.1569965750868 0.0 -285.1569965750868 0.06010995764032584 tensor(156.1860)
21006.32085844381 21006.32085844381 -288.504727560447 0.11393078347306904 tensor(144.7551)
-285.795609027149 0.0 -285.795609027149 0.027051541823417815 tensor(153.2983)
22420.83898399689 22420.83898399689 -287.9758584628669 0.1413625

-285.59863099796934 0.0 -285.59863099796934 0.13172428852815715 tensor(176.6647)
23816.473791843986 23816.473791843986 -288.0307338032169 0.2977903026770003 tensor(154.1471)
-284.76637098379933 0.0 -284.76637098379933 0.15670106370053438 tensor(175.5907)
22732.759980873587 22732.759980873587 -287.4613813540058 0.11122707331272921 tensor(150.6735)
-285.77892874104487 0.0 -285.77892874104487 0.2708603443417837 tensor(137.7887)
34535.30546739172 34535.30546739172 -287.403838543594 0.16384373704639865 tensor(185.3063)
-285.4465654882779 0.0 -285.4465654882779 0.22951208420258104 tensor(184.8705)
26290.907475603395 26290.907475603395 -287.5615606527141 0.06460463555810093 tensor(161.9541)
-285.43206028636763 0.0 -285.43206028636763 0.07068949924083168 tensor(183.9613)
28280.369113625653 28280.369113625653 -287.44167170966676 0.15340307835844244 tensor(167.8564)
-285.9161366764289 0.0 -285.9161366764289 0.05670197833011408 tensor(173.1990)
31535.870940305795 31535.870940305795 -287.684144334

23496.128615607646 23496.128615607646 -287.4481114099446 0.2037598367915497 tensor(153.0783)
-285.43601816019304 0.0 -285.43601816019304 0.12496777196688955 tensor(163.9036)
23564.78970301167 23564.78970301167 -287.78969334025265 0.09461456266220286 tensor(153.2727)
-285.80969831550976 0.0 -285.80969831550976 0.1775223495473102 tensor(162.5509)
33606.5352045975 33606.5352045975 -287.8478515873164 0.14385998116781806 tensor(182.9608)
-285.46710983497337 0.0 -285.46710983497337 0.12158987711490679 tensor(164.1634)
24154.99628226481 24154.99628226481 -287.5382734324858 0.09478677140045956 tensor(155.0495)
-285.8822971087269 0.0 -285.8822971087269 0.2393002340281489 tensor(190.0392)
25044.70292029891 25044.70292029891 -287.78831998281026 0.07402751249121538 tensor(158.0938)
-285.64353078112424 0.0 -285.64353078112424 0.08923302448109266 tensor(175.4746)
24618.65469031376 24618.65469031376 -288.00020733887624 0.1441483439514376 tensor(156.8475)
-285.35107343268373 0.0 -285.35107343268373 0.

27267.170536149373 27267.170536149373 -287.822280009254 0.19437015917453287 tensor(164.5239)
-286.0485895467856 0.0 -286.0485895467856 0.11746598262739422 tensor(167.7474)
22585.878752978006 22585.878752978006 -287.3744201194878 0.08114590138906702 tensor(149.7327)
-286.04754452420906 0.0 -286.04754452420906 0.06303136790692598 tensor(180.7571)
24489.9455959846 24489.9455959846 -287.39209731726265 0.13621361023841255 tensor(156.3169)
-285.84668254089706 0.0 -285.84668254089706 0.07215850768320632 tensor(165.3823)
23951.959337591346 23951.959337591346 -287.8623283728598 0.06858447059207005 tensor(154.4725)
-285.5275074411133 0.0 -285.5275074411133 0.032770782273452266 tensor(232.8009)
36542.526883288454 36542.526883288454 -287.9535929252522 0.11639888712301497 tensor(190.4938)
-285.9424930221997 0.0 -285.9424930221997 0.16558907095922815 tensor(172.6666)
27521.4515012092 27521.4515012092 -287.28742977037814 0.13911951482172863 tensor(165.7849)
-286.02294593338485 0.0 -286.02294593338485

-285.97445514215053 0.0 -285.97445514215053 0.059554076657697415 tensor(180.0116)
22371.83515530373 22371.83515530373 -287.7409068121698 0.1342988447370125 tensor(149.3614)
-286.22677254588746 0.0 -286.22677254588746 0.06454372723056026 tensor(176.9072)
34288.92797143833 34288.92797143833 -287.67135476770585 0.2235014042312175 tensor(183.8168)
-285.4158049718267 0.0 -285.4158049718267 0.034222366902710015 tensor(199.4721)
30027.126975443884 30027.126975443884 -287.57095624136986 0.1925138446756468 tensor(172.9297)
-286.0944798970328 0.0 -286.0944798970328 0.08213694727213951 tensor(175.4336)
31668.292281166778 31668.292281166778 -287.58677210733333 0.2638756605798576 tensor(176.4767)
-285.92638374620395 0.0 -285.92638374620395 0.0744843127845293 tensor(201.6243)
25434.210584188637 25434.210584188637 -287.8002671469246 0.22105251197373663 tensor(159.2399)
-285.8314811429243 0.0 -285.8314811429243 0.07871473160117777 tensor(163.6023)
29800.301620626647 29800.301620626647 -287.45714386515

-285.745570847045 0.0 -285.745570847045 0.12995129658382154 tensor(169.6874)
24216.906951959943 24216.906951959943 -287.67580792856074 0.06652737325669816 tensor(155.4655)
-285.9054902084646 0.0 -285.9054902084646 0.13675028014015841 tensor(158.5349)
33252.09411846067 33252.09411846067 -287.6150014927216 0.2924818945567518 tensor(181.9930)
-285.65118362575487 0.0 -285.65118362575487 0.22427391307447153 tensor(181.6058)
26831.449106897424 26831.449106897424 -287.58708672934705 0.289729505852617 tensor(163.6063)
-285.5671973616445 0.0 -285.5671973616445 0.0880381513419362 tensor(175.3864)
32142.528279795584 32142.528279795584 -287.6141296660113 0.11654521534727667 tensor(179.0673)
-285.94071900798883 0.0 -285.94071900798883 0.05977913995315427 tensor(165.9598)
27127.713258536165 27127.713258536165 -287.61442518039 0.10999137021650045 tensor(164.2598)
-286.41471527284045 0.0 -286.41471527284045 0.12210632488142013 tensor(188.7223)
48177.68207024736 48177.68207024736 -287.4012870978532 0.1

-286.2235438621625 0.0 -286.2235438621625 0.22765072777677584 tensor(202.2905)
35955.046630263096 35955.046630263096 -287.33926656206376 0.22959791234936439 tensor(189.4094)
-286.1721994062106 0.0 -286.1721994062106 0.16527123515734518 tensor(178.1206)
34524.16431815109 34524.16431815109 -287.42589842899906 0.22856287852671042 tensor(183.3455)
-286.0502086442439 0.0 -286.0502086442439 0.12062674883193604 tensor(183.5395)
38328.59975203832 38328.59975203832 -287.74774645907166 0.04716791104594132 tensor(195.6222)
-286.20479811234543 0.0 -286.20479811234543 0.2185418639218416 tensor(190.2509)
31570.518437935465 31570.518437935465 -287.156914989274 0.16167680825184313 tensor(176.9201)
-286.3282387121558 0.0 -286.3282387121558 0.2138323577468179 tensor(206.3267)
28107.99129046846 28107.99129046846 -286.68063373601194 0.15138664632274357 tensor(167.4656)
-285.47692521268993 0.0 -285.47692521268993 0.05135345704717133 tensor(183.7330)
31926.99649568144 31926.99649568144 -287.77429082430234 0

-285.90159766650925 0.0 -285.90159766650925 0.18689555263277136 tensor(178.6400)
29302.24325664845 29302.24325664845 -287.52641541521496 0.2977202533024342 tensor(171.0138)
-286.1957280156796 0.0 -286.1957280156796 0.15390818356930158 tensor(217.5358)
35480.979695792565 35480.979695792565 -287.5686647133613 0.23205930868124008 tensor(187.9015)
-285.9958945482799 0.0 -285.9958945482799 0.056247031466104214 tensor(160.1315)
35972.52858495606 35972.52858495606 -287.43726443657 0.20142393295715794 tensor(189.3919)
-285.78867696698524 0.0 -285.78867696698524 0.22767975063345652 tensor(153.9529)
33611.773077429316 33611.773077429316 -287.3658815822519 0.10355706775763862 tensor(182.9252)
-285.80994790602233 0.0 -285.80994790602233 0.07955622489641592 tensor(173.3980)
37548.75162216857 37548.75162216857 -287.3611852461448 0.21728268277673488 tensor(192.6393)
-285.84140271642576 0.0 -285.84140271642576 0.0484315922181892 tensor(173.5173)
31248.079954687953 31248.079954687953 -287.0932108033753

-286.1228533711768 0.0 -286.1228533711768 0.20091007105013536 tensor(181.7849)
27861.85487746679 27861.85487746679 -287.3324030819549 0.16021709811163998 tensor(166.6279)
-286.24402854848114 0.0 -286.24402854848114 0.17020675754593093 tensor(215.6605)
31748.558512302072 31748.558512302072 -287.3033418458452 0.24937818626950012 tensor(178.0052)
-286.0685170094599 0.0 -286.0685170094599 0.20811368222901414 tensor(169.1880)
30150.180079978883 30150.180079978883 -287.6739167553884 0.24212361002212002 tensor(173.4069)
-285.94365649697977 0.0 -285.94365649697977 0.09220945835519559 tensor(200.0513)
45637.1730235642 45637.1730235642 -287.1572303821458 0.22399631456998806 tensor(213.2759)
-285.8758881661748 0.0 -285.8758881661748 0.1362172164367375 tensor(206.4420)
42746.899630124 42746.899630124 -287.34677683263936 0.03268972445234289 tensor(206.3393)
-286.8380397636411 0.0 -286.8380397636411 0.13474220554130567 tensor(184.1085)
34625.16100218191 34625.16100218191 -287.3629368817866 0.0277843

-286.09707680826546 0.0 -286.09707680826546 0.1106302868761784 tensor(191.3049)
56158.08276143624 56158.08276143624 -287.33248896354763 0.17151008890692332 tensor(235.8843)
-286.2763957518143 0.0 -286.2763957518143 0.0748618011582997 tensor(173.2954)
29247.91350033273 29247.91350033273 -287.3798074063824 0.31458768531878684 tensor(170.7359)
-286.46878911456577 0.0 -286.46878911456577 0.10756301082849375 tensor(195.2921)
31950.826712480135 31950.826712480135 -287.25828799371 0.23999525418729828 tensor(178.6188)
-286.4451307236196 0.0 -286.4451307236196 0.08525938452059076 tensor(173.7399)
32422.722634581118 32422.722634581118 -286.9740522737057 0.13682637306144496 tensor(179.8348)
-286.32036925142927 0.0 -286.32036925142927 0.07369280810878871 tensor(215.0411)
27548.18381287833 27548.18381287833 -286.8277954139902 0.12395876528832371 tensor(165.8498)
-286.5535191877483 0.0 -286.5535191877483 0.07489620142065255 tensor(220.1715)
28740.270232658324 28740.270232658324 -287.4239030591767 0.

-286.2359568478972 0.0 -286.2359568478972 0.10964535065922734 tensor(157.0696)
27409.379678261677 27409.379678261677 -287.3206645629312 0.32318859142586054 tensor(165.3119)
-286.28346193704033 0.0 -286.28346193704033 0.05906125254538689 tensor(171.4154)
29310.595723419246 29310.595723419246 -287.3598031382814 0.0605816192474873 tensor(171.0425)
-286.5760376439953 0.0 -286.5760376439953 0.15632470155736997 tensor(191.8697)
41499.76400288858 41499.76400288858 -287.33367458393377 0.15504086937714237 tensor(203.3443)
-285.9869401828669 0.0 -285.9869401828669 0.1415508110885771 tensor(210.4469)
33686.596503595705 33686.596503595705 -287.2998433597515 0.24211596955203085 tensor(183.2082)
-285.805362331583 0.0 -285.805362331583 0.179503995724682 tensor(208.8126)
24442.3698976447 24442.3698976447 -287.4576678137736 0.10239279448244755 tensor(156.3107)
-286.3368579453281 0.0 -286.3368579453281 0.27322939931185963 tensor(175.6816)
30684.58749871014 30684.58749871014 -286.6908453228343 0.21343035

-285.8377515716842 0.0 -285.8377515716842 0.04141448428140166 tensor(178.9268)
42673.94409703878 42673.94409703878 -287.33638112160486 0.13621789372501697 tensor(205.5529)
-286.137552571873 0.0 -286.137552571873 0.12516057304832987 tensor(181.7403)
36508.47572104824 36508.47572104824 -287.37130530352044 0.17573850386569517 tensor(190.7566)
-286.61299861148495 0.0 -286.61299861148495 0.19579955516300773 tensor(212.6949)
26390.808983115578 26390.808983115578 -287.7990417212087 0.07764392460491559 tensor(162.3233)
-286.0715727385889 0.0 -286.0715727385889 0.2342834383202294 tensor(171.3243)
53327.87200412627 53327.87200412627 -287.2365082711774 0.1552429499607073 tensor(224.8826)
-286.6631815370057 0.0 -286.6631815370057 0.17243420640911467 tensor(166.8250)
30738.7720376419 30738.7720376419 -287.684371391425 0.18899359095382198 tensor(174.7737)
-286.44360914823034 0.0 -286.44360914823034 0.031237464151758675 tensor(200.6415)
45609.89175685283 45609.89175685283 -287.47102205451506 0.140361

-286.79296850267394 0.0 -286.79296850267394 0.19348807319933753 tensor(196.8218)
31341.43033944954 31341.43033944954 -287.5824977046474 0.10536654463340854 tensor(176.8835)
-285.9632298073358 0.0 -285.9632298073358 0.07833071108448085 tensor(191.6170)
31737.807431841688 31737.807431841688 -287.3611050497205 0.23085085291971794 tensor(177.9197)
-285.8109582028708 0.0 -285.8109582028708 0.027463723952055596 tensor(182.2287)
36485.5763231229 36485.5763231229 -287.35652949505925 0.15765525921140866 tensor(190.6938)
-286.27116976124637 0.0 -286.27116976124637 0.3221139845999118 tensor(201.3813)
27782.214976860676 27782.214976860676 -287.1734248250954 0.06173815759191047 tensor(166.4287)
-286.27748108117595 0.0 -286.27748108117595 0.31798676961827826 tensor(201.4387)
49317.4147689144 49317.4147689144 -287.0326686383324 0.20439823830244858 tensor(221.1722)
-286.4122715389313 0.0 -286.4122715389313 0.21890462467772417 tensor(186.2884)
42321.65500986652 42321.65500986652 -287.2384996215387 0.08

39210.02350242318 39210.02350242318 -287.6695826325291 0.27356612072676656 tensor(197.9065)
-285.7981433249939 0.0 -285.7981433249939 0.16585147772563394 tensor(255.2066)
55649.9282890732 55649.9282890732 -286.78489621990127 0.254690854443281 tensor(235.4638)
== Era 9 | Epoch 0 metrics ==
	loss 18215.8
	force 18359
	dkl -286.768
	logp 85.373
	logq -201.395
	ess 0.174072
-286.30325503827373 0.0 -286.30325503827373 0.03146893137159631 tensor(157.7512)
41537.63011911955 41537.63011911955 -287.3079504167587 0.16526018926782796 tensor(203.5865)
-286.4902241704207 0.0 -286.4902241704207 0.23878832786003606 tensor(235.4229)
50562.77908111176 50562.77908111176 -287.1716450249444 0.18346941388190582 tensor(224.3691)
-286.26361473988004 0.0 -286.26361473988004 0.14437712275654443 tensor(186.7813)
64993.88393393259 64993.88393393259 -286.5587908664894 0.2591140299459224 tensor(243.8630)
-286.02087010326596 0.0 -286.02087010326596 0.09971441904731501 tensor(202.8669)
27604.310587212436 27604.31058

25746.36169843235 25746.36169843235 -287.2484758276016 0.373895050656021 tensor(160.3458)
-286.712312258325 0.0 -286.712312258325 0.11875125104990862 tensor(180.0188)
36850.968136521115 36850.968136521115 -287.3796565195601 0.07393183349814654 tensor(191.7631)
-286.2694054412454 0.0 -286.2694054412454 0.21447021784109124 tensor(204.8493)
29654.170343633254 29654.170343633254 -287.6464203264579 0.23469756422480018 tensor(172.0753)
-286.38845607103383 0.0 -286.38845607103383 0.1560179989112319 tensor(207.2412)
62920.564336726435 62920.564336726435 -287.3375330934416 0.2244569457149788 tensor(249.9557)
-286.27579895673057 0.0 -286.27579895673057 0.12700357631360154 tensor(176.9656)
29848.355580832 29848.355580832 -287.14108285463203 0.12494974027693473 tensor(172.5053)
-285.92783684029166 0.0 -285.92783684029166 0.23943697686574533 tensor(188.6402)
29715.16796664685 29715.16796664685 -287.2579958187305 0.22036970083795437 tensor(172.1306)
-286.5256944875374 0.0 -286.5256944875374 0.056236

-286.04729253927485 0.0 -286.04729253927485 0.14588488651847623 tensor(219.8661)
71274.64995170795 71274.64995170795 -287.60906938879054 0.10463679773317969 tensor(263.5270)
-286.1804407824062 0.0 -286.1804407824062 0.22046740814903004 tensor(142.8677)
31684.28190448216 31684.28190448216 -287.5192731156961 0.10059094230338218 tensor(177.9007)
-286.4349524231052 0.0 -286.4349524231052 0.15850758652430957 tensor(210.2154)
39606.21084388472 39606.21084388472 -286.836283606641 0.2560664678509775 tensor(198.7780)
-286.3793708344189 0.0 -286.3793708344189 0.06251126071215006 tensor(204.4644)
46263.70511826304 46263.70511826304 -286.9551933002819 0.09694581754069422 tensor(214.1238)
-286.299130801615 0.0 -286.299130801615 0.14051959049435053 tensor(213.2960)
32273.74033850515 32273.74033850515 -287.0556131740846 0.22731382138810036 tensor(179.4479)
-286.4931108070857 0.0 -286.4931108070857 0.3055842500500312 tensor(185.2134)
29106.969331044256 29106.969331044256 -287.1909607774686 0.308029304

In [56]:
def test_force(x = None):
    model = flow_model
    layers, prior = model['layers'], model['prior']
    if x == None:
        pre_model = pre_flow_model
        pre_layers, pre_prior = pre_model['layers'], pre_model['prior']
        pre_xi = pre_prior.sample_n(1)
        x = ft_flow(pre_layers, pre_xi)
    xi = ft_flow_inv(layers, x)
    f = ft_force(param, layers, xi)
    f_s = torch.linalg.norm(f)
    print(f_s)

test_force()
test_force(field_run)

tensor(13.9890)
tensor(27.7164)


In [55]:
field_run = run(param, field_run[0])
field_run = torch.reshape(field_run,(1,)+field_run.shape)

latsize = (8, 8)
volume = 64
beta = 2.0
trajs = 4
tau = 0.3
steps = 8
seed = 1331
nth = 2
nth_interop = 2
Initial configuration:  plaq: 0.7437021844970354  topo: -2.0
plaq(x) 0.7437021844970354  force.norm 17.277193617757607
Traj:    1  ACCEPT:  dH: -0.0023679755  exp(-dH):  1.0023708    plaq:  0.74241928   topo: -2.0
plaq(x) 0.7424192800643952  force.norm 17.589645411352034
Traj:    2  ACCEPT:  dH:  0.0061712245  exp(-dH):  0.99384778   plaq:  0.74241454   topo: -2.0
plaq(x) 0.742414535905497  force.norm 17.248077164748352
Traj:    3  ACCEPT:  dH: -0.024960518  exp(-dH):  1.0252746    plaq:  0.67577322   topo: -2.0
plaq(x) 0.6757732248925473  force.norm 20.3000956970994
Traj:    4  ACCEPT:  dH:  0.010877182  exp(-dH):  0.98918176   plaq:  0.70781983   topo: -2.0
plaq(x) 0.7078198267301161  force.norm 18.939517607204863
Traj:    5  ACCEPT:  dH:  0.016930222  exp(-dH):  0.98321229   plaq:  0.74959694   topo: -2.0
plaq(x) 0.7495969360621778  force.norm 17.578089619545146
Traj:    6  ACCE

In [57]:
flows = flow

field = torch.clone(field_run)

print(f'plaq(field) {action(param, field[0]) / (-param.beta*param.volume)}')
# field.requires_grad_(True)
x = field
logJ = 0.0
for layer in reversed(flows):
    x, lJ = layer.reverse(x)
    logJ += lJ

# x is the prior distribution now
    
x.requires_grad_(True)
    
y = x
logJy = 0.0
for layer in flows:
    y, lJ = layer.forward(y)
    logJy += lJ
    
s = action(param, y[0]) - logJy

print(logJ,logJy)


# print("eff_action", s + 136.3786)

print("original_action", action(param, y[0]) + 91)

print("eff_action", s + 56)

s.backward()

f = x.grad

x.requires_grad_(False)

print(f'plaq(x) {action(param, x[0]) / (-param.beta*param.volume)}  logJ {logJ}  force.norm {torch.linalg.norm(f)}')

print(f'plaq(y) {action(param, y[0]) / (-param.beta*param.volume)}')

print(f'plaq(x) {action(param, field_run[0]) / (-param.beta*param.volume)}  force.norm {torch.linalg.norm(force(param, field_run[0]))}')


plaq(field) 0.6650769360835047
tensor([31.3118], grad_fn=<AddBackward0>) tensor([-31.3118], grad_fn=<AddBackward0>)
original_action tensor(5.8702, grad_fn=<AddBackward0>)
eff_action tensor([2.1819], grad_fn=<AddBackward0>)
plaq(x) -0.005778317209685344  logJ tensor([31.3118], grad_fn=<AddBackward0>)  force.norm 27.7163554152893
plaq(y) 0.6650768830501995
plaq(x) 0.6650769360835047  force.norm 16.994930571703513


In [58]:
print(x.shape)
x = ft_flow_inv(flow, field_run)
# x = field_run
#for layer in reversed(flows):
#    x, lJ = layer.reverse(x)
ff = ft_force(param, flow, x)
print(torch.linalg.norm(ff))
fff = ft_force(param, flow, x)
print(torch.linalg.norm(fff))

torch.Size([1, 2, 8, 8])
tensor(27.7164)
tensor(27.7164)


In [59]:
x = ft_flow_inv(flow, field_run)
ft_action(param, flow, x)

tensor([-53.8181], grad_fn=<SubBackward0>)

In [60]:
def ft_leapfrog(param, flow, x, p):
    dt = param.dt
    x_ = x + 0.5*dt*p
    f = ft_force(param, flow, x_)
    p_ = p + (-dt)*f
    print(f'force.norm {torch.linalg.norm(f)} ft_action {float(ft_action(param, flow, x_).detach())} pp_action {0.5*torch.sum(p_*p_)}')
    for i in range(param.nstep-1):
        x_ = x_ + dt*p_
        f = ft_force(param, flow, x_)
        # print(f'force.norm {torch.linalg.norm(f)} ft_action {float(ft_action(param, flow, x_).detach())} pp_action {0.5*torch.sum(p_*p_)}')
        p_ = p_ + (-dt)*f
    x_ = x_ + 0.5*dt*p_
    return (x_, p_)

def ft_hmc(param, flow, field):
    x = ft_flow_inv(flow, field)
    p = torch.randn_like(x)
    act0 = ft_action(param, flow, x).detach() + 0.5*torch.sum(p*p)
    x_, p_ = ft_leapfrog(param, flow, x, p)
    xr = regularize(x_)
    act = ft_action(param, flow, xr).detach() + 0.5*torch.sum(p_*p_)
    prob = torch.rand([], dtype=torch.float64)
    dH = act-act0
    exp_mdH = torch.exp(-dH)
    acc = prob < exp_mdH
    # ADJUST ME
    newx = xr if acc else x
    # newx = xr
    newfield = ft_flow(flow, newx)
    return (float(dH), float(exp_mdH), acc, newfield)

In [61]:
def ft_run(param, flow, field = param.initializer()):
    with open(param.uniquestr(), "w") as O:
        params = param.summary()
        O.write(params)
        put(params)
        plaq, topo = (action(param, field) / (-param.beta*param.volume), topocharge(field))
        status = f"Initial configuration:  plaq: {plaq}  topo: {topo}\n"
        O.write(status)
        put(status)
        ts = []
        for n in range(param.nrun):
            t = -timer()
            for i in range(param.ntraj):
                field_run = torch.reshape(field,(1,)+field.shape)
                dH, exp_mdH, acc, field_run = ft_hmc(param, flow, field_run)
                field = field_run[0]
                plaq = action(param, field) / (-param.beta*param.volume)
                topo = topocharge(field)
                ifacc = "ACCEPT" if acc else "REJECT"
                status = f"Traj: {n*param.ntraj+i+1:4}  {ifacc}:  dH: {dH:< 12.8}  exp(-dH): {exp_mdH:< 12.8}  plaq: {plaq:< 12.8}  topo: {topo:< 3.3}\n"
                O.write(status)
                if (i+1) % (param.ntraj//param.nprint) == 0:
                    put(status)
            t += timer()
            ts.append(t)
        print("Run times: ", ts)
        print("Per trajectory: ", [t/param.ntraj for t in ts])
    return field

In [62]:
param = Param(
    beta = 2.0,
    lat = (8, 8),
    tau = 0.3, # 0.3
    nstep = 8, # 3
    # ADJUST ME
    ntraj = 4, # 2**16 # 2**10 # 2**15
    #
    nprint = 4,
    seed = 1331)

field_run = ft_run(param, pre_flow, field_run[0])
field_run = torch.reshape(field_run,(1,)+field_run.shape)

latsize = (8, 8)
volume = 64
beta = 2.0
trajs = 4
tau = 0.3
steps = 8
seed = 1331
nth = 2
nth_interop = 2
Initial configuration:  plaq: 0.6650769360835047  topo: -2.0
force.norm 43.14710579583169 ft_action -52.65360514082754 pp_action 54.40667637595364
Traj:    1  ACCEPT:  dH:  1.2235686    exp(-dH):  0.29417849   plaq:  0.75765287   topo: -1.0
force.norm 11.452935560391936 ft_action -54.717053159258846 pp_action 63.849036746279836
Traj:    2  REJECT:  dH:  6.4956287    exp(-dH):  0.0015100256  plaq:  0.75765285   topo: -1.0
force.norm 9.706801843716663 ft_action -54.253345921099644 pp_action 71.39910043057215
Traj:    3  REJECT:  dH:  21.778267    exp(-dH):  3.4819198e-10  plaq:  0.75765294   topo: -1.0
force.norm 7.322195729744178 ft_action -54.43555353066791 pp_action 78.34502061205053
Traj:    4  REJECT:  dH:  0.90184084   exp(-dH):  0.40582192   plaq:  0.757653     topo: -1.0
force.norm 7.940512719574733 ft_action -54.46198155193571 pp_action 65.98772651493746
Traj:    5  REJECT: 

In [63]:
param = Param(
    beta = 2.0,
    lat = (8, 8),
    tau = 0.3, # 0.3
    nstep = 8, # 3
    # ADJUST ME
    ntraj = 4, # 2**16 # 2**10 # 2**15
    #
    nprint = 4,
    seed = 1331)

field_run = ft_run(param, flow, field_run[0])
field_run = torch.reshape(field_run,(1,)+field_run.shape)

latsize = (8, 8)
volume = 64
beta = 2.0
trajs = 4
tau = 0.3
steps = 8
seed = 1331
nth = 2
nth_interop = 2
Initial configuration:  plaq: 0.7219939947428274  topo: -1.0
force.norm 12.850247301324904 ft_action -54.674058705928616 pp_action 69.11325969070492
Traj:    1  ACCEPT:  dH: -0.17963889   exp(-dH):  1.1967851    plaq:  0.7063525    topo:  1.0
force.norm 10.647557474246932 ft_action -54.82313721225872 pp_action 54.83069931193151
Traj:    2  ACCEPT:  dH:  0.45601364   exp(-dH):  0.63380519   plaq:  0.72654844   topo:  1.0
force.norm 14.022019960942641 ft_action -54.31656866008814 pp_action 74.5318041795674
Traj:    3  ACCEPT:  dH:  0.69681575   exp(-dH):  0.49816908   plaq:  0.70065655   topo: -1.0
force.norm 8.411214509645118 ft_action -56.278041687291264 pp_action 51.900920731963865
Traj:    4  ACCEPT:  dH: -0.078952832  exp(-dH):  1.0821533    plaq:  0.73216541   topo:  0.0
force.norm 12.640280095559083 ft_action -55.573268976040985 pp_action 57.668959452647144
Traj:    5  ACCEPT: