In [28]:
from numpy import log
from torch import nn

import re
import time
from collections import namedtuple
from math import sqrt

import numpy as np
import torch
import torchvision
from uncertainties import ufloat

import os
from glob import glob



import argparse
default_dtype_torch = torch.float32

In [29]:
parser = argparse.ArgumentParser()

group = parser.add_argument_group('physics parameters')
group.add_argument(
    '--ham',
    type=str,
    default='afm',
    choices=['afm', 'fm'],
    help='Hamiltonian model')
group.add_argument(
    '--lattice',
    type=str,
    default='tri',
    choices=['sqr', 'tri'],
    help='lattice shape')
group.add_argument(
    '--boundary',
    type=str,
    default='periodic',
    choices=['open', 'periodic'],
    help='boundary condition')
group.add_argument(
    '--L',
    type=int,
    default=4,
    help='number of sites on each edge of the lattice')
group.add_argument('--beta', type=float, default=1, help='beta = 1 / k_B T')

group = parser.add_argument_group('network parameters')
group.add_argument(
    '--net',
    type=str,
    default='pixelcnn',
    choices=['made', 'pixelcnn', 'bernoulli'],
    help='network type')
group.add_argument('--net_depth', type=int, default=3, help='network depth')
group.add_argument('--net_width', type=int, default=64, help='network width')
group.add_argument(
    '--half_kernel_size', type=int, default=1, help='(kernel_size - 1) // 2')
group.add_argument(
    '--dtype',
    type=str,
    default='float32',
    choices=['float32', 'float64'],
    help='dtype')
group.add_argument('--bias', action='store_true', help='use bias')
group.add_argument(
    '--z2', action='store_true', help='use Z2 symmetry in sample and loss')
group.add_argument('--res_block', action='store_true', help='use res block')
group.add_argument(
    '--x_hat_clip',
    type=float,
    default=0,
    help='value to clip x_hat around 0 and 1, 0 for disabled')
group.add_argument(
    '--final_conv',
    action='store_true',
    help='add an additional conv layer before sigmoid')
group.add_argument(
    '--epsilon',
    type=float,
    default=1e-7,
    help='small number to avoid 0 in division and log')

group = parser.add_argument_group('optimizer parameters')
group.add_argument(
    '--seed', type=int, default=0, help='random seed, 0 for randomized')
group.add_argument(
    '--optimizer',
    type=str,
    default='adam',
    choices=['sgd', 'sgdm', 'rmsprop', 'adam', 'adam0.5'],
    help='optimizer')
group.add_argument(
    '--batch_size', type=int, default=10**3, help='number of samples')
group.add_argument('--lr', type=float, default=1e-3, help='learning rate')
group.add_argument(
    '--max_step', type=int, default=10**4, help='maximum number of steps')
group.add_argument(
    '--lr_schedule', action='store_true', help='use learning rate scheduling')
group.add_argument(
    '--beta_anneal',
    type=float,
    default=0,
    help='speed to change beta from 0 to final value, 0 for disabled')
group.add_argument(
    '--clip_grad',
    type=float,
    default=0,
    help='global norm to clip gradients, 0 for disabled')

group = parser.add_argument_group('system parameters')
group.add_argument(
    '--no_stdout',
    action='store_true',
    help='do not print log to stdout, for better performance')
group.add_argument(
    '--clear_checkpoint', action='store_true', help='clear checkpoint')
group.add_argument(
    '--print_step',
    type=int,
    default=1,
    help='number of steps to print log, 0 for disabled')
group.add_argument(
    '--save_step',
    type=int,
    default=100,
    help='number of steps to save network weights, 0 for disabled')
group.add_argument(
    '--visual_step',
    type=int,
    default=100,
    help='number of steps to visualize samples, 0 for disabled')
group.add_argument(
    '--save_sample', action='store_true', help='save samples on print_step')
group.add_argument(
    '--print_sample',
    type=int,
    default=1,
    help='number of samples to print to log on visual_step, 0 for disabled')
group.add_argument(
    '--print_grad',
    action='store_true',
    help='print summary of gradients for each parameter on visual_step')
group.add_argument(
    '--cuda', type=int, default=-1, help='ID of GPU to use, -1 for disabled')
group.add_argument(
    '--out_infix',
    type=str,
    default='',
    help='infix in output filename to distinguish repeated runs')
group.add_argument(
    '-o',
    '--out_dir',
    type=str,
    default='out',
    help='directory prefix for output, empty for disabled')

