# Supervised SRBM

In [None]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import os, sys

import torch
from torch.utils.data import TensorDataset, DataLoader, random_split
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torchvision import datasets, transforms
from torchvision.utils import make_grid , save_image

import random

# Change accordingly to your directory structure
sys.path.append('../')
import RBM

## MNIST

In [None]:
N = 10
# Initialization scheme.
# Initialize w with Cholesky solution W_phi and mass=5 sigma=1
# init_cond = {'w':torch.DoubleTensor(W_phi.copy()),'m':5., 'sig':1.}
# init_cond = {'m':3., 'sig':1., 'm_scheme':'local'}
# init_cond = {'w_sig':1e-1, 'm':3., 'sig':1., 'm_scheme':'global'}
init_cond = {'w_sig':1e-1, 'm':12., 'sig':1., 'm_scheme':0}

# Reproducibility
torch.manual_seed(42)
random.seed(42)
np.random.seed(42)

# Initialize SRBM
rbm = RBM.SRBM(n_v=784,n_h=N,k=10,init_cond=init_cond)

# For autograd if implimented
# train_op = optim.SGD(rbm.parameters(),1e-7)
# train_op = optim.Adam(rbm.parameters(),1e-5)

# Check initial coupling matrix
print(rbm.w)

# Training parameters
lr = 1e-3

epochs = 10
batch_size = 31
save_int = 2

train_ds = datasets.MNIST('../data',
                          train=True,
                          download = True,
                          transform = transforms.Compose(
                              [transforms.ToTensor()])
                         )

# Use only part of data because it is memory intensive
train_ds, validation_data = random_split(train_ds, [3000, 57000])

train_dl = torch.utils.data.DataLoader(
    train_ds,
    batch_size=batch_size
)

In [None]:
# Save and load model
rbm.save('../models/')
saved_model = rbm.name
print(saved_model)
rbm = RBM.SRBM(load='../models/'+saved_model+'.npz')

