# Unsupervised SRBM

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.font_manager as fm
%matplotlib inline
import os, sys

import torch
from torch.utils.data import TensorDataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torchvision import datasets, transforms
from torchvision.utils import make_grid , save_image

import random

# Change accordingly to your directory structure
sys.path.append('../')
import RBM

In [None]:
N = 10

In [None]:
project_name = "normal"
Nh = 10
mu = 3

print(project_name)
plot_dir = "../plots/" + project_name + '/'
model_dir = "../models/" + project_name + '/'
os.system("mkdir -p "+plot_dir)
os.system("mkdir -p "+model_dir)

In [None]:
conf_id = 'phi_1d_10N_2m'
data_file = np.load('../data/scalar_field/'+conf_id+'/'+conf_id+'.npy')
n_data = len(data_file)
print(n_data)

In [None]:
def S(field, m):
    N = len(field)
    s = 0.
    for i in range(N):
        s += -0.5*field[i]*(field[(i+1)%N] + field[(i-1)%N] - (2.+ m[i]**2)*field[i])

    return s/N

def S_fast(field,m=2.):
    N = field.shape[0]
    s = m**2 * field**2
    s += 2.*field**2
    s -= field*torch.roll(field,-1)
    s -= field*torch.roll(field,1)
    return torch.sum(0.5*s)/N

In [None]:
m = 2
K_phi = np.zeros((N,N))

for i in range(N):
    for j in range(N):
        if i==j:
            K_phi[i][j] = 2 + m**2
#             W[i][j] = 2.
        elif (i % N == (j+1) %N) or (i % N == (j-1) %N):
            K_phi[i][j] = -1
print(K_phi)

In [None]:
eig_phi = np.empty(N)
for i in range(N):
    eig_phi[i] = m**2 + 2 - 2*np.cos(2*np.pi*i/N)

In [None]:
print(eig_phi)

In [None]:
print(np.linalg.eigvals(K_phi))

In [None]:
W_phi = np.random.normal(0,0.1,size=(Nh, N))
init_cond = {'w':torch.DoubleTensor(W_phi.copy()),'m':mu, 'sig':1., 'm_scheme':0}

seed = 42
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
rbm = RBM.SRBM(n_v=N,n_h=Nh,k=3,init_cond=init_cond)

print(rbm.w.size())

# # cutoff early
epochs = 1000
batch_size = 64
save_int = 100

lr = 5e-1

# init_field = torch.ones((batch_size,N))

K_phi_tc = torch.DoubleTensor(K_phi)

