In [1]:
import torch as tr
import torch.nn as nn
import torch.nn.functional as F
from torch import distributions
from torch.nn.parameter import Parameter

import numpy as np

import matplotlib.pyplot as plt
%matplotlib inline

import time

In [2]:
if tr.backends.mps.is_available():
    device = tr.device("mps")
else:
    print ("MPS device not found.")

In [3]:
mass = -0.2
lam = 0.5
Nd =2
mtil = mass +2 *Nd 
def action(phi):
        A = 0.5*mtil*tr.einsum('bxy,bxy->b',phi,phi) + (lam/24.0)*tr.einsum('bxy,bxy->b',phi**2,phi**2)
        for mu in range(1,Nd+1):
            A = A - tr.einsum('bxy,bxy->b',phi,tr.roll(phi,shifts=-1,dims=mu))
        return A


In [4]:
class RealNVP(nn.Module):
    def __init__(self, nets, nett, mask, prior):
        super(RealNVP, self).__init__()
        self.prior = prior
        self.mask = nn.Parameter(mask, requires_grad=False)
        self.t = tr.nn.ModuleList([nett() for _ in range(len(masks))])
        self.s = tr.nn.ModuleList([nets() for _ in range(len(masks))])
    
    # this is the forward start from noise target
    def g(self, z):
        x = z
        for i in range(len(self.t)):
            x_ = x*self.mask[i]
            s = self.s[i](x_)*(1 - self.mask[i])
            t = self.t[i](x_)*(1 - self.mask[i])
            x = x_ + (1 - self.mask[i]) * (x * tr.exp(s) + t)
        return x
    
    # this is backward from target to noise
    def f(self, x):
        log_det_J, z = x.new_zeros(x.shape[0]), x
        for i in reversed(range(len(self.t))):
            z_ = self.mask[i] * z
            s = self.s[i](z_) * (1-self.mask[i])
            t = self.t[i](z_) * (1-self.mask[i])
            z = (1 - self.mask[i]) * (z - t) * tr.exp(-s) + z_
            log_det_J -= s.sum(dim=1)
        return z, log_det_J
    
    def log_prob(self,x):
        z, logp = self.f(x)
        return self.prior.log_prob(z) + logp #+ self.C
        
    def sample(self, batchSize): 
        z = self.prior.sample((batchSize, 1))
        #logp = self.prior.log_prob(z)
        x = self.g(z)
        return x
    

In [5]:
L=8 # the length of the lattice which is going to be L x L torus
V=L*L # the volume

In [6]:
X = np.array(np.arange(L))[:,np.newaxis]
Y = np.array(np.arange(L))[np.newaxis,:]
#X = X[:,np.newaxis]

X = np.repeat(X,L,axis=1)
Y = np.repeat(Y,L,axis=0)
mm = (X+Y)%2
mm

array([[0, 1, 0, 1, 0, 1, 0, 1],
       [1, 0, 1, 0, 1, 0, 1, 0],
       [0, 1, 0, 1, 0, 1, 0, 1],
       [1, 0, 1, 0, 1, 0, 1, 0],
       [0, 1, 0, 1, 0, 1, 0, 1],
       [1, 0, 1, 0, 1, 0, 1, 0],
       [0, 1, 0, 1, 0, 1, 0, 1],
       [1, 0, 1, 0, 1, 0, 1, 0]])

In [7]:
lm = mm.reshape(V)

In [8]:
tt = distributions.Normal(tr.zeros(V).to(device),tr.ones(V).to(device))
prior= distributions.Independent(tt, 1)
z = prior.sample((10,1)).squeeze()
z.shape
prior.log_prob(z)

tensor([-83.9537, -89.2414, -88.3813, -90.8288, -93.3673, -89.3514, -88.1707,
        -87.8941, -85.0303, -86.6884], device='mps:0')

In [9]:
# this are functions returning nets
#nets = lambda: nn.Sequential(nn.Linear(V, 256), nn.LeakyReLU(), nn.Linear(256, 256), nn.LeakyReLU(), nn.Linear(256, V), nn.Tanh())
#nett = lambda: nn.Sequential(nn.Linear(V, 256), nn.LeakyReLU(), nn.Linear(256, 256), nn.LeakyReLU(), nn.Linear(256, V))
nets = lambda: nn.Sequential(nn.Linear(V, 2*V), nn.LeakyReLU(), nn.Linear(2*V, V), nn.Tanh())
nett = lambda: nn.Sequential(nn.Linear(V, 2*V), nn.LeakyReLU(), nn.Linear(2*V, V))


