This notebook serves as a gentle introduction to the normflow package. Let's begin by importing a some standard libraries.

In [1]:
import os
import sys

Now let's import the key objects from the normflow library. Net, Action, etc...

In [2]:
from normflow import np, torch, Model
from normflow import backward_sanitychecker
from normflow.nn import DistConvertor_
from normflow.action import ScalarPhi4Action
from normflow.prior import NormalPrior

We define the parameters of our scalar field theory and the machine learning parameters.

In [None]:
m_sq=-1.2
lambd=0.5
knots_len=10
n_epochs=1000 
batch_size=1024
lat_shape=1  # basically a zero dimensional problem
nranks=1

It's time to instantiate the neural network and do the training.

In [3]:
net_ = DistConvertor_(knots_len, symmetric=True)

action_dict = dict(kappa=0, m_sq=m_sq, lambd=lambd)
prior = NormalPrior(shape=lat_shape)
action = ScalarPhi4Action(**action_dict)

model = Model(net_=net_, prior=prior, action=action)

snapshot_path = None

hyperparam = dict(lr=0.01, weight_decay=0.)

fit_kwargs = dict(
        n_epochs=n_epochs,
        save_every=None,
        batch_size=batch_size // nranks,
        hyperparam=hyperparam,
        checkpoint_dict=dict(print_stride=100, snapshot_path=snapshot_path)
        )

model.fit(**fit_kwargs)

backward_sanitychecker(model)


Not saving model snapshots

>>> Training progress (cpu) <<<

Note: log(q/p) is estimated with normalized p; mean & error are obtained from samples in a batch

Epoch: 1 | loss: -0.530407 | ess: 0.870286 | rho: 0.818643 | log(z): 1.11124(38) | log(q/p): 0.6(36) | accept_rate: 0.797(9)
Epoch: 10 | loss: -0.812281 | ess: 0.895045 | rho: 0.808877 | log(z): 1.10289(33) | log(q/p): 0.3(22) | accept_rate: 0.808(10)
Epoch: 100 | loss: -1.08288 | ess: 0.991174 | rho: 0.794849 | log(z): 1.115250(92) | log(q/p): 0.03(86) | accept_rate: 0.966(8)
Epoch: 200 | loss: -1.1083 | ess: 0.996703 | rho: 0.983689 | log(z): 1.111443(56) | log(q/p): 0.00(10) | accept_rate: 0.982(5)
Epoch: 300 | loss: -1.11372 | ess: 0.998849 | rho: 0.996756 | log(z): 1.114207(33) | log(q/p): 0.000(30) | accept_rate: 0.988(3)
Epoch: 400 | loss: -1.11222 | ess: 0.997468 | rho: 0.991354 | log(z): 1.113817(49) | log(q/p): 0.002(63) | accept_rate: 0.979(4)
Epoch: 500 | loss: -1.11301 | ess: 0.99917 | rho: 0.998099 | log(z): 1.11343