In [None]:
# Save and load model
rbm.name = project_name
rbm.save(model_dir)
saved_model = rbm.name
print(saved_model)
rbm = RBM.SRBM(load=model_dir+saved_model+'.npz')
!rm ../models/normal/*.npz

In [None]:
beta = 0.8
l2 = 0.0001
lr_decay=0.92
history = rbm.unsup_fit(K_phi_tc, epochs, lr, beta=beta, l2=l2, 
                        batch_size=batch_size, lr_decay=lr_decay, save_int=save_int)

In [None]:
print(rbm.train_config)

In [None]:
# Plot parameters
plt_size='big'
# plt_size='small'

if plt_size == 'big':
    tick_size=52
    tick_width=3
    tick_length=18
    label_size=62
    legend_size=52
    fs = 12.4
    lw = 4.
    ms = 4.*lw

elif plt_size == 'small':
    tick_size=32
    tick_width=1.3
    tick_length=12
    label_size=35
    legend_size=40
    fs = 6.2
    lw = 3.
    ms = 5.*lw

    
tprm = {'axis':'both',
        'labelsize':tick_size,
        'direction':'in',
        'width':tick_width,
        'length':tick_length,
        'pad':18,
       }

tprm2 = {'axis':'both',
        'direction':'in',
        'width':tick_width*0.8,
        'length':tick_length*0.3,
        'pad':18,
       }

lprm = {'fontsize':label_size,
        'fontname': 'DejaVu Sans',
#         'fontname': 'DejaVu Serif',
       }

lprm2 = {'fontsize':label_size+10,
#         'fontname': 'DejaVu Sans',
        'fontname': 'DejaVu Serif'
       }

# rbm_c = '#56B4E9'
rbm_c = '#0072B2'
rbm_c2 = '#D55E00'
rbm_c3 = '#009E73'
rbm_c4 = '#56B4E9'
mcmc_c = '#000000'
figsize = (1.618*fs,fs)
figsize2 = (fs*1.3,fs)
dpi=300

plt.rcParams.update({
    "text.usetex": True,
    "font.family": "serif"
})

# Training result

In [None]:
x = np.arange(0,len(history['loss'])*save_int,save_int)

In [None]:
plt.figure(figsize=figsize)
plt.plot(x,history['loss'],
         '.--', c=rbm_c, lw=lw, ms=ms)
# plt.title('KL')
plt.xlabel('Epoch', **lprm)
plt.ylabel('Log likelihood', **lprm)

plt.tick_params(which='major', **tprm)
plt.minorticks_on()
plt.tick_params(which='minor', **tprm2, left=False)

xran = np.arange(0, epochs+1, int(epochs/5), dtype=int)
plt.xticks(ticks=xran, labels=xran)
plt.xlim(xran[0], epochs)

# yran = np.round(np.arange(-0.12,0.06,0.04),2)
# print(yran)
# plt.yticks(ticks=yran, labels=yran)
# plt.ylim(yran[0], yran[-1])
plt.tight_layout()

# plt.savefig(plot_dir+project_name+'_lc.pdf',dpi=dpi)
plt.show()

In [None]:
len(history['loss'])

In [None]:
len(x)

In [None]:
plt.figure(figsize=figsize)
plt.plot(x,np.mean(np.mean(history['dw'], axis=1),axis=1), label='mean')
plt.plot(x,np.min(np.min(history['dw'], axis=1),axis=1), label='min')
plt.plot(x,np.max(np.max(history['dw'], axis=1),axis=1), label='max')
# plt.title('dW')
plt.xlabel('Epoch', **lprm)
plt.ylabel(r'$\frac{d \mathcal{L}}{dw}$', **lprm2)
plt.minorticks_on()
plt.tick_params(which='major', **tprm)
plt.tick_params(which='minor', **tprm2)

# plt.xticks(fontsize=tick_size)
# plt.yticks(fontsize=tick_size)
plt.legend(fontsize=legend_size)
plt.locator_params(axis='both', nbins=4)
plt.tight_layout()
# plt.savefig(plot_dir+project_name+'_grad.pdf', dpi=dpi)
plt.show()

In [None]:
s_hist = np.zeros((epochs//save_int+1,rbm.n_h))
for i in range(epochs//save_int+1):
    _, s_, _ = np.linalg.svd(history['w'][i])
    s_hist[i] = s_

In [None]:
print(len(history['loss']))

In [None]:
plt.figure(figsize=figsize)
y = s_hist**2
plt.plot(x,y,
         '.--', c=rbm_c, lw=lw, ms=ms
        )
# plt.grid(True)
plt.xlabel('Epoch', **lprm)
plt.ylabel(r'$\xi_{\alpha}^2$', **lprm2)

plt.tick_params(which='major', **tprm)
plt.minorticks_on()
plt.tick_params(which='minor', **tprm2)

xran = np.arange(0,epochs+1,int(epochs/5))
plt.xticks(ticks=xran, labels=xran)
plt.xlim(0, epochs)

if project_name == 'normal_m1':
    yran = np.round(np.arange(-0.2,4.8,0.8),2)
else:
    yran = np.round(np.arange(0.6,6.0,0.8),2)
    
yran = np.round(np.arange(-0.8,5.8,0.8),2)

# print(yran)
# plt.yticks(ticks=yran, labels=yran)
# plt.ylim(yran[0], yran[-1])
plt.locator_params(axis='x', nbins=6, tight=True)

plt.tight_layout()
# plt.title(r'$w^2$ evolution')
plt.savefig(plot_dir+project_name+'_w2.pdf',dpi=dpi)
plt.show()

In [None]:
plt.figure(figsize=figsize)
y = s_hist**2

for i in range(y.shape[1]):
    if i%2 == 0 :
        c=rbm_c
    else:
        c=rbm_c4
    plt.plot(x,y.T[i],
             '.--', c=c, lw=lw, ms=ms
            )
# plt.grid(True)
plt.xlabel('Epoch', **lprm)
plt.ylabel(r'$\xi_{\alpha}^2$', **lprm2)

plt.tick_params(which='major', **tprm)
plt.minorticks_on()
plt.tick_params(which='minor', **tprm2)

xran = np.arange(0,epochs+1,int(epochs/5))
plt.xticks(ticks=xran, labels=xran)
plt.xlim(0, epochs)

if project_name == 'normal_m1':
    yran = np.round(np.arange(-0.2,4.8,0.8),2)
else:
    yran = np.round(np.arange(0.6,6.0,0.8),2)
    
yran = np.round(np.arange(-0.8,5.8,0.8),2)

# print(yran)
# plt.yticks(ticks=yran, labels=yran)
# plt.ylim(yran[0], yran[-1])
plt.locator_params(axis='x', nbins=6, tight=True)

plt.tight_layout()
# plt.title(r'$w^2$ evolution')
plt.savefig(plot_dir+project_name+'_w2_alt.pdf',dpi=dpi)
plt.show()

In [None]:
plt.figure(figsize=figsize)
y = s_hist[-11:]**2
plt.plot(x[-11:],y,
         '.--', c=rbm_c, lw=lw*1.5, ms=ms*1.2)
# plt.grid(True)
plt.xlabel('Epoch', **lprm)
plt.ylabel(r'$\xi_{\alpha}^2$', **lprm2)

plt.tick_params(which='major', **tprm)
plt.minorticks_on()
plt.tick_params(which='minor', **tprm2)
plt.tick_params(pad=30)

# xran = np.arange(x[-11],epochs+1, 2000, dtype=int)
# plt.xticks(ticks=xran, labels=xran)
# plt.xlim(xran[0], epochs)

if project_name == 'normal_m1':
    yran = np.round(np.arange(-0.2,4.8,0.8),2)
else:
    yran = np.round(np.arange(0.6,6.0,0.8),2)
# plt.yticks(yran, yran)
# plt.ylim(yran[0], yran[-1])
# plt.locator_params(axis='both', nbins=6, tight=True)

plt.tight_layout()
plt.savefig(plot_dir+project_name+'_w2_last.pdf',dpi=dpi)
plt.show()

In [None]:
plt.figure(figsize=figsize)
y = s_hist
for i in range(y.shape[1]):
    if i%2 == 0 :
        c=rbm_c
    else:
        c=rbm_c4
    plt.plot(x,y.T[i],
             '.--', c=c, lw=lw*0.8, ms=ms*0.8
            )
# plt.grid(True)
plt.xlabel('Epoch', **lprm)
plt.ylabel(r'$\xi_{\alpha}$', **lprm2)

plt.tick_params(which='major', **tprm)
plt.minorticks_on()
plt.tick_params(which='minor', **tprm2)

xran = np.arange(0,epochs+1,int(epochs/5))
plt.xticks(ticks=xran, labels=xran)
plt.xlim(0, epochs)

if project_name == 'normal_m2':
    yran = np.round(np.arange(-0.2,2.4,0.2),2)
else:
    yran = np.round(np.arange(-0.2,2.8,0.4),2)

# yran = np.round(np.arange(0,3,0.5),2)
print(yran)
plt.yticks(ticks=yran, labels=yran)
plt.ylim(yran[0], yran[-1])
# plt.locator_params(axis='x', nbins=6, tight=False)

plt.tight_layout()
# plt.title(r'$w^2$ evolution')
plt.savefig(plot_dir+project_name+'_w_alt.pdf',dpi=dpi)
plt.show()

In [None]:
plt.figure(figsize=figsize)
y = s_hist
plt.plot(x,y,
         '.--', c=rbm_c, lw=lw, ms=ms
        )
# plt.grid(True)
plt.xlabel('Epoch', **lprm)
plt.ylabel(r'$\xi_{\alpha}$', **lprm2)

plt.tick_params(which='major', **tprm)
plt.minorticks_on()
plt.tick_params(which='minor', **tprm2)

xran = np.arange(0,epochs+1,int(epochs/5))
plt.xticks(ticks=xran, labels=xran)
plt.xlim(0, epochs)

if project_name == 'normal_m1':
    yran = np.round(np.arange(-0.2,2.6,0.4),2)
else:
    yran = np.round(np.arange(-0.2,2.8,0.4),2)

# yran = np.arange(0,np.ceil(np.max(y)+0.1),0.5)
# print(yran)
# plt.yticks(ticks=yran, labels=yran)
# plt.ylim(yran[0], yran[-1])
plt.locator_params(axis='x', nbins=6, tight=False)

plt.tight_layout()
# plt.title(r'$w^2$ evolution')
plt.savefig(plot_dir+project_name+'_w.pdf',dpi=dpi)
plt.show()

In [None]:
s_hist = np.zeros((epochs//save_int+1,N))
mu2 = np.diag(np.ones(N))


for i in range(epochs//save_int+1):
    WW_ = history['w'][i].T@history['w'][i]
    K_ = -rbm.sig**2 * WW_ + np.diag(history['m'][i]**2)
    if i ==0:
        K_i = K_.copy()
    s_ = np.sort(np.linalg.eigvals(K_))
    if (np.imag(s_) >= 1e-14).any():
        print(s_)
    s_hist[i] = s_

In [None]:
Kin = (-rbm.sig**2 * (rbm.w.t() @ rbm.w)).data.numpy()
Mss = np.diag((rbm.m**2).data.numpy())
K = Kin + Mss

In [None]:
plt.figure(figsize=figsize)
for eig in eig_phi:
    plt.axhline(eig,ls='--',c=mcmc_c, lw=lw, alpha=0.4)
plt.axhline(eig_phi[0], ls='--', c=mcmc_c, lw=lw, alpha=0.4,
           label=r'$\kappa_{\alpha}$')



plt.plot(x,s_hist, '.--',
         c=rbm_c, lw=lw*1.2, ms=ms, alpha=0.8)
plt.plot(x,s_hist[:,0], '--.',
         c=rbm_c, lw=lw*1.2, ms=ms, alpha=0.8, label=r'$\lambda_{\alpha}$')

plt.plot(x, np.array(history['m']).T[0]**2, '-',
         c=rbm_c2, lw=lw*1.5, ms=ms, label=r'$\mu^2$')

# plt.grid(True)
plt.xlabel('Epoch', **lprm)
plt.ylabel(r'$\lambda_{\alpha}$', **lprm2)

plt.tick_params(which='major', **tprm)
plt.minorticks_on()
plt.tick_params(which='minor', **tprm2, left=False)

# plt.xticks(fontsize=tick_size)
# plt.yticks(fontsize=tick_size)
# plt.title('K eigenvalue')

xran = np.arange(0,epochs+1,int(epochs/5))
plt.xticks(ticks=xran, labels=xran)
plt.xlim(0, epochs)

if project_name=='normal_m1':
    yran = np.arange(3,10,1)
elif project_name=='normal':
    yran = np.arange(3,11,1)
else:
    yran = np.arange(3,11)
    
# print(yran)
plt.yticks(ticks=yran, labels=yran)
plt.ylim(yran[0], yran[-1])

plt.locator_params(axis='x', nbins=6, tight=False)
# plt.locator_params(axis='y', nbins=6, tight=True)

handles, labels = plt.gca().get_legend_handles_labels()
order = [2,1,0]
legend=plt.legend([handles[idx] for idx in order],
           [labels[idx] for idx in order],
           loc='upper right', fontsize=legend_size,
           framealpha=1
          )
# legend.get_frame().set_alpha(None)
# legend.get_frame().set_facecolor((1, 1, 1, 0.1))

plt.tight_layout()

plt.savefig(plot_dir+project_name+'_K.pdf', dpi=dpi)
plt.show()

In [None]:
plt.figure(figsize=figsize)
plt.plot(x,np.array(history['m']).T[0],
        '-', c=rbm_c2, lw=lw, ms=ms, label=r'$\mu$')
plt.axhline(np.sqrt(m**2 + 2. + 2.),
            ls='--', color='slategray', lw=lw,
            label=r'$\max (\kappa_{\alpha})$')
# plt.title('Mass evolution')
plt.xlabel('Epoch', **lprm)
plt.ylabel(r'$\mu$', **lprm2)
plt.tick_params(**tprm)
# plt.xticks(fontsize=tick_size)
# plt.yticks(fontsize=tick_size)

plt.legend(loc='lower right',fontsize=legend_size)

plt.tick_params(which='major', **tprm)
plt.minorticks_on()
plt.tick_params(which='minor', **tprm2)

plt.tick_params(pad=30)

xran = np.arange(0,epochs+1,int(epochs/save_int))
plt.xticks(ticks=xran, labels=xran)
plt.xlim(xran[0], epochs)

if project_name=='normal_m1':
    yran = np.round(np.arange(2.6,3.,0.1, dtype=float),2)
else:
    yran = np.round(np.arange(2.4,3.2,0.2, dtype=float),2)
plt.yticks(ticks=yran, labels=yran)
plt.ylim(yran[0], yran[-1])
plt.locator_params(axis='both', nbins=6, tight=True)

# plt.grid(True)
plt.tight_layout()

plt.savefig(plot_dir+project_name+'_m.pdf', dpi=dpi)
plt.show()

In [None]:
plt.figure(figsize=figsize2)
plt.imshow(rbm.w.data.numpy(), cmap='binary')
cbar = plt.colorbar()
cbar.ax.tick_params(labelsize=tick_size-8)
plt.xticks(fontsize=tick_size)
plt.yticks(fontsize=tick_size)
# plt.title('W')
plt.tight_layout()

plt.savefig(plot_dir+project_name+'_w_final.pdf', dpi=dpi)
plt.show()

In [None]:
plt.figure(figsize=figsize2)
plt.imshow(history['w'][0], cmap='binary')
cbar = plt.colorbar()
cbar.ax.tick_params(labelsize=tick_size-8)
# plt.title('W')
plt.xticks(fontsize=tick_size)
plt.yticks(fontsize=tick_size)
plt.tight_layout()

plt.savefig(plot_dir+project_name+'_w_init.pdf', dpi=dpi)
plt.show()

In [None]:
plt.figure(figsize=figsize2)
plt.imshow(K_phi, cmap='binary', vmax=6, vmin=-1)
cbar = plt.colorbar()
cbar.ax.tick_params(labelsize=tick_size-8)
# plt.title('W')
plt.xticks(fontsize=tick_size)
plt.yticks(fontsize=tick_size)
plt.tight_layout()

plt.savefig(plot_dir+project_name+'_K_phi.pdf', dpi=dpi)
plt.show()

In [None]:
plt.figure(figsize=figsize2)
plt.imshow(K, cmap='binary')
cbar = plt.colorbar()
cbar.ax.tick_params(labelsize=tick_size-8)
# plt.title('W')
plt.xticks(fontsize=tick_size)
plt.yticks(fontsize=tick_size)
plt.tight_layout()

plt.savefig(plot_dir+project_name+'_K_rbm.pdf', dpi=dpi)
plt.show()

In [None]:
n_hist = len(history['w'])
diag_w = np.empty(n_hist)
up_w = np.empty(n_hist)
down_w = np.empty(n_hist)
diag_norm = np.empty(n_hist)
up_norm = np.empty(n_hist)
down_norm = np.empty(n_hist)
for i, w in enumerate(history['w']):
    diag_ = np.sum(np.diag(w))
    up_ = np.sum(np.triu(w, k=1))
    down_ = np.sum(np.triu(w.T, k=1))
    diag_w[i] = diag_
    up_w[i] = up_
    down_w[i] = down_
    ddiag_ = np.sum(np.diag(w)**2)
    uup_ = np.sum(np.triu(w, k=1)**2)
    ddown_ = np.sum(np.triu(w.T, k=1)**2)
    diag_norm[i] = ddiag_
    up_norm[i] = uup_
    down_norm[i] = ddown_

In [None]:
# plt.title('Comparision of sum of each part')

plt.figure(figsize=figsize)
plt.plot(x, diag_w,
         '--', c='k', lw=lw, ms=ms,
         label='Diagonal sum', alpha=1)
plt.plot(x, up_w,
         '.-', c='r', lw=lw, ms=ms,
         label='Upper triangle sum', alpha=1)
plt.plot(x, down_w,
         '-.', c='b', lw=lw, ms=ms,
         label='Lower triangle sum', alpha=1)

# plt.axhline(diag_phi, label='Cholesky diagonal',
#             ls='--', c='C0')
# plt.axhline(up_phi, label='Cholesky upper', ls='--', c='C1')
# plt.axhline(down_phi, label='Cholesky lower', ls='--', c='C2')

plt.tick_params(which='major', **tprm)
plt.minorticks_on()
plt.tick_params(which='minor', **tprm2)
# plt.xticks(fontsize=tick_size)
# plt.yticks(fontsize=tick_size)
plt.legend(loc='best', fontsize=legend_size-10,
          framealpha=1)
plt.xlabel("Epoch", **lprm)
plt.ylabel('',**lprm2)

plt.tick_params(pad=30)

xran = np.arange(0,epochs+1,100000)
plt.xticks(ticks=xran, labels=xran)
plt.xlim(xran[0], epochs)

if project_name in ('cholesky', 'symmetric', 'cholesky_opt', 'symmetric_opt') :
    yran = np.arange(-0.5,20.5,5)
elif project_name == "normal_m1":
    yran = np.arange(-5, 5, 2)
else:
    yran = np.arange(-11, 6, 2)
    
plt.yticks(ticks=yran, labels=yran)
plt.ylim(yran[0], yran[-1])
# plt.locator_params(axis='x', nbins=6, tight=True)

# plt.grid(True)
plt.tight_layout()

plt.savefig(plot_dir+project_name+'_partsum.pdf', dpi=dpi)

plt.show()

In [None]:
# plt.title('Comparision of sum of each part')
diag_mean = np.mean(diag_w)
diag_std = np.std(diag_w - diag_mean)
print(diag_mean, diag_std)

up_mean = np.mean(up_w)
up_std = np.std(up_w - up_mean)
print(up_mean, up_std)

down_mean = np.mean(down_w)
down_std = np.std(down_w - down_mean)
print(down_mean, down_std)


plt.figure(figsize=figsize)
plt.plot(x, (diag_w-diag_mean)/diag_std,
         '--', c='k', lw=lw, ms=ms,
         label=r'Standardized $D$: $\overline{x}=%.2g$, $\sigma=%.2g$'%(diag_mean, diag_std), alpha=1)
plt.plot(x, (up_w-up_mean)/up_std,
         '.-', c='r', lw=lw, ms=ms,
         label=r'Standardized $U$: $\overline{x}=%.2g$, $\sigma=%.2g$'%(up_mean, up_std), alpha=1)
plt.plot(x, (down_w-down_mean)/down_std,
         '-.', c='b', lw=lw, ms=ms,
         label=r'Standardized $L$: $\overline{x}=%.2g$, $\sigma=%.2g$'%(down_mean, down_std), alpha=1)

# plt.axhline(diag_phi, label='Cholesky diagonal',
#             ls='--', c='C0')
# plt.axhline(up_phi, label='Cholesky upper', ls='--', c='C1')
# plt.axhline(down_phi, label='Cholesky lower', ls='--', c='C2')

plt.tick_params(which='major', **tprm)
plt.minorticks_on()
plt.tick_params(which='minor', **tprm2)# plt.xticks(fontsize=tick_size)
# plt.yticks(fontsize=tick_size)
plt.legend(loc='best',fontsize=legend_size - 5)
plt.xlabel("Epoch", **lprm)
plt.ylabel('',**lprm2)

plt.tick_params(pad=30)

xran = np.arange(0,epochs+1,epochs/5,int)
plt.xticks(ticks=xran, labels=xran)
plt.xlim(xran[0], epochs)

if project_name in ('cholesky', 'cholesky_opt'):
    yran = np.arange(-3,7,2)
elif project_name in ('symmetric'):
    yran = np.round(np.arange(-3, 6, 2),2)
elif project_name == "normal_m1":
    yran = np.arange(-5, 5, 2)
else:
    yran = np.arange(-11, 6, 2)
    
plt.yticks(ticks=yran, labels=yran)
plt.ylim(yran[0], yran[-1])
# plt.ylim(-3,4)
# plt.locator_params(axis='x', nbins=6, tight=True)
# plt.grid(True)
plt.tight_layout()

plt.savefig(plot_dir+project_name+'_partsum2.pdf', dpi=dpi)

plt.show()

In [None]:
# plt.title('Comparision of sum of each part')
diag_mean = np.mean(diag_norm)
diag_std = np.std(diag_norm - diag_mean)
print(diag_mean, diag_std)

up_mean = np.mean(up_norm)
up_std = np.std(up_norm - up_mean)
print(up_mean, up_std)

down_mean = np.mean(down_norm)
down_std = np.std(down_norm - down_mean)
print(down_mean, down_std)


plt.figure(figsize=figsize)
plt.plot(x, (diag_norm-diag_mean)/diag_std,
         '--', c='k', lw=lw, ms=ms,
         label=r'Standardized $||D||^2$: $\overline{x}=%.2g$, $\sigma=%.2g$'%(diag_mean, diag_std), alpha=1)
plt.plot(x, (up_norm-up_mean)/up_std,
         '.-', c='r', lw=lw, ms=ms,
         label=r'Standardized $||U||^2$: $\overline{x}=%.2g$, $\sigma=%.2g$'%(up_mean, up_std), alpha=1)
plt.plot(x, (down_norm-down_mean)/down_std,
         '-.', c='b', lw=lw, ms=ms,
         label=r'Standardized $||L||^2$: $\overline{x}=%.2g$, $\sigma=%.2g$'%(down_mean, down_std), alpha=1)

# plt.axhline(diag_phi, label='Cholesky diagonal',
#             ls='--', c='C0')
# plt.axhline(up_phi, label='Cholesky upper', ls='--', c='C1')
# plt.axhline(down_phi, label='Cholesky lower', ls='--', c='C2')

plt.tick_params(which='major', **tprm)
plt.minorticks_on()
plt.tick_params(which='minor', **tprm2)# plt.xticks(fontsize=tick_size)
# plt.yticks(fontsize=tick_size)

handles, labels = plt.gca().get_legend_handles_labels()
order = [2,0,1]
legend=plt.legend([handles[idx] for idx in order],
           [labels[idx] for idx in order],
           loc='lower center', fontsize=legend_size-5,
           framealpha=1
          )

plt.xlabel("Epoch", **lprm)
plt.ylabel('',**lprm2)

plt.tick_params(pad=30)

xran = np.arange(0,epochs+1,epochs/5,int)
plt.xticks(ticks=xran, labels=xran)
plt.xlim(xran[0], epochs)

if project_name in ('cholesky', 'cholesky_opt'):
    yran = np.arange(-6,8,2)
elif project_name in ('symmetric'):
    yran = np.round(np.arange(-3, 6, 2),2)
elif project_name == "normal_m1":
    yran = np.arange(-5, 5, 2)
else:
    yran = np.arange(-11, 6, 2)
    
plt.yticks(ticks=yran, labels=yran)
plt.ylim(yran[0], yran[-1])
# plt.ylim(-3,4)
# plt.locator_params(axis='x', nbins=6, tight=True)
# plt.grid(True)
plt.tight_layout()

plt.savefig(plot_dir+project_name+'_partsum_norm.pdf', dpi=dpi)

plt.show()

In [None]:
det_hist = np.zeros(epochs//save_int + 1)
for i in range(epochs//save_int + 1):
    det = np.linalg.det(history['w'][i].T @ history['w'][i])
    det_hist[i] = det

In [None]:
plt.figure(figsize=figsize)
plt.plot(x,det_hist,
         'C9--.', lw=lw, ms=ms)
plt.grid(True)
plt.xlabel('Epoch', fontsize=label_size)
plt.ylabel(r'$\det(WW^T)$', fontsize=label_size)
# plt.title(r'det(ww) evolution')
plt.xticks(fontsize=tick_size)
plt.yticks(fontsize=tick_size)
plt.tight_layout()

plt.savefig(plot_dir+project_name+'_det.pdf', dpi=dpi)
plt.show()

In [None]:
def loss(K, K_phi, Zk, b, r):
    return np.sum((K - Zk*K_phi - b - r)**2)

def dL_dZ(K, K_phi, Zk, b, r):
    return -np.sum(2.*(K - Zk*K_phi - b - r)*K_phi)

def dL_db(K, K_phi, Zk, b, r):
    return -np.sum(2.*(K - Zk*K_phi - b - r))

def dL_dr(K, K_phi, Zk, b, r):
    return -2.*(K - Zk*K_phi - b - r)

def get_scale(model, verbose=True):
    Kin = (-model.sig**2 * (model.w.t() @ model.w)).data.numpy()
    Mss = np.diag((model.m**2).data.numpy())
    K = Kin + Mss
    mu_ = model.m.data.numpy()
    Zk_gd = mu_.min()**2/(2.**2 + 4)

    np.random.seed(1234)
    r_seed = np.random.normal(0.,.01,size=K.shape) # ~ sigma^2
    r_gd = r_seed.T @ r_seed
    b_gd = 0.
    lr = 1e-5

    for i in range(10000):
        l = loss(K, K_phi, Zk_gd, b_gd, r_gd)
    #     print(l)
        dZ = lr * dL_dZ(K, K_phi, Zk_gd, b_gd, r_gd)
        db = lr * dL_db(K, K_phi, Zk_gd, b_gd, r_gd)
        dr = lr * dL_dr(K, K_phi, Zk_gd, b_gd, r_gd)

        Zk_gd -= dZ
        b_gd -= db
        r_gd -= dr
        if i%1000 == 0 and verbose:
            print(i, l, Zk_gd, b_gd)
            
    return Zk_gd, b_gd, r_gd, K

In [None]:
Zk, b, r, K = get_scale(rbm)

In [None]:
plt.figure(figsize=figsize2)
K_off = K - np.diag(np.diag(K))
plt.imshow(K_off, cmap='gray', vmax=K_off.max(), vmin=K_off.min())
cbar = plt.colorbar()
cbar.ax.tick_params(labelsize=tick_size-8)
# plt.title('K off diagonal')
plt.xlabel(r"$Z_k$: %.5f"%(Zk), fontsize=tick_size)
plt.xticks(fontsize=tick_size)
plt.yticks(fontsize=tick_size)
plt.tight_layout()

plt.savefig(plot_dir+project_name+'_K_off.pdf', dpi=dpi)
plt.show()

# Regenerated field distribution

In [None]:
sample_base = torch.ones((5000,N))
regen, regen_, _, _, in_data = rbm.forward(sample_base,100)

-4.2414788198471065 0.27015507204050776
4.240729140043259 0.28375796566756795
0.0002783112880745477 0.004793137521224511
1.0011501746177673 0.0031526280838754836

In [None]:
print(regen.min())
print(regen.max())
print(regen.mean())
print(regen.std())

In [None]:
print(regen_.min())
print(regen_.max())
print(regen_.mean())
print(regen_.std())

In [None]:
print(data_file.min())
print(data_file.max())
print(data_file.mean())
print(data_file.std())

# Propagator

In [None]:
C_inv_true = np.linalg.inv(K_phi)

In [None]:
K_inv_gen = np.cov(regen_.data.numpy().T)
K_inv_gen_p = np.cov(regen.data.numpy().T)

In [None]:
plt.figure(figsize=figsize2)
plt.imshow(K_inv_gen, cmap='binary')
# plt.title('Covariance from regenerated (state) data')
cbar = plt.colorbar()
cbar.ax.tick_params(labelsize=tick_size-8)
plt.xticks(fontsize=tick_size)
plt.yticks(fontsize=tick_size)
plt.tight_layout()

plt.savefig(plot_dir+project_name+'_C_regen.pdf', dpi=dpi)

plt.show()

In [None]:
plt.figure(figsize=figsize2)
plt.imshow(np.linalg.inv(K), cmap='binary')
# plt.title('Covariance from model kernel')
cbar = plt.colorbar()
cbar.ax.tick_params(labelsize=tick_size-8)
plt.xticks(fontsize=tick_size)
plt.yticks(fontsize=tick_size)
plt.tight_layout()

plt.savefig(plot_dir+project_name+'_C_rbm.pdf', dpi=dpi)
plt.show()

In [None]:
plt.figure(figsize=figsize2)
plt.imshow(C_inv_true, cmap='binary')
# plt.title('Covariance of true kernel')
plt.xticks(fontsize=tick_size)
plt.yticks(fontsize=tick_size)
cbar = plt.colorbar()
cbar.ax.tick_params(labelsize=tick_size-8)
plt.tight_layout()

plt.savefig(plot_dir+project_name+'_C_true.pdf', dpi=dpi)

plt.show()

In [None]:
plt.figure(figsize=figsize2)
# plt.title(r"$C_{\phi} - C_{\rm rbm}$")
plt.imshow(C_inv_true - K_inv_gen)
cbar = plt.colorbar()
cbar.ax.tick_params(labelsize=tick_size-8)
plt.xticks(fontsize=tick_size)
plt.yticks(fontsize=tick_size)
plt.tight_layout()

plt.savefig(plot_dir+project_name+'_C_diff.pdf', dpi=dpi)

plt.show()

In [None]:
!rm -rf ../plots/normal
!rm -rf ../models/normal