In [None]:
# DONE

# FT-HMC implemented for 8x8 2D QED (using SiLU as activation function).

# Try to minimize size of the force in training. No significant improvements.

# Some test on ergodicity
# (calculate the probablity of generating the configs obtained via conventional HMC).

# TODO

# Plot the force size distribution
# Is the large force from the original action or Field-Transformation the determinant?
# If from the determinant, then the fermion force won't cause problem for HMC

# Use the same Field-Transformation for larger system (say 16x16, 32x32, 64x64, etc)
# Study how the delta H depends on the system size ( perhaps delta H ~ sqrt(volume) )

# Study the auto-correlation for observables, topo, plaq, flowed plaq, etc.

# Improving the Field-Transformation to reduce force.

In [1]:
import torch
import math
import sys
import os
from timeit import default_timer as timer
from functools import reduce
from field_transformation import *

In [2]:
# From Xiao-Yong

class Param:
    def __init__(self, beta = 6.0, lat = [64, 64], tau = 2.0, nstep = 50, ntraj = 256, nrun = 4, nprint = 256, seed = 11*13, randinit = False, nth = int(os.environ.get('OMP_NUM_THREADS', '2')), nth_interop = 2):
        self.beta = beta
        self.lat = lat
        self.nd = len(lat)
        self.volume = reduce(lambda x,y:x*y, lat)
        self.tau = tau
        self.nstep = nstep
        self.dt = self.tau / self.nstep
        self.ntraj = ntraj
        self.nrun = nrun
        self.nprint = nprint
        self.seed = seed
        self.randinit = randinit
        self.nth = nth
        self.nth_interop = nth_interop
    def initializer(self):
        if self.randinit:
            return torch.empty((param.nd,) + param.lat).uniform_(-math.pi, math.pi)
        else:
            return torch.zeros((param.nd,) + param.lat)
    def summary(self):
        return f"""latsize = {self.lat}
volume = {self.volume}
beta = {self.beta}
trajs = {self.ntraj}
tau = {self.tau}
steps = {self.nstep}
seed = {self.seed}
nth = {self.nth}
nth_interop = {self.nth_interop}
"""
    def uniquestr(self):
        lat = ".".join(str(x) for x in self.lat)
        return f"out_l{lat}_b{param.beta}_n{param.ntraj}_t{param.tau}_s{param.nstep}.out"

def action(param, f):
    return (-param.beta)*torch.sum(torch.cos(plaqphase(f)))

def force(param, f):
    f.requires_grad_(True)
    s = action(param, f)
    f.grad = None
    s.backward()
    ff = f.grad
    f.requires_grad_(False)
    return ff

plaqphase = lambda f: f[0,:] - f[1,:] - torch.roll(f[0,:], shifts=-1, dims=1) + torch.roll(f[1,:], shifts=-1, dims=0)
topocharge = lambda f: torch.floor(0.1 + torch.sum(regularize(plaqphase(f))) / (2*math.pi))
def regularize(f):
    p2 = 2*math.pi
    f_ = (f - math.pi) / p2
    return p2*(f_ - torch.floor(f_) - 0.5)

def leapfrog(param, x, p):
    dt = param.dt
    x_ = x + 0.5*dt*p
    f = force(param, x_)
    p_ = p + (-dt)*f
    print(f'plaq(x) {action(param, x) / (-param.beta*param.volume)}  force.norm {torch.linalg.norm(f)}')
    for i in range(param.nstep-1):
        x_ = x_ + dt*p_
        p_ = p_ + (-dt)*force(param, x_)
    x_ = x_ + 0.5*dt*p_
    return (x_, p_)
def hmc(param, x):
    p = torch.randn_like(x)
    act0 = action(param, x) + 0.5*torch.sum(p*p)
    x_, p_ = leapfrog(param, x, p)
    xr = regularize(x_)
    act = action(param, xr) + 0.5*torch.sum(p_*p_)
    prob = torch.rand([], dtype=torch.float64)
    dH = act-act0
    exp_mdH = torch.exp(-dH)
    acc = prob < exp_mdH
    newx = xr if acc else x
    return (dH, exp_mdH, acc, newx)

put = lambda s: sys.stdout.write(s)

In [3]:
param = Param(
    beta = 2.0,
    lat = (8, 8),
    tau = 2, # 0.3
    nstep = 8, # 3
    # ADJUST ME
    ntraj = 2, # 2**16 # 2**10 # 2**15
    #
    nprint = 2,
    seed = 1331)

In [4]:
torch.manual_seed(param.seed)

torch.set_num_threads(param.nth)
torch.set_num_interop_threads(param.nth_interop)
os.environ["OMP_NUM_THREADS"] = str(param.nth)
os.environ["KMP_BLOCKTIME"] = "0"
os.environ["KMP_SETTINGS"] = "1"
os.environ["KMP_AFFINITY"]= "granularity=fine,verbose,compact,1,0"

torch.set_default_tensor_type(torch.DoubleTensor)