_StoreAction(option_strings=['-o', '--out_dir'], dest='out_dir', nargs=None, const=None, default='out', type=<class 'str'>, choices=None, required=False, help='directory prefix for output, empty for disabled', metavar=None)

In [30]:
arg_str = "--ham fm --lattice sqr --L 4 --beta .1 --net pixelcnn --net_depth 1 --net_width 64 --half_kernel_size 6 --bias --z2 --beta_anneal 0.998 --clip_grad 1 --cuda 0".split()
args = parser.parse_args(arg_str)

In [31]:
# args.ham = 'fm'
# args.lattice = 'sqr'
# args.L = 4
# args.beta = .1
# args.net = 'pixelcnn'
# args.net_depth = 3
# args.net_width = 64
# args.half_kernel_size = 6
# args.bias = True
# args.z2 = True
# args.beta_anneal = 0.998
# args.clip_grad = 1
# args.cuda = 0

In [32]:
# PixelCNN
class ResBlock(nn.Module):
    def __init__(self, block):
        super(ResBlock, self).__init__()
        self.block = block

    def forward(self, x):
        return x + self.block(x)


class MaskedConv2d(nn.Conv2d):
    def __init__(self, *args, **kwargs):
        self.exclusive = kwargs.pop('exclusive')
        super(MaskedConv2d, self).__init__(*args, **kwargs)

        _, _, kh, kw = self.weight.shape
        self.register_buffer('mask', torch.ones([kh, kw]))
        self.mask[kh // 2, kw // 2 + (not self.exclusive):] = 0
        self.mask[kh // 2 + 1:] = 0
        self.weight.data *= self.mask

        # Correction to Xavier initialization
        self.weight.data *= torch.sqrt(self.mask.numel() / self.mask.sum())

    def forward(self, x):
        return nn.functional.conv2d(x, self.mask * self.weight, self.bias,
                                    self.stride, self.padding, self.dilation,
                                    self.groups)

    def extra_repr(self):
        return (super(MaskedConv2d, self).extra_repr() +
                ', exclusive={exclusive}'.format(**self.__dict__))


class PixelCNN(nn.Module):
    def __init__(self, **kwargs):
        super(PixelCNN, self).__init__()
        self.L = kwargs['L']
        self.net_depth = kwargs['net_depth']
        self.net_width = kwargs['net_width']
        self.half_kernel_size = kwargs['half_kernel_size']
        self.bias = kwargs['bias']
        self.z2 = kwargs['z2']
        self.res_block = kwargs['res_block']
        self.x_hat_clip = kwargs['x_hat_clip']
        self.final_conv = kwargs['final_conv']
        self.epsilon = kwargs['epsilon']
        self.device = kwargs['device']

        # Force the first x_hat to be 0.5
        if self.bias and not self.z2:
            self.register_buffer('x_hat_mask', torch.ones([self.L] * 2))
            self.x_hat_mask[0, 0] = 0
            self.register_buffer('x_hat_bias', torch.zeros([self.L] * 2))
            self.x_hat_bias[0, 0] = 0.5

        layers = []
        layers.append(
            MaskedConv2d(
                1,
                1 if self.net_depth == 1 else self.net_width,
                self.half_kernel_size * 2 + 1,
                padding=self.half_kernel_size,
                bias=self.bias,
                exclusive=True))
        
        for count in range(self.net_depth - 2):
            if self.res_block:
                layers.append(
                    self._build_res_block(self.net_width, self.net_width))
            else:
                layers.append(
                    self._build_simple_block(self.net_width, self.net_width))
                
        if self.net_depth >= 2:
            layers.append(
                self._build_simple_block(
                    self.net_width, self.net_width if self.final_conv else 1))
            
        if self.final_conv:
            layers.append(nn.PReLU(self.net_width, init=0.5))
            layers.append(nn.Conv2d(self.net_width, 1, 1))
            
        layers.append(nn.Sigmoid())
        self.net = nn.Sequential(*layers)

    def _build_simple_block(self, in_channels, out_channels):
        layers = []
        layers.append(nn.PReLU(in_channels, init=0.5))
        layers.append(
            MaskedConv2d(
                in_channels,
                out_channels,
                self.half_kernel_size * 2 + 1,
                padding=self.half_kernel_size,
                bias=self.bias,
                exclusive=False))
        block = nn.Sequential(*layers)
        return block

    def _build_res_block(self, in_channels, out_channels):
        layers = []
        layers.append(nn.Conv2d(in_channels, in_channels, 1, bias=self.bias))
        layers.append(nn.PReLU(in_channels, init=0.5))
        layers.append(
            MaskedConv2d(
                in_channels,
                out_channels,
                self.half_kernel_size * 2 + 1,
                padding=self.half_kernel_size,
                bias=self.bias,
                exclusive=False))
        block = ResBlock(nn.Sequential(*layers))
        return block

    def forward(self, x):
        x_hat = self.net(x)

        if self.x_hat_clip:
            # Clip value and preserve gradient
            with torch.no_grad():
                delta_x_hat = torch.clamp(x_hat, self.x_hat_clip,
                                          1 - self.x_hat_clip) - x_hat
            assert not delta_x_hat.requires_grad
            x_hat = x_hat + delta_x_hat

        # Force the first x_hat to be 0.5
        if self.bias and not self.z2:
            x_hat = x_hat * self.x_hat_mask + self.x_hat_bias

        return x_hat

    # sample = +/-1, +1 = up = white, -1 = down = black
    # sample.dtype == default_dtype_torch
    # x_hat = p(x_{i, j} == +1 | x_{0, 0}, ..., x_{i, j - 1})
    # 0 < x_hat < 1
    # x_hat will not be flipped by z2
    def sample(self, batch_size):
        sample = torch.zeros(
            [batch_size, 1, self.L, self.L],
            dtype=default_dtype_torch,
            device=self.device)
        for i in range(self.L):
            for j in range(self.L):
                x_hat = self.forward(sample)
                sample[:, :, i, j] = torch.bernoulli(
                    x_hat[:, :, i, j]).to(default_dtype_torch) * 2 - 1

        if self.z2:
            # Binary random int 0/1
            flip = torch.randint(
                2, [batch_size, 1, 1, 1],
                dtype=sample.dtype,
                device=sample.device) * 2 - 1
            sample *= flip

        return sample, x_hat

    def _log_prob(self, sample, x_hat):
        mask = (sample + 1) / 2
        log_prob = (torch.log(x_hat + self.epsilon) * mask +
                    torch.log(1 - x_hat + self.epsilon) * (1 - mask))
        log_prob = log_prob.view(log_prob.shape[0], -1).sum(dim=1)
        return log_prob

    def log_prob(self, sample):
        x_hat = self.forward(sample)
        log_prob = self._log_prob(sample, x_hat)

        if self.z2:
            # Density estimation on inverted sample
            sample_inv = -sample
            x_hat_inv = self.forward(sample_inv)
            log_prob_inv = self._log_prob(sample_inv, x_hat_inv)
            log_prob = torch.logsumexp(
                torch.stack([log_prob, log_prob_inv]), dim=0)
            log_prob = log_prob - log(2)

        return log_prob


In [33]:
# 2D classical Ising model
def i_energy(sample, ham, lattice, boundary):
    term = sample[:, :, 1:, :] * sample[:, :, :-1, :]
    term = term.sum(dim=(1, 2, 3))
    output = term
    term = sample[:, :, :, 1:] * sample[:, :, :, :-1]
    term = term.sum(dim=(1, 2, 3))
    output += term
    if lattice == 'tri':
        term = sample[:, :, 1:, 1:] * sample[:, :, :-1, :-1]
        term = term.sum(dim=(1, 2, 3))
        output += term

    if boundary == 'periodic':
        term = sample[:, :, 0, :] * sample[:, :, -1, :]
        term = term.sum(dim=(1, 2))
        output += term
        term = sample[:, :, :, 0] * sample[:, :, :, -1]
        term = term.sum(dim=(1, 2))
        output += term
        if lattice == 'tri':
            term = sample[:, :, 0, 1:] * sample[:, :, -1, :-1]
            term = term.sum(dim=(1, 2))
            output += term
            term = sample[:, :, 1:, 0] * sample[:, :, :-1, -1]
            term = term.sum(dim=(1, 2))
            output += term
            term = sample[:, :, 0, 0] * sample[:, :, -1, -1]
            term = term.sum(dim=1)
            output += term

    if ham == 'fm':
        output *= -1

    return output

In [34]:
if args.dtype == 'float32':
    default_dtype = np.float32
    default_dtype_torch = torch.float32
elif args.dtype == 'float64':
    default_dtype = np.float64
    default_dtype_torch = torch.float64
else:
    raise ValueError('Unknown dtype: {}'.format(args.dtype))

np.seterr(all='raise')
np.seterr(under='warn')
np.set_printoptions(precision=8, linewidth=160)

torch.set_default_dtype(default_dtype_torch)
torch.set_printoptions(precision=8, linewidth=160)
torch.backends.cudnn.benchmark = True

if not args.seed:
    args.seed = np.random.randint(1, 10**8)
np.random.seed(args.seed)
torch.manual_seed(args.seed)

if args.cuda >= 0:
    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda)
args.device = torch.device('cpu' if args.cuda < 0 else 'cuda:0')

args.out_filename = None


def get_ham_args_features():
    ham_args = '{ham}_{lattice}_{boundary}_L{L}_beta{beta:g}'
    ham_args = ham_args.format(**vars(args))

    if args.net == 'made':
        features = 'nd{net_depth}_nw{net_width}_made'
    elif args.net == 'bernoulli':
        features = 'bernoulli_nw{net_width}'
    else:
        features = 'nd{net_depth}_nw{net_width}_hks{half_kernel_size}'

    if args.bias:
        features += '_bias'
    if args.z2:
        features += '_z2'
    if args.res_block:
        features += '_res'
    if args.x_hat_clip:
        features += '_xhc{x_hat_clip:g}'
    if args.final_conv:
        features += '_fconv'

    if args.optimizer != 'adam':
        features += '_{optimizer}'
    if args.lr_schedule:
        features += '_lrs'
    if args.beta_anneal:
        features += '_ba{beta_anneal:g}'
    if args.clip_grad:
        features += '_cg{clip_grad:g}'

    features = features.format(**vars(args))

    return ham_args, features


def init_out_filename():
    if not args.out_dir:
        return
    ham_args, features = get_ham_args_features()
    template = '{args.out_dir}/{ham_args}/{features}/out{args.out_infix}'
    args.out_filename = template.format(**{**globals(), **locals()})


def ensure_dir(filename):
    dirname = os.path.dirname(filename)
    if dirname:
        try:
            os.makedirs(dirname)
        except OSError:
            pass


def init_out_dir():
    if not args.out_dir:
        return
    init_out_filename()
    ensure_dir(args.out_filename)
    if args.save_step:
        ensure_dir(args.out_filename + '_save/')
    if args.visual_step:
        ensure_dir(args.out_filename + '_img/')


def clear_log():
    if args.out_filename:
        open(args.out_filename + '.log', 'w').close()


def clear_err():
    if args.out_filename:
        open(args.out_filename + '.err', 'w').close()


def my_log(s):
    if args.out_filename:
        with open(args.out_filename + '.log', 'a', newline='\n') as f:
            f.write(s + u'\n')
    if not args.no_stdout:
        print(s)


def my_err(s):
    if args.out_filename:
        with open(args.out_filename + '.err', 'a', newline='\n') as f:
            f.write(s + u'\n')
    if not args.no_stdout:
        print(s)


def print_args(print_fn=my_log):
    for k, v in args._get_kwargs():
        print_fn('{} = {}'.format(k, v))
    print_fn('')


def parse_checkpoint_name(filename):
    filename = os.path.basename(filename)
    filename = filename.replace('.state', '')
    step = int(filename)
    return step


def get_last_checkpoint_step():
    if not (args.out_filename and args.save_step):
        return -1
    filename_list = glob('{}_save/*.state'.format(args.out_filename))
    if not filename_list:
        return -1
    step = max([parse_checkpoint_name(x) for x in filename_list])
    return step


def clear_checkpoint():
    if not (args.out_filename and args.save_step):
        return
    filename_list = glob('{}_save/*.state'.format(args.out_filename))
    for filename in filename_list:
        os.remove(filename)


# Do not load some params
def ignore_param(state, net):
    ignore_param_name_list = ['x_hat_mask', 'x_hat_bias']
    param_name_list = list(state.keys())
    for x in param_name_list:
        for y in ignore_param_name_list:
            if y in x:
                state[x] = net.state_dict()[x]
                break


In [35]:
start_time = time.time()

init_out_dir()
last_step = -1
clear_log()

# if args.clear_checkpoint:
#     clear_checkpoint()
# last_step = get_last_checkpoint_step()
# if last_step >= 0:
#     my_log('\nCheckpoint found: {}\n'.format(last_step))
# else:
    # clear_log()
# print_args()e


# if args.net == 'made':
#     net = MADE(**vars(args))
# elif args.net == 'pixelcnn':
#     net = PixelCNN(**vars(args))
# elif args.net == 'bernoulli':
#     net = BernoulliMixture(**vars(args))
# else:
#     raise ValueError('Unknown net: {}'.format(args.net))
net = PixelCNN(**vars(args))
net.to(args.device)
my_log('{}\n'.format(net))

params = list(net.parameters())
params = list(filter(lambda p: p.requires_grad, params))
nparams = int(sum([np.prod(p.shape) for p in params]))
my_log('Total number of trainable parameters: {}'.format(nparams))
named_params = list(net.named_parameters())

if args.optimizer == 'sgd':
    optimizer = torch.optim.SGD(params, lr=args.lr)
elif args.optimizer == 'sgdm':
    optimizer = torch.optim.SGD(params, lr=args.lr, momentum=0.9)
elif args.optimizer == 'rmsprop':
    optimizer = torch.optim.RMSprop(params, lr=args.lr, alpha=0.99)
elif args.optimizer == 'adam':
    optimizer = torch.optim.Adam(params, lr=args.lr, betas=(0.9, 0.999))
elif args.optimizer == 'adam0.5':
    optimizer = torch.optim.Adam(params, lr=args.lr, betas=(0.5, 0.999))
else:
    raise ValueError('Unknown optimizer: {}'.format(args.optimizer))

if args.lr_schedule:
    # 0.92**80 ~ 1e-3
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, factor=0.92, patience=100, threshold=1e-4, min_lr=1e-6)

