In [1]:
# DONE

# FT-HMC implemented for 8x8 2D QED (using SiLU as activation function).

# Try to minimize size of the force in training. No significant improvements.

# Some test on ergodicity
# (calculate the probablity of generating the configs obtained via conventional HMC).

# TODO

# Plot the force size distribution
# Is the large force from the original action or Field-Transformation the determinant?
# If from the determinant, then the fermion force won't cause problem for HMC

# Use the same Field-Transformation for larger system (say 16x16, 32x32, 64x64, etc)
# Study how the delta H depends on the system size ( perhaps delta H ~ sqrt(volume) )

# Study the auto-correlation for observables, topo, plaq, flowed plaq, etc.

# Improving the Field-Transformation to reduce force.

In [2]:
import torch
import math
import sys
import os
from timeit import default_timer as timer
from functools import reduce
from field_transformation import *

In [3]:
# From Xiao-Yong

class Param:
    def __init__(self, beta = 6.0, lat = [64, 64], tau = 2.0, nstep = 50, ntraj = 256, nrun = 4, nprint = 256, seed = 11*13, randinit = False, nth = int(os.environ.get('OMP_NUM_THREADS', '2')), nth_interop = 2):
        self.beta = beta
        self.lat = lat
        self.nd = len(lat)
        self.volume = reduce(lambda x,y:x*y, lat)
        self.tau = tau
        self.nstep = nstep
        self.dt = self.tau / self.nstep
        self.ntraj = ntraj
        self.nrun = nrun
        self.nprint = nprint
        self.seed = seed
        self.randinit = randinit
        self.nth = nth
        self.nth_interop = nth_interop
    def initializer(self):
        if self.randinit:
            return torch.empty((param.nd,) + param.lat).uniform_(-math.pi, math.pi)
        else:
            return torch.zeros((param.nd,) + param.lat)
    def summary(self):
        return f"""latsize = {self.lat}
volume = {self.volume}
beta = {self.beta}
trajs = {self.ntraj}
tau = {self.tau}
steps = {self.nstep}
seed = {self.seed}
nth = {self.nth}
nth_interop = {self.nth_interop}
"""
    def uniquestr(self):
        lat = ".".join(str(x) for x in self.lat)
        return f"out_l{lat}_b{param.beta}_n{param.ntraj}_t{param.tau}_s{param.nstep}.out"

def action(param, f):
    return (-param.beta)*torch.sum(torch.cos(plaqphase(f)))

def force(param, f):
    f.requires_grad_(True)
    s = action(param, f)
    f.grad = None
    s.backward()
    ff = f.grad
    f.requires_grad_(False)
    return ff

plaqphase = lambda f: f[0,:] - f[1,:] - torch.roll(f[0,:], shifts=-1, dims=1) + torch.roll(f[1,:], shifts=-1, dims=0)
topocharge = lambda f: torch.floor(0.1 + torch.sum(regularize(plaqphase(f))) / (2*math.pi))
def regularize(f):
    p2 = 2*math.pi
    f_ = (f - math.pi) / p2
    return p2*(f_ - torch.floor(f_) - 0.5)

def leapfrog(param, x, p):
    dt = param.dt
    x_ = x + 0.5*dt*p
    f = force(param, x_)
    p_ = p + (-dt)*f
    print(f'plaq(x) {action(param, x) / (-param.beta*param.volume)}  force.norm {torch.linalg.norm(f)}')
    for i in range(param.nstep-1):
        x_ = x_ + dt*p_
        p_ = p_ + (-dt)*force(param, x_)
    x_ = x_ + 0.5*dt*p_
    return (x_, p_)
def hmc(param, x):
    p = torch.randn_like(x)
    act0 = action(param, x) + 0.5*torch.sum(p*p)
    x_, p_ = leapfrog(param, x, p)
    xr = regularize(x_)
    act = action(param, xr) + 0.5*torch.sum(p_*p_)
    prob = torch.rand([], dtype=torch.float64)
    dH = act-act0
    exp_mdH = torch.exp(-dH)
    acc = prob < exp_mdH
    newx = xr if acc else x
    return (dH, exp_mdH, acc, newx)

put = lambda s: sys.stdout.write(s)

In [4]:
param = Param(
    beta = 2.0,
    lat = (8, 8),
    tau = 2, # 0.3
    nstep = 8, # 3
    # ADJUST ME
    ntraj = 2, # 2**16 # 2**10 # 2**15
    #
    nprint = 2,
    seed = 1331)

In [5]:
torch.manual_seed(param.seed)

torch.set_num_threads(param.nth)
torch.set_num_interop_threads(param.nth_interop)
os.environ["OMP_NUM_THREADS"] = str(param.nth)
os.environ["KMP_BLOCKTIME"] = "0"
os.environ["KMP_SETTINGS"] = "1"
os.environ["KMP_AFFINITY"]= "granularity=fine,verbose,compact,1,0"

torch.set_default_tensor_type(torch.DoubleTensor)