In [None]:
!rm ../models/*.npz

In [None]:
# Train the model
history = rbm.fit(train_dl, epochs, lr, beta=0.5, l2=1e-2, verbose=True, lr_decay=0, save_int=save_int)

# Training result

In [None]:
plot_dir = '../plots/'
model_name ='normal'
x = np.arange(0,epochs+1,save_int)

In [None]:
# Loss function
plt.plot(x,history['loss'])
plt.title('KL')
plt.xlabel('epoch')
plt.ylabel(r'$\mathcal{L}$')
# plt.savefig(plot_dir+model_name+'_lc.jpg')
plt.show()

In [None]:
# Gradient
plt.plot(x,np.mean(np.mean(history['dw'], axis=1),axis=1), label='mean')
plt.plot(x,np.min(np.min(history['dw'], axis=1),axis=1), label='min')
plt.plot(x,np.max(np.max(history['dw'], axis=1),axis=1), label='max')
plt.title('dW')
plt.xlabel('epoch')
plt.ylabel(r'$\frac{d L}{dw}$')
plt.legend()
# plt.savefig(plot_dir+model_name+'_lc.jpg')
plt.show()

In [None]:
# SVD of coupling matrix squared
s_hist = np.zeros((len(x),N))
for i in range(len(x)):
    _, s_, _ = np.linalg.svd(history['w'][i])
    s_hist[i] = s_

plt.plot(x,s_hist**2)
plt.grid(True)
plt.xlabel('epoch')
plt.ylabel(r'$w_{\alpha}^2$')
plt.title(r'$w^2$ evolution')
# plt.savefig(plot_dir+model_name+'_w.jpg')
plt.show()

In [None]:
# Last few steps
plt.plot(x[-10:],s_hist[-10:]**2, '.-')
plt.grid(True)
plt.xlabel('epoch')
plt.ylabel(r'$w_{\alpha}^2$')
plt.title(r'$w^2$ evolution')
# plt.savefig(plot_dir+model_name+'_w.jpg')
plt.show()

In [None]:
# SVD of coupling matrix not squared
plt.plot(x,s_hist, '-')
plt.grid(True)
plt.xlabel('epoch')
plt.ylabel(r'$w_{\alpha}$')
plt.title(r'$w_{\alpha}$ evolution')
# plt.savefig(plot_dir+model_name+'_w.jpg')
plt.show()

In [None]:
# Last few steps
plt.plot(x[-10:],s_hist[-10:], '.-')
plt.grid(True)
plt.xlabel('epoch')
plt.ylabel(r'$w_{\alpha}$')
plt.title(r'$w_{\alpha}$ evolution')
# plt.savefig(plot_dir+model_name+'_w.jpg')
plt.show()

In [None]:
# Kernel SVD values
s_hist = np.zeros((len(x),784))
mu2 = np.diag(np.ones(784))

for i in range(len(x)):
    WW_ = history['w'][i].T@history['w'][i]
    K_ = -rbm.sig**2 * WW_ + np.diag(history['m'][i]**2)
    if i ==0:
        K_i = K_.copy()
    s_ = np.sort(np.linalg.eigvals(K_))
    s_hist[i] = s_

plt.plot(x,s_hist)
plt.grid(True)
plt.xlabel('epoch')
plt.ylabel(r'$K_{\alpha}$')
plt.title('K eigenvalue')
# plt.savefig(plot_dir+model_name+'_K.jpg')
plt.show()

In [None]:
# Last few steps
plt.plot(x[-10:],s_hist[-10:], '.-')
plt.grid(True)
plt.xlabel('epoch')
plt.ylabel(r'$K_{\alpha}$')
plt.title('K eigenvalue')
# plt.savefig(plot_dir+model_name+'_K.jpg')
plt.show()

In [None]:
# Mass parameter of the model
plt.plot(x,history['m'])
plt.title('Mass evolution')
plt.xlabel('epoch')
plt.ylabel('mass')
plt.legend()
plt.grid(True)
# plt.savefig(plot_dir+model_name+'_mass.jpg')
plt.show()

In [None]:
# Last few steps
plt.plot(x[-10:],history['m'][-10:], '.-')
# plt.axhline(np.sqrt(m**2 + 2. + 2.), ls='--', color='C3', label='Minimum Cholesky mass limit')
plt.title('Mass evolution')
plt.xlabel('epoch')
plt.ylabel('mass')
# plt.legend()
plt.grid(True)
# plt.savefig(plot_dir+model_name+'_mass.jpg')
plt.show()

In [None]:
# K_rbm off diagonal part
Kin = (-rbm.sig**2 * (rbm.w.t() @ rbm.w)).data.numpy()
Mss = np.diag((rbm.m**2).data.numpy())
K = Kin + Mss

K_off = K - np.diag(np.diag(K))
plt.imshow(K_off, cmap='gray', vmax=K_off.max(), vmin=K_off.min())
plt.colorbar()
plt.title('K off diagonal')
# plt.savefig(plot_dir+model_name+'_K_img.jpg')
plt.show()

In [None]:
# Coupling matrix as image
w_rbm = rbm.w.data.numpy()
plt.imshow(w_rbm, cmap='gray')
plt.colorbar()
plt.title('W')
# plt.savefig(plot_dir+model_name+'_K_img.jpg')
plt.show()

In [None]:
# Evolution of det(ww^T)
det_hist = np.zeros(len(x))
for i in range(len(x)):
    det = np.linalg.det(history['w'][i].T @ history['w'][i])
    det_hist[i] = det

plt.plot(x,det_hist,'C9')
plt.grid(True)
plt.xlabel('epoch')
plt.ylabel(r'det($ww$)')
plt.title(r'det(ww) evolution')
# plt.savefig(plot_dir+model_name+'_w.jpg')
plt.show()

## Regenerated image

In [None]:
from torchvision.utils import make_grid

In [None]:
def show_img(file_name,img):
    npimg = np.transpose(img.numpy(),(1,2,0))
    f = "./%s.png" % file_name
    plt.imshow(npimg)

In [None]:
for data in train_dl:
    1+1

In [None]:
data[0].shape

In [None]:
p_v, v_, _, _, v = rbm.forward(data[0][0:24].reshape(-1,784))

In [None]:
show_img("real",make_grid(v.view(24,1,28,28).data))

In [None]:
v_scaled = (v_ - v_.min())/(v_ - v_.min()).max()
show_img("real",make_grid(v_scaled.view(24,1,28,28).data))