# the number of masks determines layers
masks = tr.from_numpy(np.array([lm, 1-lm] * 3).astype(np.float32))
normal = distributions.Normal(tr.zeros(V,device=device),tr.ones(V,device=device))
prior= distributions.Independent(normal, 1)
flow = RealNVP(nets, nett, masks, prior)
flow =flow.to(device)

In [None]:
batch_size = 2*2048
optimizer = tr.optim.Adam([p for p in flow.parameters() if p.requires_grad==True], lr=1e-4)
tic=time.perf_counter()
for t in range(5001):   
    #with torch.no_grad():
    z = prior.sample((batch_size, 1)).squeeze()
    x = flow.g(z) # generate a sample
    loss = (flow.log_prob(x)+action(x.view(batch_size,L,L))).mean() # KL divergence (or not?)
    optimizer.zero_grad()
    loss.backward(retain_graph=True)
    optimizer.step() 
    if t % 500 == 0:
        toc=time.perf_counter()
        #print(z.shape)
        print('iter %s:' % t, 'loss = %.3f' % loss,'time = %.3f' % (toc-tic),'seconds')
        tic=time.perf_counter()

  A = A - tr.einsum('bxy,bxy->b',phi,tr.roll(phi,shifts=-1,dims=mu))


iter 0: loss = 77.825 time = 0.194 seconds
iter 10: loss = 53.114 time = 0.472 seconds
iter 20: loss = 34.328 time = 0.446 seconds
iter 30: loss = 22.750 time = 0.451 seconds
iter 40: loss = 14.825 time = 0.445 seconds
iter 50: loss = 8.634 time = 0.447 seconds
iter 60: loss = 4.280 time = 0.470 seconds
iter 70: loss = 0.529 time = 0.456 seconds
iter 80: loss = -2.431 time = 0.444 seconds
iter 90: loss = -5.033 time = 0.445 seconds
iter 100: loss = -6.800 time = 0.444 seconds
iter 110: loss = -8.632 time = 0.443 seconds
iter 120: loss = -9.974 time = 0.440 seconds
iter 130: loss = -11.084 time = 0.444 seconds
iter 140: loss = -12.048 time = 0.454 seconds
iter 150: loss = -12.965 time = 0.446 seconds
iter 160: loss = -13.554 time = 0.451 seconds
iter 170: loss = -14.314 time = 0.454 seconds
iter 180: loss = -14.910 time = 0.484 seconds
iter 190: loss = -15.258 time = 0.449 seconds
iter 200: loss = -15.885 time = 0.456 seconds
iter 210: loss = -16.098 time = 0.464 seconds
iter 220: loss 

In [None]:
z = prior.sample((batch_size, 1)).squeeze()
x = flow.g(z)
x.shape
x,j = flow.f(z)
prior.log_prob(z).shape,j.shape

In [None]:
z = prior.sample((10, 1)).squeeze()
x = flow.g(z)

In [None]:
zz,j=flow.f(x)

In [None]:
print(zz.requires_grad,z.requires_grad)

In [None]:
print(tr.sum(tr.abs(zz-z))/V)

In [None]:
diff = action(x.view(x.shape[0],L,L))+flow.log_prob(x)
print(diff)

In [None]:
print(diff - diff.mean())

In [None]:
z = prior.sample((2000, 1)).squeeze()
xz = flow.g(z).detach()
diff = action(xz.view(xz.shape[0],L,L))+flow.log_prob(xz)
diff.std()

In [None]:
x = flow.sample(2000).detach().cpu().numpy()
plt.scatter(x[:, 0,0], x[:, 0,1], c='r')

In [None]:
m_diff = diff.mean()
diff -= m_diff

In [None]:
print("max  action diff: ", tr.max(diff.abs()).cpu().detach().numpy())
print("min  action diff: ", tr.min(diff.abs()).cpu().detach().numpy())
print("mean action diff: ", m_diff.detach().cpu().detach().numpy())
print("std  action diff: ", diff.std().cpu().detach().numpy())

In [None]:
foo = tr.exp(-diff)
#print(foo)
w = foo/tr.mean(foo)

print("mean re-weighting factor: " , w.mean().cpu().detach().numpy())
print("std  re-weighting factor: " , w.std().cpu().detach().numpy())



In [None]:
logbins = np.logspace(np.log10(5e-2),np.log10(5e1),int(w.shape[0]/10))
plt.hist(w.detach().cpu(),bins=logbins)
plt.xscale('log')

In [None]:
c=0
for tt in flow.parameters():
    #print(tt.shape)
    if tt.requires_grad==True :
        c+=tt.numel()
        
print("parameter count: ",c)