# if last_step >= 0:
#     state = torch.load('{}_save/{}.state'.format(args.out_filename,
#                                                     last_step))
#     ignore_param(state['net'], net)
#     net.load_state_dict(state['net'])
#     if state.get('optimizer'):
#         optimizer.load_state_dict(state['optimizer'])
#     if args.lr_schedule and state.get('scheduler'):
#         scheduler.load_state_dict(state['scheduler'])

init_time = time.time() - start_time
my_log('init_time = {:.3f}'.format(init_time))

my_log('Training...')
sample_time = 0
train_time = 0
start_time = time.time()
for step in range(last_step + 1, args.max_step + 1):
    optimizer.zero_grad()

    sample_start_time = time.time()
    with torch.no_grad():
        sample, x_hat = net.sample(args.batch_size)
    assert not sample.requires_grad
    assert not x_hat.requires_grad
    sample_time += time.time() - sample_start_time

    train_start_time = time.time()

    log_prob = net.log_prob(sample)
    # 0.998**9000 ~ 1e-8
    beta = args.beta * (1 - args.beta_anneal**step)
    with torch.no_grad():
        energy = i_energy(sample, args.ham, args.lattice,args.boundary)
        loss = log_prob + beta * energy
    assert not energy.requires_grad
    assert not loss.requires_grad
    loss_reinforce = torch.mean((loss - loss.mean()) * log_prob)
    loss_reinforce.backward()

    if args.clip_grad:
        nn.utils.clip_grad_norm_(params, args.clip_grad)

    optimizer.step()

    if args.lr_schedule:
        scheduler.step(loss.mean())

    train_time += time.time() - train_start_time

    if args.print_step and step % args.print_step == 0:
        free_energy_mean = loss.mean() / args.beta / args.L**2
        free_energy_std = loss.std() / args.beta / args.L**2
        entropy_mean = -log_prob.mean() / args.L**2
        energy_mean = energy.mean() / args.L**2
        mag = sample.mean(dim=0)
        mag_mean = mag.mean()
        mag_sqr_mean = (mag**2).mean()
        if step > 0:
            sample_time /= args.print_step
            train_time /= args.print_step
        used_time = time.time() - start_time
        my_log(
            'step = {}, F = {:.8g}, F_std = {:.8g}, S = {:.8g}, E = {:.8g}, M = {:.8g}, Q = {:.8g}, lr = {:.3g}, beta = {:.8g}, sample_time = {:.3f}, train_time = {:.3f}, used_time = {:.3f}'
            .format(
                step,
                free_energy_mean.item(),
                free_energy_std.item(),
                entropy_mean.item(),
                energy_mean.item(),
                mag_mean.item(),
                mag_sqr_mean.item(),
                optimizer.param_groups[0]['lr'],
                beta,
                sample_time,
                train_time,
                used_time,
            ))
        sample_time = 0
        train_time = 0

        if args.save_sample:
            state = {
                'sample': sample,
                'x_hat': x_hat,
                'log_prob': log_prob,
                'energy': energy,
                'loss': loss,
            }
            torch.save(state, '{}_save/{}.sample'.format(
                args.out_filename, step))

    if (args.out_filename and args.save_step
            and step % args.save_step == 0):
        state = {
            'net': net.state_dict(),
            'optimizer': optimizer.state_dict(),
        }
        if args.lr_schedule:
            state['scheduler'] = scheduler.state_dict()
        torch.save(state, '{}_save/{}.state'.format(
            args.out_filename, step))

    if (args.out_filename and args.visual_step
            and step % args.visual_step == 0):
        torchvision.utils.save_image(
            sample,
            '{}_img/{}.png'.format(args.out_filename, step),
            nrow=int(sqrt(sample.shape[0])),
            padding=0,
            normalize=True)

        if args.print_sample:
            x_hat_np = x_hat.view(x_hat.shape[0], -1).cpu().numpy()
            x_hat_std = np.std(x_hat_np, axis=0).reshape([args.L] * 2)

            x_hat_cov = np.cov(x_hat_np.T)
            x_hat_cov_diag = np.diag(x_hat_cov)
            x_hat_corr = x_hat_cov / (
                np.sqrt(x_hat_cov_diag[:, None] * x_hat_cov_diag[None, :]) +
                args.epsilon)
            x_hat_corr = np.tril(x_hat_corr, -1)
            x_hat_corr = np.max(np.abs(x_hat_corr), axis=1)
            x_hat_corr = x_hat_corr.reshape([args.L] * 2)

            energy_np = energy.cpu().numpy()
            energy_count = np.stack(
                np.unique(energy_np, return_counts=True)).T

            my_log(
                '\nsample\n{}\nx_hat\n{}\nlog_prob\n{}\nenergy\n{}\nloss\n{}\nx_hat_std\n{}\nx_hat_corr\n{}\nenergy_count\n{}\n'
                .format(
                    sample[:args.print_sample, 0],
                    x_hat[:args.print_sample, 0],
                    log_prob[:args.print_sample],
                    energy[:args.print_sample],
                    loss[:args.print_sample],
                    x_hat_std,
                    x_hat_corr,
                    energy_count,
                ))

        if args.print_grad:
            my_log('grad max_abs min_abs mean std')
            for name, param in named_params:
                if param.grad is not None:
                    grad = param.grad
                    grad_abs = torch.abs(grad)
                    my_log('{} {:.3g} {:.3g} {:.3g} {:.3g}'.format(
                        name,
                        torch.max(grad_abs).item(),
                        torch.min(grad_abs).item(),
                        torch.mean(grad).item(),
                        torch.std(grad).item(),
                    ))
                else:
                    my_log('{} None'.format(name))
            my_log('')

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
step = 6439, F = -7.0328207, F_std = 0.012923931, S = 0.6819821, E = -0.21300001, M = 0.0081250006, Q = 0.00066175009, lr = 0.001, beta = 0.099999748, sample_time = 0.003, train_time = 0.002, used_time = 54.703
step = 6440, F = -7.0323992, F_std = 0.013199726, S = 0.68363994, E = -0.19600001, M = 0.027625002, Q = 0.0015357502, lr = 0.001, beta = 0.099999748, sample_time = 0.006, train_time = 0.003, used_time = 54.713
step = 6441, F = -7.0334549, F_std = 0.012489786, S = 0.68174553, E = -0.21600001, M = -0.010500001, Q = 0.0011900001, lr = 0.001, beta = 0.099999749, sample_time = 0.005, train_time = 0.003, used_time = 54.722
step = 6442, F = -7.0328531, F_std = 0.013144091, S = 0.68473536, E = -0.18550001, M = -0.01075, Q = 0.00088850001, lr = 0.001, beta = 0.099999749, sample_time = 0.004, train_time = 0.004, used_time = 54.730
step = 6443, F = -7.0321746, F_std = 0.013216763, S = 0.68216747, E = -0.21050002, M = -0.00075

In [37]:
with torch.no_grad():
    sample, x_hat = net.sample(args.batch_size)

In [38]:
energy = i_energy(sample, args.ham, args.lattice,args.boundary)

In [39]:
np.mean( energy.cpu().numpy() )

-3.204

In [40]:
1/beta

10.00000002020286

In [41]:
net

PixelCNN(
  (net): Sequential(
    (0): MaskedConv2d(1, 1, kernel_size=(13, 13), stride=(1, 1), padding=(6, 6), exclusive=True)
    (1): Sigmoid()
  )
)

In [42]:
[0.0999999997979714, -2.868]

[0.0999999997979714, -2.868]