In [132]:
def run(param, field = param.initializer()):
    with open(param.uniquestr(), "w") as O:
        params = param.summary()
        O.write(params)
        put(params)
        plaq, topo = (action(param, field) / (-param.beta*param.volume), topocharge(field))
        status = f"Initial configuration:  plaq: {plaq}  topo: {topo}\n"
        O.write(status)
        put(status)
        ts = []
        for n in range(param.nrun):
            t = -timer()
            for i in range(param.ntraj):
                dH, exp_mdH, acc, field = hmc(param, field)
                plaq = action(param, field) / (-param.beta*param.volume)
                topo = topocharge(field)
                ifacc = "ACCEPT" if acc else "REJECT"
                status = f"Traj: {n*param.ntraj+i+1:4}  {ifacc}:  dH: {dH:< 12.8}  exp(-dH): {exp_mdH:< 12.8}  plaq: {plaq:< 12.8}  topo: {topo:< 3.3}\n"
                O.write(status)
                if (i+1) % (param.ntraj//param.nprint) == 0:
                    put(status)
            t += timer()
            ts.append(t)
        print("Run times: ", ts)
        print("Per trajectory: ", [t/param.ntraj for t in ts])
    return field
field = run(param)
field_run = torch.reshape(field,(1,)+field.shape)

latsize = (8, 8)
volume = 64
beta = 2.0
trajs = 4
tau = 0.5
steps = 64
seed = 1331
nth = 2
nth_interop = 2
Initial configuration:  plaq: 1.0  topo: 0.0
plaq(x) 1.0  force.norm 0.2414368926997877
Traj:    1  ACCEPT:  dH: -0.0021568039  exp(-dH):  1.0021591    plaq:  0.84102958   topo:  0.0
plaq(x) 0.8410295798623219  force.norm 15.417172564787883
Traj:    2  ACCEPT:  dH: -0.0015585396  exp(-dH):  1.0015598    plaq:  0.73777785   topo:  0.0
plaq(x) 0.7377778477217429  force.norm 20.79966906732188
Traj:    3  ACCEPT:  dH:  0.00036747194  exp(-dH):  0.9996326    plaq:  0.73153926   topo:  0.0
plaq(x) 0.7315392570615669  force.norm 19.07617281780043
Traj:    4  ACCEPT:  dH:  0.00095658134  exp(-dH):  0.99904388   plaq:  0.80278571   topo:  0.0
plaq(x) 0.8027857055412537  force.norm 16.828897482626882
Traj:    5  ACCEPT:  dH: -0.0011914106  exp(-dH):  1.0011921    plaq:  0.62739512   topo:  0.0
plaq(x) 0.6273951210698728  force.norm 19.693776271513325
Traj:    6  ACCEPT:  dH:  0.00091658616 

In [6]:
def ft_flow(flow, f):
    for layer in flow:
        f, lJ = layer.forward(f)
    return f.detach()

def ft_flow_inv(flow, f):
    for layer in reversed(flow):
        f, lJ = layer.reverse(f)
    return f.detach()

def ft_action(param, flow, f):
    y = f
    logJy = 0.0
    for layer in flow:
        y, lJ = layer.forward(y)
        logJy += lJ
    action = U1GaugeAction(param.beta)
    s = action(y) - logJy
    return s

def ft_force(param, flow, field, create_graph = False):
    # f is the field follows the transformed distribution (close to prior distribution)
    f = field
    f.requires_grad_(True)
    s = ft_action(param, flow, f)
    ss = torch.sum(s)
    # f.grad = None
    ff, = torch.autograd.grad(ss, f, create_graph = create_graph)
    f.requires_grad_(False)
    return ff

In [7]:
def train_step(model, action, optimizer, metrics, batch_size, with_force = False, pre_model = None):
    layers, prior = model['layers'], model['prior']
    optimizer.zero_grad()
    #
    xi = None
    if pre_model != None:
        pre_layers, pre_prior = pre_model['layers'], pre_model['prior']
        pre_xi = pre_prior.sample_n(batch_size)
        x = ft_flow(pre_layers, pre_xi)
        xi = ft_flow_inv(layers, x)
    #
    xi, x, logq = apply_flow_to_prior(prior, layers, batch_size=batch_size, xi=xi)
    logp = -action(x)
    #
    force_size = torch.tensor(0.0)
    dkl = calc_dkl(logp, logq)
    loss = torch.tensor(0.0)
    if with_force:
        assert pre_model != None
        force = ft_force(param, layers, xi, True)
        force_size = torch.sum(torch.square(force))
        loss = force_size
    else:
        loss = dkl
    #
    loss.backward()
    #
    # minimization target
    # loss mini
    # -> (logq - logp) mini
    # -> (action - logJ) mini
    #
    optimizer.step()
    ess = compute_ess(logp, logq)
    #
    print(grab(loss),
          grab(force_size),
          grab(dkl),
          grab(ess),
          torch.linalg.norm(ft_force(param, layers, xi)))
    #
    metrics['loss'].append(grab(loss))
    metrics['force'].append(grab(force_size))
    metrics['dkl'].append(grab(dkl))
    metrics['logp'].append(grab(logp))
    metrics['logq'].append(grab(logq))
    metrics['ess'].append(grab(ess))

def flow_train(param, with_force = False, pre_model = None):  # packaged from original ipynb by Xiao-Yong Jin
    # Theory
    lattice_shape = param.lat
    link_shape = (2,*param.lat)
    beta = param.beta
    u1_action = U1GaugeAction(beta)
    # Model
    prior = MultivariateUniform(torch.zeros(link_shape), 2*np.pi*torch.ones(link_shape))
    #
    n_layers = 24
    n_s_nets = 2
    hidden_sizes = [8,8]
    kernel_size = 3
    layers = make_u1_equiv_layers(lattice_shape=lattice_shape, n_layers=n_layers, n_mixture_comps=n_s_nets,
                                  hidden_sizes=hidden_sizes, kernel_size=kernel_size)
    set_weights(layers)
    model = {'layers': layers, 'prior': prior}
    # Training
    base_lr = .001
    optimizer = torch.optim.Adam(model['layers'].parameters(), lr=base_lr)
    optimizer_wf = torch.optim.Adam(model['layers'].parameters(), lr=base_lr / 100.0)
    #
    # ADJUST ME
    N_era = 10
    N_epoch = 100
    #
    batch_size = 64
    print_freq = N_epoch # epochs
    plot_freq = 1 # epochs
    history = {
        'loss' : [],
        'force' : [],
        'dkl' : [],
        'logp' : [],
        'logq' : [],
        'ess' : []
    }
    for era in range(N_era):
        for epoch in range(N_epoch):
            train_step(model, u1_action, optimizer, history, batch_size)
            if with_force:
                train_step(model, u1_action, optimizer_wf, history, batch_size,
                           with_force = with_force, pre_model = pre_model)
            if epoch % print_freq == 0:
                print_metrics(history, print_freq, era, epoch)
    return model,u1_action

def flow_eval(model, u1_action):  # packaged from original ipynb by Xiao-Yong Jin
    ensemble_size = 1024
    u1_ens = make_mcmc_ensemble(model, u1_action, 64, ensemble_size)
    print("Accept rate:", np.mean(u1_ens['accepted']))
    Q = grab(topo_charge(torch.stack(u1_ens['x'], axis=0)))
    X_mean, X_err = bootstrap(Q**2, Nboot=100, binsize=16)
    print(f'Topological susceptibility = {X_mean:.2f} +/- {X_err:.2f}')
    print(f'... vs HMC estimate = 1.23 +/- 0.02')

In [8]:
pre_flow_model, flow_act = flow_train(param)
flow_eval(pre_flow_model,flow_act)
pre_flow = pre_flow_model['layers']

-235.80379178382825 0.0 -235.80379178382825 0.015742986900554163 tensor(173.9276)
== Era 0 | Epoch 0 metrics ==
	loss -235.804
	force 0
	dkl -235.804
	logp 0.693854
	logq -235.11
	ess 0.015743
-240.22592204900425 0.0 -240.22592204900425 0.01726504117436026 tensor(169.3064)
-241.73488535507386 0.0 -241.73488535507386 0.016400911777968756 tensor(167.0606)
-247.46169379805445 0.0 -247.46169379805445 0.015810153849540865 tensor(160.3539)
-253.25240939507336 0.0 -253.25240939507336 0.01941417680419557 tensor(158.7390)
-256.7516621657513 0.0 -256.7516621657513 0.015804690643426364 tensor(154.5671)
-257.21331835493976 0.0 -257.21331835493976 0.01562887239859159 tensor(157.1432)
-261.58546505634735 0.0 -261.58546505634735 0.020297799438260253 tensor(151.7423)
-265.36695333901866 0.0 -265.36695333901866 0.0156412599546812 tensor(151.7542)
-266.2862499543668 0.0 -266.2862499543668 0.016146700456890064 tensor(155.4196)
-270.4000967148002 0.0 -270.4000967148002 0.01831122343147412 tensor(156.1291)

-282.30318004954523 0.0 -282.30318004954523 0.03519414026309184 tensor(173.6366)
-281.71068933998566 0.0 -281.71068933998566 0.044429059801842655 tensor(189.9906)
-282.02463668393204 0.0 -282.02463668393204 0.028696246313953674 tensor(171.9115)
-283.132869239756 0.0 -283.132869239756 0.019255968703855045 tensor(165.2447)
-282.0692816718184 0.0 -282.0692816718184 0.023057457037706616 tensor(170.5575)
-282.6374101397821 0.0 -282.6374101397821 0.03463303707863622 tensor(174.0301)
-282.6942481521188 0.0 -282.6942481521188 0.07775488816440008 tensor(182.1365)
-282.8336707077077 0.0 -282.8336707077077 0.11739860957901044 tensor(173.4950)
-282.4292215109003 0.0 -282.4292215109003 0.04353524233107781 tensor(174.8651)
-282.6392043268243 0.0 -282.6392043268243 0.057467884980354454 tensor(174.7497)
-282.52266144131096 0.0 -282.52266144131096 0.1869519973917699 tensor(175.9306)
-282.3960594537949 0.0 -282.3960594537949 0.10639430865819953 tensor(175.4130)
-283.18745889396143 0.0 -283.1874588939614

-285.4745453428999 0.0 -285.4745453428999 0.08551370972478504 tensor(159.2134)
-285.74109108264855 0.0 -285.74109108264855 0.09694266245844649 tensor(166.2664)
-285.18165772424607 0.0 -285.18165772424607 0.18581720592138307 tensor(160.6161)
-284.9456329769272 0.0 -284.9456329769272 0.037887675923891334 tensor(165.5799)
-285.13679041170724 0.0 -285.13679041170724 0.0709500809313619 tensor(161.8093)
-284.8899231454957 0.0 -284.8899231454957 0.1292122816474257 tensor(175.4481)
-285.48417859246416 0.0 -285.48417859246416 0.09883028991631716 tensor(170.3434)
-285.2071173702128 0.0 -285.2071173702128 0.16868826186902278 tensor(159.7044)
-285.5761087567977 0.0 -285.5761087567977 0.10562976216060008 tensor(153.3455)
-284.95738302568657 0.0 -284.95738302568657 0.1613954717331739 tensor(166.5366)
-285.52831036667357 0.0 -285.52831036667357 0.06433101101357216 tensor(157.1581)
-285.53802653925254 0.0 -285.53802653925254 0.08475295738354054 tensor(168.5715)
-285.5215177825453 0.0 -285.521517782545

-285.21536683697167 0.0 -285.21536683697167 0.2189449581050638 tensor(171.8966)
-285.48660306425217 0.0 -285.48660306425217 0.11321807829973943 tensor(187.9403)
-285.5833949046988 0.0 -285.5833949046988 0.06373153825501945 tensor(158.9213)
-285.48727051687985 0.0 -285.48727051687985 0.08399532584933865 tensor(182.5070)
-285.54272284288584 0.0 -285.54272284288584 0.049358031771688855 tensor(175.0742)
-285.1236794662193 0.0 -285.1236794662193 0.1780201144879282 tensor(176.3286)
-285.84530449928945 0.0 -285.84530449928945 0.1956267839023526 tensor(155.9885)
-286.1826945515368 0.0 -286.1826945515368 0.20304134682894795 tensor(165.7703)
-285.68631223265453 0.0 -285.68631223265453 0.2116240683831535 tensor(173.1717)
-285.77136434638396 0.0 -285.77136434638396 0.23513026087984623 tensor(168.5351)
-285.317134575164 0.0 -285.317134575164 0.09205607013542556 tensor(185.0386)
-285.3653766845781 0.0 -285.3653766845781 0.07107531753230847 tensor(166.2538)
-285.29057236143143 0.0 -285.29057236143143

-286.2298298684607 0.0 -286.2298298684607 0.20246256551853137 tensor(171.2551)
-286.18202929008567 0.0 -286.18202929008567 0.15544772488882455 tensor(171.7185)
-286.30195926420214 0.0 -286.30195926420214 0.20824601214829314 tensor(164.9257)
-286.8255649181762 0.0 -286.8255649181762 0.29026847375326403 tensor(173.1734)
-286.41775810069305 0.0 -286.41775810069305 0.10261035785947044 tensor(182.5192)
-286.356919937346 0.0 -286.356919937346 0.057770890006515106 tensor(164.2266)
-286.17847129446477 0.0 -286.17847129446477 0.09900306937627477 tensor(172.7846)
-286.4338164251212 0.0 -286.4338164251212 0.0916416689717375 tensor(175.3758)
-286.51255808334076 0.0 -286.51255808334076 0.08394243379648351 tensor(184.6040)
-286.49338604286265 0.0 -286.49338604286265 0.15191005068858007 tensor(171.0245)
-286.1197153753396 0.0 -286.1197153753396 0.0958604712796014 tensor(221.3763)
-286.27289834545667 0.0 -286.27289834545667 0.1455096301059904 tensor(202.0609)
-286.68317381021313 0.0 -286.6831738102131

-286.1030487674826 0.0 -286.1030487674826 0.14757232188571232 tensor(177.4414)
-286.79881426928785 0.0 -286.79881426928785 0.19925762210835296 tensor(175.7745)
-286.8895556518922 0.0 -286.8895556518922 0.20663419603924182 tensor(212.8828)
-286.73436938298016 0.0 -286.73436938298016 0.2396118599507676 tensor(198.8368)
-286.3219745816881 0.0 -286.3219745816881 0.1414377985216021 tensor(222.2222)
-286.2504391436577 0.0 -286.2504391436577 0.14604459521012364 tensor(155.7067)
-286.80813812640076 0.0 -286.80813812640076 0.2522846437389289 tensor(207.6366)
-286.7118538872095 0.0 -286.7118538872095 0.10008081763760789 tensor(187.6564)
-286.7248982663979 0.0 -286.7248982663979 0.24895956847581732 tensor(175.3379)
-286.25743717191267 0.0 -286.25743717191267 0.2283220077208355 tensor(182.6109)
-286.70812550199093 0.0 -286.70812550199093 0.12354206327816915 tensor(155.5527)
-287.02429185487904 0.0 -287.02429185487904 0.2074911297505572 tensor(183.9654)
-286.7605332923435 0.0 -286.7605332923435 0.1

-286.37424945966757 0.0 -286.37424945966757 0.27631506150752144 tensor(187.5516)
-286.87764579615214 0.0 -286.87764579615214 0.33560314012720543 tensor(196.6000)
-286.6968894805847 0.0 -286.6968894805847 0.18031224362893333 tensor(187.4604)
-286.5287983371965 0.0 -286.5287983371965 0.20499660178731807 tensor(196.2932)
-286.2364603642544 0.0 -286.2364603642544 0.11616869579152166 tensor(197.9263)
-286.1971812865029 0.0 -286.1971812865029 0.11281678713560853 tensor(253.9944)
-287.031403529598 0.0 -287.031403529598 0.08090413860815306 tensor(195.7730)
-286.32109874125695 0.0 -286.32109874125695 0.2542890375275158 tensor(172.7984)
-286.5775680322279 0.0 -286.5775680322279 0.2325613031681031 tensor(189.0569)
-286.5379995265555 0.0 -286.5379995265555 0.059621040490759954 tensor(224.7545)
-286.85048280028434 0.0 -286.85048280028434 0.07016776209457493 tensor(265.4737)
-286.8202146132348 0.0 -286.8202146132348 0.25341311544609363 tensor(286.0131)
-286.6823554320838 0.0 -286.6823554320838 0.180

-286.829989939537 0.0 -286.829989939537 0.2897993723674256 tensor(158.2422)
-286.55234764676885 0.0 -286.55234764676885 0.22451236753008907 tensor(182.5948)
-286.9383151640146 0.0 -286.9383151640146 0.14170541945339032 tensor(295.1155)
-286.9684003038342 0.0 -286.9684003038342 0.17478064276434813 tensor(227.1028)
-287.10364243493234 0.0 -287.10364243493234 0.32195787571074136 tensor(218.9781)
-286.3360277810375 0.0 -286.3360277810375 0.24192631806098625 tensor(239.8895)
-286.43053024083815 0.0 -286.43053024083815 0.2601425447803293 tensor(164.5041)
-287.04702616368945 0.0 -287.04702616368945 0.1019783067028248 tensor(257.8064)
-286.6678767203065 0.0 -286.6678767203065 0.23081487720553823 tensor(212.7097)
-286.8493991942497 0.0 -286.8493991942497 0.18800788629736237 tensor(256.1802)
-286.67042896794663 0.0 -286.67042896794663 0.3784958505739057 tensor(181.2171)
-287.34412030840406 0.0 -287.34412030840406 0.33425404121055735 tensor(247.9975)
-286.9731702036803 0.0 -286.9731702036803 0.09

-287.1013126907318 0.0 -287.1013126907318 0.21329969181483255 tensor(193.4902)
-286.75604462110755 0.0 -286.75604462110755 0.2194159025806836 tensor(201.2356)
-287.0479481087499 0.0 -287.0479481087499 0.23300483909460953 tensor(223.6758)
-286.7347007892731 0.0 -286.7347007892731 0.37456966217586135 tensor(215.4751)
-286.59911385224126 0.0 -286.59911385224126 0.08058265640562165 tensor(185.7359)
-286.9823837360541 0.0 -286.9823837360541 0.35942636246267073 tensor(177.2091)
-287.4609491473624 0.0 -287.4609491473624 0.3299042787363679 tensor(167.8750)
-286.5736689158928 0.0 -286.5736689158928 0.34961766350318535 tensor(295.1423)
-286.8294812932953 0.0 -286.8294812932953 0.30271053184380814 tensor(179.9932)
-287.1030216724084 0.0 -287.1030216724084 0.3387798203593732 tensor(164.6811)
-286.815283436428 0.0 -286.815283436428 0.31125346821440225 tensor(171.8524)
-286.8567798727716 0.0 -286.8567798727716 0.053799360710257295 tensor(168.7261)
-286.7602947419879 0.0 -286.7602947419879 0.19893814

-287.1974354358955 0.0 -287.1974354358955 0.24741621583893092 tensor(414.6770)
-286.9126605344217 0.0 -286.9126605344217 0.19326177991835528 tensor(197.4445)
-287.1525202454744 0.0 -287.1525202454744 0.15849238446212519 tensor(180.5447)
-286.99595224483573 0.0 -286.99595224483573 0.13534719086020094 tensor(221.3385)
-287.0839793721429 0.0 -287.0839793721429 0.2945747959313512 tensor(205.0207)
-286.92071639395357 0.0 -286.92071639395357 0.24238989951667803 tensor(198.5116)
-286.9655107787159 0.0 -286.9655107787159 0.3045876064012117 tensor(302.7847)
-287.01217130365694 0.0 -287.01217130365694 0.21326813705089678 tensor(141.8040)
-286.72580910681984 0.0 -286.72580910681984 0.25328739579993514 tensor(171.2045)
-286.8274692717415 0.0 -286.8274692717415 0.3325019587180933 tensor(198.5363)
-286.9714275810692 0.0 -286.9714275810692 0.22347510584246394 tensor(280.2572)
-286.91179021224457 0.0 -286.91179021224457 0.22762529192807113 tensor(431.2482)
-286.9297917601647 0.0 -286.9297917601647 0.1

In [9]:
flow_model, flow_act = flow_train(param, with_force=True, pre_model=pre_flow_model)
flow_eval(flow_model,flow_act)
flow = flow_model['layers']
# flow.eval()

-238.39049113095706 0.0 -238.39049113095706 0.020874773022601868 tensor(172.2697)
20057.912707492145 20057.912707492145 -319.6221631477296 0.03606409550273815 tensor(141.4678)
== Era 0 | Epoch 0 metrics ==
	loss 9909.76
	force 10029
	dkl -279.006
	logp 45.1304
	logq -233.876
	ess 0.0284694
-239.99995031784758 0.0 -239.99995031784758 0.016018551334671154 tensor(168.6970)
18508.217711007655 18508.217711007655 -316.4459660502251 0.04013515196336657 tensor(135.8981)
-243.95157348742728 0.0 -243.95157348742728 0.026683909450260914 tensor(165.5899)
15870.26093456487 15870.26093456487 -315.65977477143747 0.018370012111602607 tensor(125.8410)
-245.25201064574748 0.0 -245.25201064574748 0.052705307415307594 tensor(160.9198)
14462.621701326689 14462.621701326689 -311.91091435015744 0.029296820164363094 tensor(120.1388)
-251.90469794323698 0.0 -251.90469794323698 0.01562503155286332 tensor(156.8551)
12739.94778084604 12739.94778084604 -309.25396699240855 0.05697726152577854 tensor(112.7597)
-255.

19074.303911960218 19074.303911960218 -292.55488633410727 0.0866930698406297 tensor(138.0980)
-280.2719772857525 0.0 -280.2719772857525 0.03044307838361087 tensor(179.9664)
17675.65055218959 17675.65055218959 -293.0546726317576 0.03566032242706418 tensor(132.9348)
-281.2607820372472 0.0 -281.2607820372472 0.049040705419754776 tensor(166.9287)
18459.102758276156 18459.102758276156 -291.75652409496115 0.05663009784234521 tensor(135.8517)
-280.70231286811327 0.0 -280.70231286811327 0.16896579870956277 tensor(165.9623)
17493.2839483881 17493.2839483881 -291.8698132200958 0.09248786586521678 tensor(132.2484)
-280.00946562950685 0.0 -280.00946562950685 0.04202988730583567 tensor(184.7994)
19971.03077950465 19971.03077950465 -293.20921397977315 0.046751445464503126 tensor(141.3075)
-280.82711625033005 0.0 -280.82711625033005 0.025582369575553807 tensor(186.7338)
18695.943570009884 18695.943570009884 -292.13422255456567 0.046325159838130896 tensor(136.7147)
-281.3903984979996 0.0 -281.39039849

-282.21602421670696 0.0 -282.21602421670696 0.06633353701993606 tensor(177.8588)
17353.61844073992 17353.61844073992 -290.4366023840206 0.02306281185007687 tensor(131.7117)
-282.66544634543055 0.0 -282.66544634543055 0.11601004393354111 tensor(175.1146)
17905.928761597108 17905.928761597108 -291.023764986694 0.03375794109880063 tensor(133.7803)
-281.9435188496797 0.0 -281.9435188496797 0.04116358603410218 tensor(183.3974)
17929.90416840589 17929.90416840589 -291.35470171683716 0.016280301845685935 tensor(133.8678)
-282.77258779360545 0.0 -282.77258779360545 0.05102493252003544 tensor(176.2602)
21362.565025908672 21362.565025908672 -291.40344821922963 0.060298693192348214 tensor(146.1096)
-283.0361827050161 0.0 -283.0361827050161 0.04582138798409073 tensor(178.2340)
20184.261067007392 20184.261067007392 -291.0750682504206 0.08161775266269498 tensor(142.0265)
-282.7121363608883 0.0 -282.7121363608883 0.08947900223654791 tensor(173.4838)
19293.2963387733 19293.2963387733 -291.348682019932

-284.97034821845057 0.0 -284.97034821845057 0.05489626602713098 tensor(156.7565)
21747.42151630036 21747.42151630036 -289.73452293867035 0.27004309813497857 tensor(147.1258)
-284.35942794650475 0.0 -284.35942794650475 0.039209840901709744 tensor(172.6830)
24862.240926072147 24862.240926072147 -289.2757507958735 0.0883683956731514 tensor(157.5366)
-284.08686589810645 0.0 -284.08686589810645 0.10881850403625802 tensor(158.8496)
19024.76989134109 19024.76989134109 -289.5788590478145 0.06943457285789091 tensor(137.7176)
-284.05288039793027 0.0 -284.05288039793027 0.05249959113759645 tensor(163.2786)
21486.664915048288 21486.664915048288 -289.9207789709983 0.16762521689208396 tensor(146.1552)
-283.7914084989188 0.0 -283.7914084989188 0.10377805064375248 tensor(153.5166)
23754.536093235263 23754.536093235263 -289.44319126067705 0.036920420967849754 tensor(153.9685)
-283.9381614085331 0.0 -283.9381614085331 0.03293872831962461 tensor(166.1496)
19618.284835824626 19618.284835824626 -289.503012

18298.470412036724 18298.470412036724 -289.4863359918005 0.15355262066470948 tensor(135.1448)
-285.02951740536474 0.0 -285.02951740536474 0.11109000582092264 tensor(155.0115)
19580.854288265717 19580.854288265717 -288.92839256242416 0.08109112257438797 tensor(139.6969)
-285.4408052182038 0.0 -285.4408052182038 0.021615436321249445 tensor(144.4478)
16497.63152134316 16497.63152134316 -288.9393163230589 0.17641735184551338 tensor(128.2933)
-285.13051297704817 0.0 -285.13051297704817 0.0825030927473174 tensor(161.8545)
19370.655111159354 19370.655111159354 -288.62535650558254 0.25962391825012415 tensor(139.0632)
-285.0005297664599 0.0 -285.0005297664599 0.07712382991772826 tensor(152.6422)
24880.953939679843 24880.953939679843 -288.3325535085609 0.040363882630917114 tensor(157.5512)
-285.15110304257604 0.0 -285.15110304257604 0.14272318531314257 tensor(157.6855)
23318.50188031238 23318.50188031238 -288.7553229882891 0.2011232265833868 tensor(152.5372)
-285.54812235913835 0.0 -285.54812235

17991.1991049188 17991.1991049188 -288.93744747733984 0.27988902412768246 tensor(134.0252)
-285.2878355129742 0.0 -285.2878355129742 0.14425865210315303 tensor(146.0366)
23822.36436512699 23822.36436512699 -288.4192506261963 0.24517478462743597 tensor(154.0776)
-285.20803486806074 0.0 -285.20803486806074 0.12367578463774954 tensor(158.6803)
18858.785099364482 18858.785099364482 -288.6823613774401 0.27427815882881634 tensor(137.2206)
-285.20568167578773 0.0 -285.20568167578773 0.12474643861027426 tensor(151.6942)
19741.23174903314 19741.23174903314 -288.69060247458344 0.18583149152907427 tensor(140.3560)
-285.5615952259924 0.0 -285.5615952259924 0.1680074144199956 tensor(166.5738)
16504.939188366916 16504.939188366916 -289.08287134583605 0.1270856646693557 tensor(128.3706)
-285.33578134789553 0.0 -285.33578134789553 0.0926829781025436 tensor(172.0526)
17587.118422592128 17587.118422592128 -288.7417273179901 0.05958857771715249 tensor(132.4808)
-285.55356241975704 0.0 -285.55356241975704

-285.53763442104605 0.0 -285.53763442104605 0.09599160616781392 tensor(170.1124)
23754.089019382453 23754.089019382453 -288.7235253701389 0.1613040681459115 tensor(154.0080)
-285.72831831727854 0.0 -285.72831831727854 0.1235902473945449 tensor(162.9724)
17304.90739847014 17304.90739847014 -288.5763581253353 0.13929163804340786 tensor(131.4343)
-285.38829961949045 0.0 -285.38829961949045 0.12279025840028317 tensor(166.6758)
25043.492012365678 25043.492012365678 -288.93613553567616 0.17126945640334468 tensor(157.9998)
-285.0893675861591 0.0 -285.0893675861591 0.05617487453320584 tensor(170.5899)
20798.94460200907 20798.94460200907 -288.4787070112312 0.17251358817923354 tensor(144.0878)
-285.47146581930417 0.0 -285.47146581930417 0.2026030470268418 tensor(155.4018)
14236.773501892023 14236.773501892023 -288.9521894494561 0.1502202164909821 tensor(119.2451)
-285.69571716139916 0.0 -285.69571716139916 0.11213292682051056 tensor(148.7692)
20590.56081089483 20590.56081089483 -288.428688223134

-285.9727788302298 0.0 -285.9727788302298 0.12691777558267686 tensor(150.7071)
23991.12162384852 23991.12162384852 -288.1267906095826 0.2108380600256449 tensor(154.7497)
-285.4457599610498 0.0 -285.4457599610498 0.2402225155748003 tensor(149.8018)
16015.220833971094 16015.220833971094 -288.01324799823567 0.1418136373618562 tensor(126.3893)
-285.802192550592 0.0 -285.802192550592 0.15206985470940582 tensor(140.6666)
23508.491520222524 23508.491520222524 -288.58547639125135 0.10629678672506042 tensor(152.9864)
-285.72681426076537 0.0 -285.72681426076537 0.05240698969669682 tensor(150.0914)
22166.642261876383 22166.642261876383 -288.37908508291207 0.21285082213348386 tensor(148.6772)
-285.9468360498366 0.0 -285.9468360498366 0.03363963641293479 tensor(139.2536)
16341.13222246667 16341.13222246667 -288.66027524944536 0.07643104392244816 tensor(127.6184)
-286.2475127616262 0.0 -286.2475127616262 0.08669552732828305 tensor(163.2563)
22681.740753915576 22681.740753915576 -288.2428717721371 0.

-285.84248992121695 0.0 -285.84248992121695 0.23692807383514264 tensor(166.4314)
25940.0245935195 25940.0245935195 -287.88865334250136 0.07839522415993602 tensor(160.8006)
-285.6718281653317 0.0 -285.6718281653317 0.13563702080189755 tensor(144.5219)
18923.776507327777 18923.776507327777 -288.72843160717764 0.31500855275963896 tensor(137.4351)
-285.43856044614273 0.0 -285.43856044614273 0.15470151187424339 tensor(167.4265)
16128.638245916009 16128.638245916009 -288.2249567522944 0.19043288972256137 tensor(126.9576)
-286.37810940793327 0.0 -286.37810940793327 0.07451264252475999 tensor(137.5084)
24233.38618001946 24233.38618001946 -287.92780859398863 0.2143310202177221 tensor(155.5237)
-286.29609122288025 0.0 -286.29609122288025 0.19574977023178083 tensor(155.7765)
17342.47314066839 17342.47314066839 -288.20951348967355 0.12175219307197903 tensor(131.5379)
-286.453787833296 0.0 -286.453787833296 0.20209054957300882 tensor(152.7309)
19334.035184447614 19334.035184447614 -288.264521168760

-285.7891383372605 0.0 -285.7891383372605 0.06248868399890766 tensor(133.8619)
26339.413130920733 26339.413130920733 -288.3366352916199 0.303353342687243 tensor(161.6762)
-286.8091074038364 0.0 -286.8091074038364 0.22921665142266612 tensor(149.4481)
24530.77492300021 24530.77492300021 -288.3807601161063 0.06580238241256914 tensor(156.2593)
-285.9327669183945 0.0 -285.9327669183945 0.1078548915231454 tensor(169.6685)
27715.986211459338 27715.986211459338 -288.61551656932977 0.1347412227341016 tensor(165.5600)
-286.49388664194777 0.0 -286.49388664194777 0.1178020460828718 tensor(158.4133)
20650.71013084868 20650.71013084868 -288.3013639111995 0.33566176627126787 tensor(143.4453)
-286.47842033919744 0.0 -286.47842033919744 0.16643771551096587 tensor(147.2449)
23119.136522983405 23119.136522983405 -288.11441958642786 0.1964907163482189 tensor(151.9224)
-286.45441445382437 0.0 -286.45441445382437 0.23473916496289046 tensor(142.6987)
28851.966695506166 28851.966695506166 -288.3696823222349 0

18895.885852294337 18895.885852294337 -288.23468196823126 0.08652410333678036 tensor(137.3147)
-286.0357347030575 0.0 -286.0357347030575 0.06303162518883486 tensor(148.6247)
21366.226644062423 21366.226644062423 -287.8182779835371 0.15092198468198792 tensor(145.7757)
-286.4390622426534 0.0 -286.4390622426534 0.24371656082298926 tensor(155.9149)
21349.980247907366 21349.980247907366 -288.14595814573266 0.06967805892506161 tensor(145.9711)
-286.0743325279866 0.0 -286.0743325279866 0.09968182830889098 tensor(154.6899)
25933.48449976956 25933.48449976956 -288.3338789769083 0.31328546481085495 tensor(160.7440)
-286.04419974283104 0.0 -286.04419974283104 0.20265728791476284 tensor(142.2100)
28286.71770831826 28286.71770831826 -288.2900640955386 0.2806460464760222 tensor(167.6769)
-286.11302821412926 0.0 -286.11302821412926 0.15188006402789223 tensor(171.5121)
23035.563225666534 23035.563225666534 -288.02967548653373 0.1874591709867953 tensor(151.0670)
-286.0003694748567 0.0 -286.000369474856

21004.595306300325 21004.595306300325 -288.1263275555921 0.11356998347654314 tensor(144.8088)
-286.7316114054173 0.0 -286.7316114054173 0.25765855051364217 tensor(162.0850)
27486.345059527033 27486.345059527033 -288.17531488207777 0.21794707293214455 tensor(165.5880)
-286.4177362746666 0.0 -286.4177362746666 0.20290664543582615 tensor(147.9444)
21789.83114430288 21789.83114430288 -287.9287793243426 0.2521325218614762 tensor(147.4409)
-286.52861049072476 0.0 -286.52861049072476 0.3349203946601021 tensor(157.1442)
28452.793854154825 28452.793854154825 -288.07207762848446 0.2067386712385994 tensor(168.2729)
-286.36623504848234 0.0 -286.36623504848234 0.11566125312031084 tensor(161.7498)
19295.272249246584 19295.272249246584 -288.256536903667 0.25270739599905656 tensor(138.6347)
-286.50177636845643 0.0 -286.50177636845643 0.2479246441416943 tensor(135.1678)
23468.706016696782 23468.706016696782 -288.183900458903 0.22883807657119773 tensor(152.7620)
-285.98986154249246 0.0 -285.989861542492

-286.26529293621377 0.0 -286.26529293621377 0.23019251517748654 tensor(178.1651)
38051.85850573397 38051.85850573397 -288.2805918247161 0.11639097822079157 tensor(193.6200)
-286.15725198127495 0.0 -286.15725198127495 0.15217844429383787 tensor(170.9882)
24068.298719227234 24068.298719227234 -288.4126729932877 0.22047240248852862 tensor(154.9802)
-286.3998923270982 0.0 -286.3998923270982 0.08799783339216699 tensor(192.1408)
31155.793516366513 31155.793516366513 -287.7396437973993 0.1900492002142713 tensor(165.2300)
-286.18708732452467 0.0 -286.18708732452467 0.236339556678481 tensor(164.7383)
23991.054970006076 23991.054970006076 -288.3999586696394 0.24719226260307406 tensor(154.3022)
-286.34541372258235 0.0 -286.34541372258235 0.3166566301764982 tensor(198.2927)
19139.571112408343 19139.571112408343 -288.0411633728559 0.1437149402574481 tensor(137.6514)
-286.35077954088854 0.0 -286.35077954088854 0.11720489586063657 tensor(177.7323)
39351.39501644089 39351.39501644089 -288.293026246234

-286.5271002091637 0.0 -286.5271002091637 0.23421152448287902 tensor(178.8134)
28083.164372094536 28083.164372094536 -287.87097274404607 0.07771987717004385 tensor(167.3376)
-286.3084798717028 0.0 -286.3084798717028 0.09477030661128111 tensor(172.7973)
38001.2969573314 38001.2969573314 -288.4246949466288 0.15482162870590901 tensor(187.4294)
-286.7259243174136 0.0 -286.7259243174136 0.21815974982414155 tensor(157.8940)
42353.06960313262 42353.06960313262 -288.3543930421031 0.098974755739988 tensor(204.1474)
-286.2539869136395 0.0 -286.2539869136395 0.2661109860697287 tensor(182.1572)
29939.2240606571 29939.2240606571 -287.8188255271603 0.250193772985456 tensor(172.7877)
-286.17339937717384 0.0 -286.17339937717384 0.18425248530110067 tensor(157.4664)
17867.103693779423 17867.103693779423 -288.1755481739189 0.24109495289432117 tensor(133.4522)
-286.35123979665144 0.0 -286.35123979665144 0.2563819155810669 tensor(165.0757)
31639.657432629057 31639.657432629057 -287.9621558901609 0.16553006

-286.7071166408974 0.0 -286.7071166408974 0.11011083523495957 tensor(163.0576)
27047.312760188615 27047.312760188615 -288.02501604274676 0.13929618263217972 tensor(164.0448)
-286.338304731965 0.0 -286.338304731965 0.20278238759698625 tensor(188.0359)
27036.115288853376 27036.115288853376 -288.32541645284425 0.19765010778053144 tensor(164.1406)
-286.17104559687345 0.0 -286.17104559687345 0.12963896984818063 tensor(171.9740)
26198.820565770857 26198.820565770857 -288.0244239714888 0.19654968143459495 tensor(161.7284)
-286.0716673714779 0.0 -286.0716673714779 0.2589845632979061 tensor(178.3325)
38576.85609328673 38576.85609328673 -288.2725275278226 0.1139989899046807 tensor(194.8865)
-286.9411133372706 0.0 -286.9411133372706 0.26117118294645775 tensor(137.6470)
30697.201116537024 30697.201116537024 -288.1462391308553 0.06179755720175978 tensor(171.0779)
-286.56067166093294 0.0 -286.56067166093294 0.10689779818358022 tensor(152.9505)
20095.12253086365 20095.12253086365 -288.2604735350841 0

-286.1593824027167 0.0 -286.1593824027167 0.21904573652351167 tensor(173.6671)
30863.277331212194 30863.277331212194 -288.1651013641319 0.3668576948163544 tensor(175.0245)
-286.38317231351033 0.0 -286.38317231351033 0.16657214248163335 tensor(168.9308)
19397.84998823823 19397.84998823823 -288.21507011772957 0.24461906575252218 tensor(139.2582)
-286.3334340587272 0.0 -286.3334340587272 0.2061305844188384 tensor(178.3494)
22763.721442012513 22763.721442012513 -287.9668428581885 0.4178549918622683 tensor(150.7251)
-286.10314102281933 0.0 -286.10314102281933 0.14410917688624766 tensor(182.8177)
28272.26746343081 28272.26746343081 -288.16879986376694 0.24080344593533562 tensor(167.9817)
-286.4683634208486 0.0 -286.4683634208486 0.1301804100055399 tensor(164.9295)
26452.17103949813 26452.17103949813 -288.1622126304362 0.3384602958991995 tensor(162.1245)
-286.3623536913355 0.0 -286.3623536913355 0.2386075671540342 tensor(162.6375)
27730.86676605287 27730.86676605287 -287.8856646458763 0.20970

-286.46680864059454 0.0 -286.46680864059454 0.17811990241393133 tensor(161.7926)
18368.201997614706 18368.201997614706 -288.0401608694202 0.1565551691749182 tensor(135.5836)
-286.7446843528391 0.0 -286.7446843528391 0.2223876022961817 tensor(142.8008)
19159.563534029465 19159.563534029465 -288.24104121484663 0.3403819207189742 tensor(138.3662)
-286.66589698771077 0.0 -286.66589698771077 0.11926935246438929 tensor(179.9168)
19720.231479552895 19720.231479552895 -288.1468123411204 0.04931820966894226 tensor(140.3726)
-286.727269492541 0.0 -286.727269492541 0.28200011738281017 tensor(166.4664)
22862.63740444296 22862.63740444296 -288.317985704099 0.19336761602411634 tensor(151.0355)
-286.5590683289122 0.0 -286.5590683289122 0.18003751605221358 tensor(172.2588)
41292.54466266175 41292.54466266175 -288.1407576480872 0.10579771687757626 tensor(202.8666)
-286.5533418583237 0.0 -286.5533418583237 0.26665062486165286 tensor(159.1823)
22824.040861170364 22824.040861170364 -288.2968732097847 0.20

-286.75092011761325 0.0 -286.75092011761325 0.34275138305498437 tensor(199.3850)
32347.47522167852 32347.47522167852 -288.30155466823555 0.11092364189080878 tensor(179.5158)
-286.3911711164683 0.0 -286.3911711164683 0.33871452832670973 tensor(159.1175)
25206.347640718872 25206.347640718872 -288.09897223076234 0.14719876793443207 tensor(158.4368)
-286.78210547284374 0.0 -286.78210547284374 0.28313825679700616 tensor(165.7690)
27413.609543628852 27413.609543628852 -288.1080974875202 0.2751069622284433 tensor(165.4174)
-286.5621706889951 0.0 -286.5621706889951 0.1365422488626961 tensor(172.4679)
29119.780830527714 29119.780830527714 -288.35075050827845 0.2668265279494765 tensor(170.4814)
-286.66796139690706 0.0 -286.66796139690706 0.1483785267188906 tensor(168.3702)
35163.51300108803 35163.51300108803 -288.02676649191056 0.1392273799625386 tensor(186.6374)
-286.6097197198169 0.0 -286.6097197198169 0.16472176987536283 tensor(183.9521)
19589.47485530219 19589.47485530219 -287.9097502854225 

-286.51930116492264 0.0 -286.51930116492264 0.06405212793384879 tensor(215.0865)
35884.69514594404 35884.69514594404 -288.13402159526373 0.08047683269013448 tensor(188.7271)
-286.7397521021133 0.0 -286.7397521021133 0.25480351058260664 tensor(197.0777)
26266.43593771653 26266.43593771653 -287.70267171931573 0.12961092282707823 tensor(161.9964)
-286.47902637172245 0.0 -286.47902637172245 0.24913547966201666 tensor(196.4692)
54036.361713042534 54036.361713042534 -288.15589486356055 0.2764031057390767 tensor(228.6135)
-286.6446305768037 0.0 -286.6446305768037 0.17197276365846656 tensor(171.5409)
41086.12627811538 41086.12627811538 -287.68033087824193 0.22889120900044846 tensor(195.3223)
-286.4985733440152 0.0 -286.4985733440152 0.2047595873203937 tensor(362.4746)
22953.703311735724 22953.703311735724 -288.0255232021808 0.03675248769941552 tensor(151.5921)
-286.61566725413377 0.0 -286.61566725413377 0.11286276310703298 tensor(198.3934)
24342.761333641218 24342.761333641218 -288.06345955024

-286.5465946582092 0.0 -286.5465946582092 0.3106887177359616 tensor(151.6154)
37744.75130477442 37744.75130477442 -288.21555902280113 0.24859357724795048 tensor(193.9659)
== Era 9 | Epoch 0 metrics ==
	loss 18132.8
	force 18276.2
	dkl -287.392
	logp 86.0082
	logq -201.384
	ess 0.21213
-286.7994813077259 0.0 -286.7994813077259 0.28261589224075423 tensor(153.6920)
48471.05078074326 48471.05078074326 -288.0052064826268 0.1009070287250021 tensor(219.7139)
-286.85376824051775 0.0 -286.85376824051775 0.33517589586212326 tensor(165.7589)
33746.05028692072 33746.05028692072 -287.9028733455503 0.18191579912698824 tensor(183.1985)
-287.1522697165326 0.0 -287.1522697165326 0.13435406722848278 tensor(180.9764)
34521.203468984444 34521.203468984444 -288.39701317916837 0.3045218436354892 tensor(185.6068)
-286.5751603536142 0.0 -286.5751603536142 0.13374491215843706 tensor(165.2813)
30514.787916131267 30514.787916131267 -288.290567170231 0.2503197911825974 tensor(174.3634)
-286.79294477235817 0.0 -28

41728.80759001081 41728.80759001081 -288.0645146060877 0.24301370879335946 tensor(202.5762)
-286.43408676461956 0.0 -286.43408676461956 0.051264529076139105 tensor(184.1731)
54332.77085542517 54332.77085542517 -288.12421237906005 0.2584622761176537 tensor(231.8702)
-286.38283618355547 0.0 -286.38283618355547 0.11559642931067778 tensor(169.0819)
38544.32349241026 38544.32349241026 -287.92077583986855 0.13768289692728827 tensor(195.3228)
-286.606497860421 0.0 -286.606497860421 0.26143236374224615 tensor(164.2697)
43510.75198994574 43510.75198994574 -288.0788436430789 0.26509352258875724 tensor(206.2704)
-286.52392919422215 0.0 -286.52392919422215 0.26972428818003247 tensor(151.9041)
20824.026510915624 20824.026510915624 -288.14703865292046 0.2693228691417599 tensor(144.3956)
-286.51370545232976 0.0 -286.51370545232976 0.2744329472704612 tensor(144.7935)
27267.345108840553 27267.345108840553 -288.21089977840734 0.3613486120494314 tensor(160.1815)
-286.7340540479189 0.0 -286.7340540479189 

-286.4380275852351 0.0 -286.4380275852351 0.29134181944869486 tensor(195.8758)
28393.78202624868 28393.78202624868 -287.85431138416635 0.3549198184242779 tensor(168.4265)
-286.4305420022208 0.0 -286.4305420022208 0.33681136068207657 tensor(267.7959)
40644.092943764685 40644.092943764685 -288.0413815593246 0.13444518922518958 tensor(200.4770)
-286.5434480800014 0.0 -286.5434480800014 0.22563497877244196 tensor(148.5231)
32051.80340789499 32051.80340789499 -288.41377260816137 0.2061858654964539 tensor(178.8105)
-286.6824022989845 0.0 -286.6824022989845 0.31110680037134364 tensor(170.9833)
35591.507500784006 35591.507500784006 -288.1790112815279 0.2657964766796887 tensor(188.3058)
-286.8379840400337 0.0 -286.8379840400337 0.11233995482930213 tensor(173.9845)
63991.91891154476 63991.91891154476 -288.3369908918387 0.17521890147739627 tensor(249.5865)
Accept rate: 0.265625
Topological susceptibility = 0.84 +/- 0.12
... vs HMC estimate = 1.23 +/- 0.02


In [10]:
def test_force(x = None):
    model = flow_model
    layers, prior = model['layers'], model['prior']
    if x == None:
        pre_model = pre_flow_model
        pre_layers, pre_prior = pre_model['layers'], pre_model['prior']
        pre_xi = pre_prior.sample_n(1)
        x = ft_flow(pre_layers, pre_xi)
    xi = ft_flow_inv(layers, x)
    f = ft_force(param, layers, xi)
    f_s = torch.linalg.norm(f)
    print(f_s)

test_force()
test_force(field_run)

tensor(20.7829)
tensor(24.9007)


In [11]:
field_run = run(param, field_run[0])
field_run = torch.reshape(field_run,(1,)+field_run.shape)

latsize = (8, 8)
volume = 64
beta = 2.0
trajs = 2
tau = 2
steps = 8
seed = 1331
nth = 2
nth_interop = 2
Initial configuration:  plaq: 0.6627686419055754  topo: 0.0
plaq(x) 0.6627686419055754  force.norm 21.091837451367958
Traj:    1  ACCEPT:  dH: -0.49041166   exp(-dH):  1.6329883    plaq:  0.62473448   topo:  2.0
plaq(x) 0.6247344841194027  force.norm 21.8840629070871
Traj:    2  ACCEPT:  dH:  0.66073851   exp(-dH):  0.51646978   plaq:  0.66851193   topo:  0.0
plaq(x) 0.6685119259671379  force.norm 19.583396204466677
Traj:    3  REJECT:  dH:  0.62826715   exp(-dH):  0.5335155    plaq:  0.66851193   topo:  0.0
plaq(x) 0.6685119259671379  force.norm 19.540474594479193
Traj:    4  REJECT:  dH:  0.25176081   exp(-dH):  0.77743067   plaq:  0.66851193   topo:  0.0
plaq(x) 0.6685119259671379  force.norm 18.73962959721593
Traj:    5  ACCEPT:  dH:  0.26645544   exp(-dH):  0.76609014   plaq:  0.73566717   topo:  1.0
plaq(x) 0.735667166271016  force.norm 18.993473339448784
Traj:    6  ACCEPT:  d

In [133]:
flows = flow

print(f'plaq(field_run[0]) {action(param, field_run[0]) / (-param.beta*param.volume)}')
# field.requires_grad_(True)
x = field_run
logJ = 0.0
for layer in reversed(flows):
    x, lJ = layer.reverse(x)
    logJ += lJ

# x is the prior distribution now
    
x.requires_grad_(True)
    
y = x
logJy = 0.0
for layer in flows:
    y, lJ = layer.forward(y)
    logJy += lJ
    
s = action(param, y[0]) - logJy

print(logJ,logJy)


# print("eff_action", s + 136.3786)

print("original_action", action(param, y[0]) + 91)

print("eff_action", s + 56)

s.backward()

f = x.grad

x.requires_grad_(False)

print(f'plaq(x) {action(param, x[0]) / (-param.beta*param.volume)}  logJ {logJ}  force.norm {torch.linalg.norm(f)}')

print(f'plaq(y) {action(param, y[0]) / (-param.beta*param.volume)}')

print(f'plaq(x) {action(param, field_run[0]) / (-param.beta*param.volume)}  force.norm {torch.linalg.norm(force(param, field_run[0]))}')


plaq(field_run[0]) 0.5729559235604118
tensor([18.9736], grad_fn=<AddBackward0>) tensor([-18.9736], grad_fn=<AddBackward0>)
original_action tensor(17.6616, grad_fn=<AddBackward0>)
eff_action tensor([1.6352], grad_fn=<AddBackward0>)
plaq(x) -0.20566029157065324  logJ tensor([18.9736], grad_fn=<AddBackward0>)  force.norm 15.38544434986163
plaq(y) 0.5729560026167126
plaq(x) 0.5729559235604118  force.norm 20.91876473600617


In [13]:
print(x.shape)
x = ft_flow_inv(flow, field_run)
# x = field_run
#for layer in reversed(flows):
#    x, lJ = layer.reverse(x)
ff = ft_force(param, flow, x)
print(torch.linalg.norm(ff))
fff = ft_force(param, flow, x)
print(torch.linalg.norm(fff))

torch.Size([1, 2, 8, 8])
tensor(30.5185)
tensor(30.5185)


In [14]:
x = ft_flow_inv(flow, field_run)
ft_action(param, flow, x)

tensor([-54.6079], grad_fn=<SubBackward0>)

In [110]:
def flattern(l):
    return [x for y in l for x in y]

def average(l):
    return sum(l) / len(l)

def sub_avg(l):
    avg = average(l)
    return np.array([x - avg for x in l])

In [116]:
ft_hmc_info_list = []
def ft_leapfrog(param, flow, x, p):
    mom_norm = torch.sum(p*p)
    info_list = []
    dt = param.dt
    x_ = x + 0.5*dt*p
    f = ft_force(param, flow, x_)
    p_ = p + (-dt)*f
    info = np.array((float(torch.linalg.norm(f)),
                     float(ft_action(param, flow, x_).detach()),
                     float(torch.sum(p*p_)/np.sqrt(mom_norm*torch.sum(p_*p_)))))
    info_list.append(info)
    for i in range(param.nstep-1):
        x_ = x_ + dt*p_
        f = ft_force(param, flow, x_)
        info = np.array((float(torch.linalg.norm(f)),
                        float(ft_action(param, flow, x_).detach()),
                        float(torch.sum(p*p_)/np.sqrt(mom_norm*torch.sum(p_*p_)))))
        info_list.append(info)
        p_ = p_ + (-dt)*f
    x_ = x_ + 0.5*dt*p_
    print(np.sqrt(average([l[0]**2 for l in info_list])),
          (info_list[0][1], info_list[-1][1]),
          info_list[-1][2])
    ft_hmc_info_list.append(info_list)
    return (x_, p_)

def ft_hmc(param, flow, field):
    x = ft_flow_inv(flow, field)
    p = torch.randn_like(x)
    act0 = ft_action(param, flow, x).detach() + 0.5*torch.sum(p*p)
    x_, p_ = ft_leapfrog(param, flow, x, p)
    xr = regularize(x_)
    act = ft_action(param, flow, xr).detach() + 0.5*torch.sum(p_*p_)
    prob = torch.rand([], dtype=torch.float64)
    dH = act-act0
    exp_mdH = torch.exp(-dH)
    acc = prob < exp_mdH
    # ADJUST ME
    newx = xr if acc else x
    # newx = xr
    newfield = ft_flow(flow, newx)
    return (float(dH), float(exp_mdH), acc, newfield)

In [125]:
def ft_run(param, flow, field = None):
    if field == None:
        field = param.initializer()
    ft_hmc_info_list = []
    with open(param.uniquestr(), "w") as O:
        params = param.summary()
        O.write(params)
        put(params)
        plaq, topo = (action(param, field) / (-param.beta*param.volume), topocharge(field))
        status = f"Initial configuration:  plaq: {plaq}  topo: {topo}\n"
        O.write(status)
        put(status)
        ts = []
        for n in range(param.nrun):
            t = -timer()
            for i in range(param.ntraj):
                field_run = torch.reshape(field,(1,)+field.shape)
                dH, exp_mdH, acc, field_run = ft_hmc(param, flow, field_run)
                field = field_run[0]
                plaq = action(param, field) / (-param.beta*param.volume)
                topo = topocharge(field)
                ifacc = "ACCEPT" if acc else "REJECT"
                status = f"Traj: {n*param.ntraj+i+1:4}  {ifacc}:  dH: {dH:< 12.8}  exp(-dH): {exp_mdH:< 12.8}  plaq: {plaq:< 12.8}  topo: {topo:< 3.3}\n"
                O.write(status)
                if (i+1) % (param.ntraj//param.nprint) == 0:
                    put(status)
            t += timer()
            ts.append(t)
        print("Run times: ", ts)
        print("Per trajectory: ", [t/param.ntraj for t in ts])
    return field

In [159]:
param = Param(
    beta = 2.0,
    lat = (8, 8),
    tau = 0.5, # 0.3
    nstep = 64, # 3
    # ADJUST ME
    ntraj = 4, # 2**16 # 2**10 # 2**15
    nprint = 4,
    #
    seed = 1331)

# field = ft_run(param, pre_flow)
field = ft_run(param, pre_flow, field)

latsize = (8, 8)
volume = 64
beta = 2.0
trajs = 4
tau = 0.5
steps = 64
seed = 1331
nth = 2
nth_interop = 2
Initial configuration:  plaq: 0.7629951284651968  topo: -1.0
13.22268982123488 (-53.34181291527922, -54.876554449401496) 0.9655721914591343
Traj:    1  ACCEPT:  dH:  0.0030688103  exp(-dH):  0.99693589   plaq:  0.68926726   topo: -1.0
21.852915596175038 (-54.85201954833743, -52.21010307281895) 0.9534584595204836
Traj:    2  ACCEPT:  dH: -0.079341423  exp(-dH):  1.0825739    plaq:  0.63487933   topo: -2.0
20.332651317477275 (-52.29965736350433, -52.47600352890689) 0.9622497587200374
Traj:    3  ACCEPT:  dH: -0.048030786  exp(-dH):  1.049203     plaq:  0.69466522   topo: -1.0
27.52295044775975 (-52.42782795609963, -53.22876146049952) 0.9555116848770403
Traj:    4  ACCEPT:  dH:  1.252823     exp(-dH):  0.28569713   plaq:  0.70510825   topo:  0.0
21.40765677353727 (-53.19172805413828, -54.22214120009521) 0.9576481950578482
Traj:    5  ACCEPT:  dH: -0.0062113797  exp(-dH):  1.0062307  

In [151]:
param = Param(
    beta = 2.0,
    lat = (8, 8),
    tau = 0.5, # 0.3
    nstep = 64, # 3
    # ADJUST ME
    ntraj = 4, # 2**16 # 2**10 # 2**15
    nprint = 4,
    #
    seed = 1331)

# field = ft_run(param, pre_flow)
field = ft_run(param, pre_flow, field)

latsize = (8, 8)
volume = 64
beta = 2.0
trajs = 4
tau = 0.5
steps = 64
seed = 1331
nth = 2
nth_interop = 2
Initial configuration:  plaq: 0.635313317826161  topo: 0.0
20.78557074806181 (-50.70482292494995, -53.69532464920362) 0.94443367886955
Traj:    1  ACCEPT:  dH: -0.1493157    exp(-dH):  1.1610395    plaq:  0.66644803   topo:  1.0
28.214570380777243 (-53.64952198450136, -54.10207440344347) 0.9293616738812421
Traj:    2  ACCEPT:  dH:  0.096385415  exp(-dH):  0.90811395   plaq:  0.76428727   topo:  1.0
18.847970377810025 (-54.09168794183857, -54.316912234755506) 0.9535372516104789
Traj:    3  ACCEPT:  dH:  0.13176509   exp(-dH):  0.87654688   plaq:  0.74819831   topo:  0.0
17.4003441477011 (-54.33298820802238, -56.12852759824926) 0.9407703828248026
Traj:    4  ACCEPT:  dH: -0.13929567   exp(-dH):  1.1494639    plaq:  0.75147884   topo:  0.0
16.556162048299715 (-56.11146213578058, -54.08786736404058) 0.977222502981708
Traj:    5  ACCEPT:  dH: -0.010062502  exp(-dH):  1.0101133    plaq:

In [157]:
action_list = np.array([l[1] for l in flattern(ft_hmc_info_list)])
action_list = sub_avg(action_list)
np.sqrt(average(action_list**2))

1.3476780398265165

In [158]:
force_list = np.array([l[0] for l in flattern(ft_hmc_info_list)])
np.sqrt(average(force_list**2))

24.652070381082513

In [156]:
print(np.array(force_list[0:300]))

[ 17.32528814  24.41830328  35.86739086  40.38043113  28.76035006
  23.42855236  29.58922539  33.54458561  33.65143712  33.86083941
  33.49291895  31.83862211  27.47076189  24.25968801  22.01808465
  18.63787514  14.72785278  10.78893991   9.04838945  12.19247229
  15.60845521  16.04484037  14.48102181  13.42641911  14.66221209
  17.15532298  17.43832809  12.67743819   6.41088996  12.52344938
  15.78009459  10.51769265   8.67199085  11.21120182  12.65614775
  12.97190556  12.88566952  12.88993966  12.04653558   9.08958067
  17.7111367   14.50481818   7.32864662   8.0942277    7.76810601
   7.34379062   7.18410535   7.13059923   7.06343922   6.96975024
   6.95233258   7.24697015   8.13272086   9.6659413   11.37533083
  12.02592045 431.99648976  17.3099667   18.31988376  18.31030973
  17.84059728  17.20681364  16.7225185   14.38948063   6.68696119
   5.84353042   6.36965709   7.5283624    8.78319064  10.00651765
  11.24919959  12.34706306  12.60897163  13.51418753  24.01165632
  31.29971