In [103]:
def run(param, field = None):
    if field is None:
        field = param.initializer()
    with open(param.uniquestr(), "w") as O:
        params = param.summary()
        O.write(params)
        put(params)
        plaq, topo = (action(param, field) / (-param.beta*param.volume), topocharge(field))
        status = f"Initial configuration:  plaq: {plaq}  topo: {topo}\n"
        O.write(status)
        put(status)
        ts = []
        for n in range(param.nrun):
            t = -timer()
            for i in range(param.ntraj):
                dH, exp_mdH, acc, field = hmc(param, field)
                plaq = action(param, field) / (-param.beta*param.volume)
                topo = topocharge(field)
                ifacc = "ACCEPT" if acc else "REJECT"
                status = f"Traj: {n*param.ntraj+i+1:4}  {ifacc}:  dH: {dH:< 12.8}  exp(-dH): {exp_mdH:< 12.8}  plaq: {plaq:< 12.8}  topo: {topo:< 3.3}\n"
                O.write(status)
                if (i+1) % (param.ntraj//param.nprint) == 0:
                    put(status)
            t += timer()
            ts.append(t)
        print("Run times: ", ts)
        print("Per trajectory: ", [t/param.ntraj for t in ts])
    return field
field = run(param)

latsize = (12, 12)
volume = 144
beta = 2.0
trajs = 4
tau = 0.5
steps = 64
seed = 1331
nth = 2
nth_interop = 2
Initial configuration:  plaq: 1.0  topo: 0.0
plaq(x) 1.0  force.norm 0.3902628323006594
Traj:    1  ACCEPT:  dH: -0.0061020683  exp(-dH):  1.0061207    plaq:  0.80160241   topo:  0.0
plaq(x) 0.8016024084919849  force.norm 26.353286360548907
Traj:    2  ACCEPT:  dH: -0.00098832182  exp(-dH):  1.0009888    plaq:  0.75800764   topo:  0.0
plaq(x) 0.7580076383603013  force.norm 28.28470315611182
Traj:    3  ACCEPT:  dH: -0.00045810993  exp(-dH):  1.0004582    plaq:  0.74109523   topo:  0.0
plaq(x) 0.7410952313494887  force.norm 28.779010749455527
Traj:    4  ACCEPT:  dH:  0.00021325972  exp(-dH):  0.99978676   plaq:  0.74094164   topo:  0.0
plaq(x) 0.7409416399750001  force.norm 28.464612960452314
Traj:    5  ACCEPT:  dH: -0.00017601601  exp(-dH):  1.000176     plaq:  0.72011859   topo:  0.0
plaq(x) 0.7201185944633461  force.norm 27.835833225543123
Traj:    6  ACCEPT:  dH: -0.000271

In [7]:
def ft_flow(flow, f):
    for layer in flow:
        f, lJ = layer.forward(f)
    return f.detach()

def ft_flow_inv(flow, f):
    for layer in reversed(flow):
        f, lJ = layer.reverse(f)
    return f.detach()

def ft_action(param, flow, f):
    y = f
    logJy = 0.0
    for layer in flow:
        y, lJ = layer.forward(y)
        logJy += lJ
    action = U1GaugeAction(param.beta)
    s = action(y) - logJy
    return s

def ft_force(param, flow, field, create_graph = False):
    # f is the field follows the transformed distribution (close to prior distribution)
    f = field
    f.requires_grad_(True)
    s = ft_action(param, flow, f)
    ss = torch.sum(s)
    # f.grad = None
    ff, = torch.autograd.grad(ss, f, create_graph = create_graph)
    f.requires_grad_(False)
    return ff

In [8]:
def train_step(model, action, optimizer, metrics, batch_size, with_force = False, pre_model = None):
    layers, prior = model['layers'], model['prior']
    optimizer.zero_grad()
    #
    xi = None
    if pre_model != None:
        pre_layers, pre_prior = pre_model['layers'], pre_model['prior']
        pre_xi = pre_prior.sample_n(batch_size)
        x = ft_flow(pre_layers, pre_xi)
        xi = ft_flow_inv(layers, x)
    #
    xi, x, logq = apply_flow_to_prior(prior, layers, batch_size=batch_size, xi=xi)
    logp = -action(x)
    #
    force_size = torch.tensor(0.0)
    dkl = calc_dkl(logp, logq)
    loss = torch.tensor(0.0)
    if with_force:
        assert pre_model != None
        force = ft_force(param, layers, xi, True)
        force_size = torch.sum(torch.square(force))
        loss = force_size
    else:
        loss = dkl
    #
    loss.backward()
    #
    # minimization target
    # loss mini
    # -> (logq - logp) mini
    # -> (action - logJ) mini
    #
    optimizer.step()
    ess = compute_ess(logp, logq)
    #
    print(grab(loss),
          grab(force_size),
          grab(dkl),
          grab(ess),
          torch.linalg.norm(ft_force(param, layers, xi)))
    #
    metrics['loss'].append(grab(loss))
    metrics['force'].append(grab(force_size))
    metrics['dkl'].append(grab(dkl))
    metrics['logp'].append(grab(logp))
    metrics['logq'].append(grab(logq))
    metrics['ess'].append(grab(ess))

def flow_train(param, with_force = False, pre_model = None):  # packaged from original ipynb by Xiao-Yong Jin
    # Theory
    lattice_shape = param.lat
    link_shape = (2,*param.lat)
    beta = param.beta
    u1_action = U1GaugeAction(beta)
    # Model
    prior = MultivariateUniform(torch.zeros(link_shape), 2*np.pi*torch.ones(link_shape))
    #
    n_layers = 24
    n_s_nets = 2
    hidden_sizes = [8,8]
    kernel_size = 3
    layers = make_u1_equiv_layers(lattice_shape=lattice_shape, n_layers=n_layers, n_mixture_comps=n_s_nets,
                                  hidden_sizes=hidden_sizes, kernel_size=kernel_size)
    set_weights(layers)
    model = {'layers': layers, 'prior': prior}
    # Training
    base_lr = .001
    optimizer = torch.optim.Adam(model['layers'].parameters(), lr=base_lr)
    optimizer_wf = torch.optim.Adam(model['layers'].parameters(), lr=base_lr / 100.0)
    #
    # ADJUST ME
    N_era = 10
    N_epoch = 100
    #
    batch_size = 64
    print_freq = N_epoch # epochs
    plot_freq = 1 # epochs
    history = {
        'loss' : [],
        'force' : [],
        'dkl' : [],
        'logp' : [],
        'logq' : [],
        'ess' : []
    }
    for era in range(N_era):
        for epoch in range(N_epoch):
            train_step(model, u1_action, optimizer, history, batch_size)
            if with_force:
                train_step(model, u1_action, optimizer_wf, history, batch_size,
                           with_force = with_force, pre_model = pre_model)
            if epoch % print_freq == 0:
                print_metrics(history, print_freq, era, epoch)
    return model,u1_action

def flow_eval(model, u1_action):  # packaged from original ipynb by Xiao-Yong Jin
    ensemble_size = 1024
    u1_ens = make_mcmc_ensemble(model, u1_action, 64, ensemble_size)
    print("Accept rate:", np.mean(u1_ens['accepted']))
    Q = grab(topo_charge(torch.stack(u1_ens['x'], axis=0)))
    X_mean, X_err = bootstrap(Q**2, Nboot=100, binsize=16)
    print(f'Topological susceptibility = {X_mean:.2f} +/- {X_err:.2f}')
    print(f'... vs HMC estimate = 1.23 +/- 0.02')

In [9]:
pre_flow_model, flow_act = flow_train(param)
flow_eval(pre_flow_model,flow_act)
pre_flow = pre_flow_model['layers']

-235.80379178382825 0.0 -235.80379178382825 0.015742986900554163 tensor(173.9276)
== Era 0 | Epoch 0 metrics ==
	loss -235.804
	force 0
	dkl -235.804
	logp 0.693854
	logq -235.11
	ess 0.015743
-240.22592204900425 0.0 -240.22592204900425 0.01726504117436026 tensor(169.3064)
-241.73488535507386 0.0 -241.73488535507386 0.016400911777968756 tensor(167.0606)
-247.46169379805445 0.0 -247.46169379805445 0.015810153849540865 tensor(160.3539)
-253.25240939507336 0.0 -253.25240939507336 0.01941417680419557 tensor(158.7390)
-256.7516621657513 0.0 -256.7516621657513 0.015804690643426364 tensor(154.5671)
-257.21331835493976 0.0 -257.21331835493976 0.01562887239859159 tensor(157.1432)
-261.58546505634735 0.0 -261.58546505634735 0.020297799438260253 tensor(151.7423)
-265.36695333901866 0.0 -265.36695333901866 0.0156412599546812 tensor(151.7542)
-266.2862499543668 0.0 -266.2862499543668 0.016146700456890064 tensor(155.4196)
-270.4000967148002 0.0 -270.4000967148002 0.01831122343147412 tensor(156.1291)

-282.30318004954523 0.0 -282.30318004954523 0.03519414026309184 tensor(173.6366)
-281.71068933998566 0.0 -281.71068933998566 0.044429059801842655 tensor(189.9906)
-282.02463668393204 0.0 -282.02463668393204 0.028696246313953674 tensor(171.9115)
-283.132869239756 0.0 -283.132869239756 0.019255968703855045 tensor(165.2447)
-282.0692816718184 0.0 -282.0692816718184 0.023057457037706616 tensor(170.5575)
-282.6374101397821 0.0 -282.6374101397821 0.03463303707863622 tensor(174.0301)
-282.6942481521188 0.0 -282.6942481521188 0.07775488816440008 tensor(182.1365)
-282.8336707077077 0.0 -282.8336707077077 0.11739860957901044 tensor(173.4950)
-282.4292215109003 0.0 -282.4292215109003 0.04353524233107781 tensor(174.8651)
-282.6392043268243 0.0 -282.6392043268243 0.057467884980354454 tensor(174.7497)
-282.52266144131096 0.0 -282.52266144131096 0.1869519973917699 tensor(175.9306)
-282.3960594537949 0.0 -282.3960594537949 0.10639430865819953 tensor(175.4130)
-283.18745889396143 0.0 -283.1874588939614

-285.4745453428999 0.0 -285.4745453428999 0.08551370972478504 tensor(159.2134)
-285.74109108264855 0.0 -285.74109108264855 0.09694266245844649 tensor(166.2664)
-285.18165772424607 0.0 -285.18165772424607 0.18581720592138307 tensor(160.6161)
-284.9456329769272 0.0 -284.9456329769272 0.037887675923891334 tensor(165.5799)
-285.13679041170724 0.0 -285.13679041170724 0.0709500809313619 tensor(161.8093)
-284.8899231454957 0.0 -284.8899231454957 0.1292122816474257 tensor(175.4481)
-285.48417859246416 0.0 -285.48417859246416 0.09883028991631716 tensor(170.3434)
-285.2071173702128 0.0 -285.2071173702128 0.16868826186902278 tensor(159.7044)
-285.5761087567977 0.0 -285.5761087567977 0.10562976216060008 tensor(153.3455)
-284.95738302568657 0.0 -284.95738302568657 0.1613954717331739 tensor(166.5366)
-285.52831036667357 0.0 -285.52831036667357 0.06433101101357216 tensor(157.1581)
-285.53802653925254 0.0 -285.53802653925254 0.08475295738354054 tensor(168.5715)
-285.5215177825453 0.0 -285.521517782545

-285.21536683697167 0.0 -285.21536683697167 0.2189449581050638 tensor(171.8966)
-285.48660306425217 0.0 -285.48660306425217 0.11321807829973943 tensor(187.9403)
-285.5833949046988 0.0 -285.5833949046988 0.06373153825501945 tensor(158.9213)
-285.48727051687985 0.0 -285.48727051687985 0.08399532584933865 tensor(182.5070)
-285.54272284288584 0.0 -285.54272284288584 0.049358031771688855 tensor(175.0742)
-285.1236794662193 0.0 -285.1236794662193 0.1780201144879282 tensor(176.3286)
-285.84530449928945 0.0 -285.84530449928945 0.1956267839023526 tensor(155.9885)
-286.1826945515368 0.0 -286.1826945515368 0.20304134682894795 tensor(165.7703)
-285.68631223265453 0.0 -285.68631223265453 0.2116240683831535 tensor(173.1717)
-285.77136434638396 0.0 -285.77136434638396 0.23513026087984623 tensor(168.5351)
-285.317134575164 0.0 -285.317134575164 0.09205607013542556 tensor(185.0386)
-285.3653766845781 0.0 -285.3653766845781 0.07107531753230847 tensor(166.2538)
-285.29057236143143 0.0 -285.29057236143143

-286.2298298684607 0.0 -286.2298298684607 0.20246256551853137 tensor(171.2551)
-286.18202929008567 0.0 -286.18202929008567 0.15544772488882455 tensor(171.7185)
-286.30195926420214 0.0 -286.30195926420214 0.20824601214829314 tensor(164.9257)
-286.8255649181762 0.0 -286.8255649181762 0.29026847375326403 tensor(173.1734)
-286.41775810069305 0.0 -286.41775810069305 0.10261035785947044 tensor(182.5192)
-286.356919937346 0.0 -286.356919937346 0.057770890006515106 tensor(164.2266)
-286.17847129446477 0.0 -286.17847129446477 0.09900306937627477 tensor(172.7846)
-286.4338164251212 0.0 -286.4338164251212 0.0916416689717375 tensor(175.3758)
-286.51255808334076 0.0 -286.51255808334076 0.08394243379648351 tensor(184.6040)
-286.49338604286265 0.0 -286.49338604286265 0.15191005068858007 tensor(171.0245)
-286.1197153753396 0.0 -286.1197153753396 0.0958604712796014 tensor(221.3763)
-286.27289834545667 0.0 -286.27289834545667 0.1455096301059904 tensor(202.0609)
-286.68317381021313 0.0 -286.6831738102131

-286.1030487674826 0.0 -286.1030487674826 0.14757232188571232 tensor(177.4414)
-286.79881426928785 0.0 -286.79881426928785 0.19925762210835296 tensor(175.7745)
-286.8895556518922 0.0 -286.8895556518922 0.20663419603924182 tensor(212.8828)
-286.73436938298016 0.0 -286.73436938298016 0.2396118599507676 tensor(198.8368)
-286.3219745816881 0.0 -286.3219745816881 0.1414377985216021 tensor(222.2222)
-286.2504391436577 0.0 -286.2504391436577 0.14604459521012364 tensor(155.7067)
-286.80813812640076 0.0 -286.80813812640076 0.2522846437389289 tensor(207.6366)
-286.7118538872095 0.0 -286.7118538872095 0.10008081763760789 tensor(187.6564)
-286.7248982663979 0.0 -286.7248982663979 0.24895956847581732 tensor(175.3379)
-286.25743717191267 0.0 -286.25743717191267 0.2283220077208355 tensor(182.6109)
-286.70812550199093 0.0 -286.70812550199093 0.12354206327816915 tensor(155.5527)
-287.02429185487904 0.0 -287.02429185487904 0.2074911297505572 tensor(183.9654)
-286.7605332923435 0.0 -286.7605332923435 0.1

-286.37424945966757 0.0 -286.37424945966757 0.27631506150752144 tensor(187.5516)
-286.87764579615214 0.0 -286.87764579615214 0.33560314012720543 tensor(196.6000)
-286.6968894805847 0.0 -286.6968894805847 0.18031224362893333 tensor(187.4604)
-286.5287983371965 0.0 -286.5287983371965 0.20499660178731807 tensor(196.2932)
-286.2364603642544 0.0 -286.2364603642544 0.11616869579152166 tensor(197.9263)
-286.1971812865029 0.0 -286.1971812865029 0.11281678713560853 tensor(253.9944)
-287.031403529598 0.0 -287.031403529598 0.08090413860815306 tensor(195.7730)
-286.32109874125695 0.0 -286.32109874125695 0.2542890375275158 tensor(172.7984)
-286.5775680322279 0.0 -286.5775680322279 0.2325613031681031 tensor(189.0569)
-286.5379995265555 0.0 -286.5379995265555 0.059621040490759954 tensor(224.7545)
-286.85048280028434 0.0 -286.85048280028434 0.07016776209457493 tensor(265.4737)
-286.8202146132348 0.0 -286.8202146132348 0.25341311544609363 tensor(286.0131)
-286.6823554320838 0.0 -286.6823554320838 0.180

-286.829989939537 0.0 -286.829989939537 0.2897993723674256 tensor(158.2422)
-286.55234764676885 0.0 -286.55234764676885 0.22451236753008907 tensor(182.5948)
-286.9383151640146 0.0 -286.9383151640146 0.14170541945339032 tensor(295.1155)
-286.9684003038342 0.0 -286.9684003038342 0.17478064276434813 tensor(227.1028)
-287.10364243493234 0.0 -287.10364243493234 0.32195787571074136 tensor(218.9781)
-286.3360277810375 0.0 -286.3360277810375 0.24192631806098625 tensor(239.8895)
-286.43053024083815 0.0 -286.43053024083815 0.2601425447803293 tensor(164.5041)
-287.04702616368945 0.0 -287.04702616368945 0.1019783067028248 tensor(257.8064)
-286.6678767203065 0.0 -286.6678767203065 0.23081487720553823 tensor(212.7097)
-286.8493991942497 0.0 -286.8493991942497 0.18800788629736237 tensor(256.1802)
-286.67042896794663 0.0 -286.67042896794663 0.3784958505739057 tensor(181.2171)
-287.34412030840406 0.0 -287.34412030840406 0.33425404121055735 tensor(247.9975)
-286.9731702036803 0.0 -286.9731702036803 0.09

-287.1013126907318 0.0 -287.1013126907318 0.21329969181483255 tensor(193.4902)
-286.75604462110755 0.0 -286.75604462110755 0.2194159025806836 tensor(201.2356)
-287.0479481087499 0.0 -287.0479481087499 0.23300483909460953 tensor(223.6758)
-286.7347007892731 0.0 -286.7347007892731 0.37456966217586135 tensor(215.4751)
-286.59911385224126 0.0 -286.59911385224126 0.08058265640562165 tensor(185.7359)
-286.9823837360541 0.0 -286.9823837360541 0.35942636246267073 tensor(177.2091)
-287.4609491473624 0.0 -287.4609491473624 0.3299042787363679 tensor(167.8750)
-286.5736689158928 0.0 -286.5736689158928 0.34961766350318535 tensor(295.1423)
-286.8294812932953 0.0 -286.8294812932953 0.30271053184380814 tensor(179.9932)
-287.1030216724084 0.0 -287.1030216724084 0.3387798203593732 tensor(164.6811)
-286.815283436428 0.0 -286.815283436428 0.31125346821440225 tensor(171.8524)
-286.8567798727716 0.0 -286.8567798727716 0.053799360710257295 tensor(168.7261)
-286.7602947419879 0.0 -286.7602947419879 0.19893814

-287.1974354358955 0.0 -287.1974354358955 0.24741621583893092 tensor(414.6770)
-286.9126605344217 0.0 -286.9126605344217 0.19326177991835528 tensor(197.4445)
-287.1525202454744 0.0 -287.1525202454744 0.15849238446212519 tensor(180.5447)
-286.99595224483573 0.0 -286.99595224483573 0.13534719086020094 tensor(221.3385)
-287.0839793721429 0.0 -287.0839793721429 0.2945747959313512 tensor(205.0207)
-286.92071639395357 0.0 -286.92071639395357 0.24238989951667803 tensor(198.5116)
-286.9655107787159 0.0 -286.9655107787159 0.3045876064012117 tensor(302.7847)
-287.01217130365694 0.0 -287.01217130365694 0.21326813705089678 tensor(141.8040)
-286.72580910681984 0.0 -286.72580910681984 0.25328739579993514 tensor(171.2045)
-286.8274692717415 0.0 -286.8274692717415 0.3325019587180933 tensor(198.5363)
-286.9714275810692 0.0 -286.9714275810692 0.22347510584246394 tensor(280.2572)
-286.91179021224457 0.0 -286.91179021224457 0.22762529192807113 tensor(431.2482)
-286.9297917601647 0.0 -286.9297917601647 0.1

In [12]:
train_force = False
flow_model = None
if train_force:
    flow_model, flow_act = flow_train(param, with_force=True, pre_model=pre_flow_model)
else:
    flow_model = pre_flow_model
flow_eval(flow_model,flow_act)
flow = flow_model['layers']
# flow.eval()

Accept rate: 0.3291015625
Topological susceptibility = 1.15 +/- 0.10
... vs HMC estimate = 1.23 +/- 0.02


In [13]:
def test_force(x = None):
    model = flow_model
    layers, prior = model['layers'], model['prior']
    if x == None:
        pre_model = pre_flow_model
        pre_layers, pre_prior = pre_model['layers'], pre_model['prior']
        pre_xi = pre_prior.sample_n(1)
        x = ft_flow(pre_layers, pre_xi)
    xi = ft_flow_inv(layers, x)
    f = ft_force(param, layers, xi)
    f_s = torch.linalg.norm(f)
    print(f_s)

test_force()
test_force(torch.reshape(field,(1,)+field.shape))

tensor(16.7122)
tensor(11.7961)


In [104]:
field = run(param, field)

latsize = (12, 12)
volume = 144
beta = 2.0
trajs = 4
tau = 0.5
steps = 64
seed = 1331
nth = 2
nth_interop = 2
Initial configuration:  plaq: 0.6436154345146917  topo: -1.0
plaq(x) 0.6436154345146917  force.norm 28.595181015284275
Traj:    1  ACCEPT:  dH:  0.00048705497  exp(-dH):  0.99951306   plaq:  0.66561836   topo: -2.0
plaq(x) 0.6656183553517631  force.norm 28.11460650971445
Traj:    2  ACCEPT:  dH: -0.0012323184  exp(-dH):  1.0012331    plaq:  0.62357274   topo: -1.0
plaq(x) 0.6235727353424116  force.norm 30.429344370433586
Traj:    3  ACCEPT:  dH:  0.00029376435  exp(-dH):  0.99970628   plaq:  0.59412058   topo:  0.0
plaq(x) 0.5941205758120947  force.norm 28.538774002763446
Traj:    4  ACCEPT:  dH: -0.00022747014  exp(-dH):  1.0002275    plaq:  0.62219674   topo:  1.0
plaq(x) 0.6221967424532626  force.norm 29.916468389422278
Traj:    5  ACCEPT:  dH:  0.0015577067  exp(-dH):  0.99844351   plaq:  0.69874304   topo:  0.0
plaq(x) 0.6987430410613997  force.norm 27.98159411963013
Traj:

In [15]:
flows = flow

print(f'plaq(field_run[0]) {action(param, field_run[0]) / (-param.beta*param.volume)}')
# field.requires_grad_(True)
x = field_run
logJ = 0.0
for layer in reversed(flows):
    x, lJ = layer.reverse(x)
    logJ += lJ

# x is the prior distribution now
    
x.requires_grad_(True)
    
y = x
logJy = 0.0
for layer in flows:
    y, lJ = layer.forward(y)
    logJy += lJ
    
s = action(param, y[0]) - logJy

print(logJ,logJy)


# print("eff_action", s + 136.3786)

print("original_action", action(param, y[0]) + 91)

print("eff_action", s + 56)

s.backward()

f = x.grad

x.requires_grad_(False)

print(f'plaq(x) {action(param, x[0]) / (-param.beta*param.volume)}  logJ {logJ}  force.norm {torch.linalg.norm(f)}')

print(f'plaq(y) {action(param, y[0]) / (-param.beta*param.volume)}')

print(f'plaq(x) {action(param, field_run[0]) / (-param.beta*param.volume)}  force.norm {torch.linalg.norm(force(param, field_run[0]))}')


plaq(field_run[0]) 0.7302893102810509
tensor([40.8650], grad_fn=<AddBackward0>) tensor([-40.8650], grad_fn=<AddBackward0>)
original_action tensor(-2.4770, grad_fn=<AddBackward0>)
eff_action tensor([3.3880], grad_fn=<AddBackward0>)
plaq(x) 0.03995800269130618  logJ tensor([40.8650], grad_fn=<AddBackward0>)  force.norm 7.234809490104591
plaq(y) 0.7302892218853114
plaq(x) 0.7302893102810509  force.norm 19.320145488334514


In [16]:
print(x.shape)
x = ft_flow_inv(flow, field_run)
# x = field_run
#for layer in reversed(flows):
#    x, lJ = layer.reverse(x)
ff = ft_force(param, flow, x)
print(torch.linalg.norm(ff))
fff = ft_force(param, flow, x)
print(torch.linalg.norm(fff))

torch.Size([1, 2, 8, 8])
tensor(7.2348)
tensor(7.2348)


In [17]:
x = ft_flow_inv(flow, field_run)
ft_action(param, flow, x)

tensor([-52.6120], grad_fn=<SubBackward0>)

In [18]:
def flattern(l):
    return [x for y in l for x in y]

def average(l):
    return sum(l) / len(l)

def sub_avg(l):
    avg = average(l)
    return np.array([x - avg for x in l])

In [19]:
ft_hmc_info_list = []
def ft_leapfrog(param, flow, x, p):
    mom_norm = torch.sum(p*p)
    info_list = []
    dt = param.dt
    x_ = x + 0.5*dt*p
    f = ft_force(param, flow, x_)
    p_ = p + (-dt)*f
    info = np.array((float(torch.linalg.norm(f)),
                     float(ft_action(param, flow, x_).detach()),
                     float(torch.sum(p*p_)/np.sqrt(mom_norm*torch.sum(p_*p_)))))
    info_list.append(info)
    for i in range(param.nstep-1):
        x_ = x_ + dt*p_
        f = ft_force(param, flow, x_)
        info = np.array((float(torch.linalg.norm(f)),
                        float(ft_action(param, flow, x_).detach()),
                        float(torch.sum(p*p_)/np.sqrt(mom_norm*torch.sum(p_*p_)))))
        info_list.append(info)
        p_ = p_ + (-dt)*f
    x_ = x_ + 0.5*dt*p_
    print(np.sqrt(average([l[0]**2 for l in info_list])),
          (info_list[0][1], info_list[-1][1]),
          info_list[-1][2])
    ft_hmc_info_list.append(info_list)
    return (x_, p_)

def ft_hmc(param, flow, field):
    x = ft_flow_inv(flow, field)
    p = torch.randn_like(x)
    act0 = ft_action(param, flow, x).detach() + 0.5*torch.sum(p*p)
    x_, p_ = ft_leapfrog(param, flow, x, p)
    xr = regularize(x_)
    act = ft_action(param, flow, xr).detach() + 0.5*torch.sum(p_*p_)
    prob = torch.rand([], dtype=torch.float64)
    dH = act-act0
    exp_mdH = torch.exp(-dH)
    acc = prob < exp_mdH
    # ADJUST ME
    newx = xr if acc else x
    # newx = xr
    newfield = ft_flow(flow, newx)
    return (float(dH), float(exp_mdH), acc, newfield)

In [20]:
def ft_run(param, flow, field = None):
    if field == None:
        field = param.initializer()
    ft_hmc_info_list = []
    with open(param.uniquestr(), "w") as O:
        params = param.summary()
        O.write(params)
        put(params)
        plaq, topo = (action(param, field) / (-param.beta*param.volume), topocharge(field))
        status = f"Initial configuration:  plaq: {plaq}  topo: {topo}\n"
        O.write(status)
        put(status)
        ts = []
        for n in range(param.nrun):
            t = -timer()
            for i in range(param.ntraj):
                field_run = torch.reshape(field,(1,)+field.shape)
                dH, exp_mdH, acc, field_run = ft_hmc(param, flow, field_run)
                field = field_run[0]
                plaq = action(param, field) / (-param.beta*param.volume)
                topo = topocharge(field)
                ifacc = "ACCEPT" if acc else "REJECT"
                status = f"Traj: {n*param.ntraj+i+1:4}  {ifacc}:  dH: {dH:< 12.8}  exp(-dH): {exp_mdH:< 12.8}  plaq: {plaq:< 12.8}  topo: {topo:< 3.3}\n"
                O.write(status)
                if (i+1) % (param.ntraj//param.nprint) == 0:
                    put(status)
            t += timer()
            ts.append(t)
        print("Run times: ", ts)
        print("Per trajectory: ", [t/param.ntraj for t in ts])
    return field

In [21]:
param = Param(
    beta = 2.0,
    lat = (8, 8),
    tau = 0.5, # 0.3
    nstep = 64, # 3
    # ADJUST ME
    ntraj = 4, # 2**16 # 2**10 # 2**15
    nprint = 4,
    #
    seed = 1331)

# field = ft_run(param, pre_flow)
field = ft_run(param, pre_flow, field)

latsize = (8, 8)
volume = 64
beta = 2.0
trajs = 4
tau = 0.5
steps = 64
seed = 1331
nth = 2
nth_interop = 2
Initial configuration:  plaq: 0.6627686419055754  topo: 0.0
13.378684424948641 (-53.73971180677989, -53.659984669959464) 0.9732045603479264
Traj:    1  ACCEPT:  dH: -0.0028627377  exp(-dH):  1.0028668    plaq:  0.72662017   topo:  0.0
20.26347831375218 (-53.562232506413864, -52.51380989435269) 0.9547163537573339
Traj:    2  REJECT:  dH:  0.34848501   exp(-dH):  0.7057565    plaq:  0.72661999   topo:  0.0
16.331966844133525 (-53.60969968082007, -53.5585371652662) 0.9871465395562814
Traj:    3  ACCEPT:  dH: -0.01379607   exp(-dH):  1.0138917    plaq:  0.6983424    topo:  0.0
15.456432820850809 (-53.48774725216437, -51.383267278283995) 0.9653108529207802
Traj:    4  ACCEPT:  dH:  0.0012047187  exp(-dH):  0.99879601   plaq:  0.65318636   topo:  0.0
17.38078591052218 (-51.4535395027086, -54.30936491400867) 0.9909948866061801
Traj:    5  REJECT:  dH:  0.37634973   exp(-dH):  0.68636225 

In [22]:
param = Param(
    beta = 2.0,
    lat = (8, 8),
    tau = 0.5, # 0.3
    nstep = 64, # 3
    # ADJUST ME
    ntraj = 4, # 2**16 # 2**10 # 2**15
    nprint = 4,
    #
    seed = 1331)

# field = ft_run(param, pre_flow)
field = ft_run(param, pre_flow, field)

latsize = (8, 8)
volume = 64
beta = 2.0
trajs = 4
tau = 0.5
steps = 64
seed = 1331
nth = 2
nth_interop = 2
Initial configuration:  plaq: 0.7185859884287391  topo: -1.0
35.58825429661411 (-53.64953457231129, -54.05251106745193) 0.957981301570924
Traj:    1  ACCEPT:  dH: -0.2646876    exp(-dH):  1.3030238    plaq:  0.68012964   topo:  0.0
26.996907862728886 (-54.19039581798784, -55.528032313186635) 0.9705934656072241
Traj:    2  ACCEPT:  dH: -0.53375325   exp(-dH):  1.7053208    plaq:  0.72417941   topo:  1.0
27.88171560805188 (-55.551799406136276, -54.5756659006458) 0.9700067180225045
Traj:    3  ACCEPT:  dH: -0.056386845  exp(-dH):  1.0580069    plaq:  0.69634589   topo: -1.0
30.950330945550775 (-54.69092790549618, -52.0511193358186) 0.8944220460873022
Traj:    4  ACCEPT:  dH:  0.49485686   exp(-dH):  0.60965816   plaq:  0.60480921   topo:  1.0
42.92329937977941 (-52.085588556007835, -51.74406041828097) 0.9394773190295843
Traj:    5  ACCEPT:  dH:  1.4368181    exp(-dH):  0.23768285   p

In [23]:
action_list = np.array([l[1] for l in flattern(ft_hmc_info_list)])
action_list = sub_avg(action_list)
np.sqrt(average(action_list**2))

1.389480087367922

In [24]:
force_list = np.array([l[0] for l in flattern(ft_hmc_info_list)])
np.sqrt(average(force_list**2))

26.10260986658055

In [25]:
print(np.array(force_list[0:300]))

[ 18.27711285  34.3038331   38.47971804  22.92079991  14.89793795
  18.08681988  20.31581029  22.14914023  23.87103188  23.34911165
  19.31782271  15.31122775  14.29842581  16.53974364  20.60968096
  22.45969878  19.81038763  14.50076622   9.74271184   7.16144983
   6.43791031   6.51055981   6.79499594   7.1529161    7.57889242
   8.06840814   8.57790798   9.01304499   9.2346246    9.09240327
   8.49076822   7.45884179   6.1723308    4.90264743   3.92519736
   3.41259392   3.33438466   3.51030509   3.78384734   4.08238111
   4.38564792   4.69295284   5.00689449   5.32765208   5.65228937
   5.97601477   6.29380474   6.60160793   6.89683138   7.17811448
   7.44458231   7.69486476   7.92618572   8.13380598   8.31116147
   8.45151134   8.55369476   8.63963128   8.80064034   9.28327141
  10.49678124  12.55137939  14.37275549  13.96333735  11.64322749
   9.55646993   7.99003083   7.13012197   6.80463101   6.76225949
   6.8546393    6.99590927   7.07011514   6.90727963   6.40603871
   5.93640

In [45]:
param = Param(
    beta = 2.0,
    lat = (12, 12),
    tau = 0.5, # 0.3
    nstep = 64, # 3
    # ADJUST ME
    ntraj = 4, # 2**16 # 2**10 # 2**15
    nprint = 4,
    #
    seed = 1331)

# field = param.initializer()

field = run(param, field)

latsize = (12, 12)
volume = 144
beta = 2.0
trajs = 4
tau = 0.5
steps = 64
seed = 1331
nth = 2
nth_interop = 2
Initial configuration:  plaq: 0.7016228019160151  topo: -2.0
plaq(x) 0.7016228019160151  force.norm 27.473372391179087
Traj:    1  ACCEPT:  dH: -0.00080053206  exp(-dH):  1.0008009    plaq:  0.678064     topo:  0.0
plaq(x) 0.6780639990092816  force.norm 28.79359239911523
Traj:    2  ACCEPT:  dH:  0.0011502039  exp(-dH):  0.99885046   plaq:  0.74059871   topo:  0.0
plaq(x) 0.7405987100486567  force.norm 27.80254480347343
Traj:    3  ACCEPT:  dH: -0.0012113997  exp(-dH):  1.0012121    plaq:  0.70697632   topo: -1.0
plaq(x) 0.7069763193143671  force.norm 30.515196608105374
Traj:    4  ACCEPT:  dH:  0.0015065459  exp(-dH):  0.99849459   plaq:  0.72056544   topo:  0.0
plaq(x) 0.720565440451751  force.norm 27.413912752174674
Traj:    5  ACCEPT:  dH: -0.00072322231  exp(-dH):  1.0007235    plaq:  0.69602558   topo:  0.0
plaq(x) 0.696025580670386  force.norm 28.34971124354062
Traj:    

In [97]:
field_run.shape

torch.Size([1, 2, 12, 12])

In [35]:
param.initializer()

tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 

In [96]:
force(param, field_run[0])

tensor([[[ 1.7676,  0.1112, -1.1515,  0.0148, -1.6742,  2.7194, -3.3007,
           0.3404,  0.8066,  2.7574, -0.4600, -1.9311],
         [ 1.2636, -0.3440, -1.2371,  1.0828,  2.4481, -1.8026,  1.8030,
          -2.8691, -0.6625,  2.9842,  0.5663, -3.2327],
         [ 1.0764,  0.5406, -0.9713,  0.1759, -1.2286, -0.8030,  0.1867,
           0.1817,  1.1472,  0.1158, -0.0745, -0.3469],
         [ 1.7824, -2.2857, -1.2961,  1.7109,  2.1828, -1.3668,  0.4849,
           0.7049, -1.7105,  1.9339, -1.3661, -0.7745],
         [-0.7513, -2.1414,  2.1709,  0.8740, -0.0108, -3.3313,  0.9532,
          -1.1550,  1.2115, -0.5403,  1.7388,  0.9816],
         [ 0.7189, -1.3359,  0.4805,  1.3458, -1.1233, -0.9647,  2.4161,
          -0.0155, -1.5422,  0.0277,  0.2828, -0.2902],
         [-2.2164,  2.0370, -0.7510,  1.3763, -0.2672, -1.2977, -1.6096,
           3.0478, -2.5262,  0.8780,  1.9587, -0.6297],
         [ 0.8151, -0.4133, -0.5078,  1.6036, -2.1299,  1.0449, -1.5394,
           0.0590,  1.07

In [86]:
def get_nets(layers):
    nets = []
    for l in layers:
        nets.append(l.plaq_coupling.net)
    return nets

In [92]:
def make_u1_equiv_layers_net(*, lattice_shape, nets):
    n_layers = len(nets)
    layers = []
    for i in range(n_layers):
        # periodically loop through all arrangements of maskings
        mu = i % 2
        off = (i//2) % 4
        net = nets[i]
        plaq_coupling = NCPPlaqCouplingLayer(
            net, mask_shape=lattice_shape, mask_mu=mu, mask_off=off)
        link_coupling = GaugeEquivCouplingLayer(
            lattice_shape=lattice_shape, mask_mu=mu, mask_off=off, 
            plaq_coupling=plaq_coupling)
        layers.append(link_coupling)
    return torch.nn.ModuleList(layers)

In [93]:
len(get_nets(flow))

24

In [94]:
new_flow = make_u1_equiv_layers_net(lattice_shape = param.lat, nets = get_nets(flow))

In [101]:
ft_run(param, new_flow, field)

latsize = (12, 12)
volume = 144
beta = 2.0
trajs = 4
tau = 0.5
steps = 64
seed = 1331
nth = 2
nth_interop = 2
Initial configuration:  plaq: 0.6481015786072608  topo: 2.0
23.769701649797888 (-116.53445539389065, -120.53610026514147) 0.965123015498082
Traj:    1  ACCEPT:  dH:  0.12061243   exp(-dH):  0.88637742   plaq:  0.65311072   topo: -2.0
32.16286069243797 (-120.1481012358675, -120.23093118538995) 0.9419354022939522
Traj:    2  ACCEPT:  dH:  0.0054811496  exp(-dH):  0.99453384   plaq:  0.65460432   topo:  2.0
39.85399081188922 (-120.16897302727948, -122.11547603755164) 0.9641809328701209
Traj:    3  ACCEPT:  dH: -0.031140951  exp(-dH):  1.0316309    plaq:  0.69446443   topo: -2.0
43.151113410659725 (-122.34084850976468, -119.03432826546131) 0.912262122722401
Traj:    4  ACCEPT:  dH: -0.11953146   exp(-dH):  1.1269687    plaq:  0.63606857   topo: -1.0
37.32056471850772 (-118.79559013169148, -120.57127954351374) 0.8971280972684773
Traj:    5  ACCEPT:  dH:  0.063506931  exp(-dH):  0.93

tensor([[[2.3752e+00, 3.7908e+00, 4.3528e+00, 2.5210e+00, 4.7764e+00,
          5.4460e+00, 4.3141e+00, 5.1648e+00, 3.5046e+00, 4.6180e+00,
          1.0509e+00, 1.6771e+00],
         [4.8399e+00, 1.5831e+00, 1.4803e+00, 4.7336e+00, 5.1794e+00,
          1.3116e+00, 1.0175e+00, 3.4732e+00, 2.8924e+00, 1.1776e+00,
          6.2588e+00, 1.9701e+00],
         [4.8401e-01, 6.6562e-01, 4.7643e+00, 3.3379e+00, 5.0044e+00,
          3.4488e+00, 3.2837e+00, 5.6005e+00, 5.1433e+00, 5.3443e+00,
          1.8785e+00, 4.6364e+00],
         [7.2549e-01, 5.4411e+00, 4.0904e+00, 1.7220e+00, 5.6641e+00,
          5.6044e+00, 6.8245e-01, 5.5383e+00, 4.4251e-01, 1.0098e+00,
          2.8486e+00, 5.2598e+00],
         [2.0633e+00, 1.9559e+00, 3.9353e+00, 4.3933e+00, 2.8059e+00,
          3.3580e+00, 2.2833e+00, 3.0077e+00, 5.1626e+00, 3.5316e+00,
          5.1970e+00, 1.5598e+00],
         [5.3014e+00, 8.9498e-01, 1.8614e-01, 4.6153e+00, 5.6853e+00,
          4.7606e+00, 4.4472e-01, 4.4610e+00, 6.1098e+0

In [105]:
action(param, field)

tensor(-209.9513)