In [None]:
%matplotlib inline
import torch
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from ffebm.data import load_mnist
from ffebm.data_noise import DATA_NOISE_sampler
from ffebm.sgld import SGLD_sampler
from ffebm.nets.conjugate_vanilla_ebm import Energy_function

CUDA = torch.cuda.is_available()
if CUDA:
    DEVICE = torch.device('cuda:1')
print('torch:', torch.__version__, 'CUDA:', CUDA)
#     DNSs = [3e-2, 1e-2]
#     SGLDNSs = [7.5e-3, 1.5e-2]
#     REGs = [1e-2, 1e-3]
## EBM hyper-parameters
data_noise_std = 3e-2
sgld_noise_std = 7.5e-3
sgld_step_size = 2
sgld_num_steps = 50
buffer_size = 5000
buffer_percent = 0.95
reg_alpha = 1e-3
lr = 5e-5
latent_dim = 128
LOAD_VERSION = 'mnist-conjugate_sgld-latentdim=%d-lr=%.2E-data_noise_std=%.2E-sgld_noise_std=%.2E-sgld_step_size=%.2E-sgld_num_steps=%.2E-buffer_size=%d-buffer_percent=%.2f-reg_alpha=%.2E' % (latent_dim, lr, data_noise_std, sgld_noise_std, sgld_step_size, sgld_num_steps, buffer_size, buffer_percent,reg_alpha)
# ebm-mnist-conjugate_sgld-latentdim=128-lr=5.00E-05-data_noise_std=3.00E-02-sgld_noise_std=7.50E-03-sgld_step_size=2.00E+00-sgld_num_steps=5.00E+01-buffer_size=5000-buffer_percent=0.95-reg_alpha=1.00E-03
print('Initialize data noise sampler...')
if data_noise_std == 0.0:
    data_noise_sampler = None
elif data_noise_std > 0:
    data_noise_sampler = DATA_NOISE_sampler(data_noise_std, CUDA, DEVICE)
else:
    raise ValueError

print('Initialize sgld sampler...')
sgld_sampler = SGLD_sampler(noise_std=1e-3,
                            step_size=2,
                            buffer_size=None,
                            buffer_percent=None,
                            grad_clipping=False,
                            CUDA=CUDA,
                            DEVICE=DEVICE)

print('Initialize EBM and optimizer...')
ebm = Energy_function(latent_dim=latent_dim, CUDA=CUDA, DEVICE=DEVICE)
if CUDA:
    with torch.cuda.device(DEVICE):
        ebm.cuda()
ebm.load_state_dict(torch.load('../weights/grid_search/ebm-%s' % LOAD_VERSION))

In [None]:
%run ../cebm_sgld.py --seed=1 --device=0 --dataset=mnist --data_dir=../../../sebm_data/ --batch_size=100 --data_noise_std=1e-2 --optimizer=Adam --lr=1e-4 --hidden_dim=[128] --latent_dim=128 --num_epochs=150 --sgld_noise_std=7.5e-3 --sgld_lr=2.0 --sgld_num_steps=5 --regularize_factor=1e-3

Experiment with cebm-dataset=mnist-seed=1-lr=0.0001-latentdim=128-data_noise_std=0.01-sgld_noise_std=0.0075-sgld_lr=2.0-sgld_num_steps=5-buffer_size=5000-buffer_percent=0.95-reg_alpha=0.001-act=Swish-arch=simplenet
Load MNIST dataset...
Initialize EBM...
Initialize sgld sampler...
Start training...


In [None]:
%debug

In [None]:
def log_partition(nat1, nat2):
    """
    compute the log partition of a normal distribution
    """
    return - 0.25 * (nat1 ** 2) / nat2 - 0.5 * (-2 * nat2).log()  
prior_nat1 = torch.zeros(10)
prior_nat2 = - 0.5 * torch.ones(10)
tx1 = torch.randn(10)
E_hao = log_partition(prior_nat1, prior_nat2) - log_partition(tx1+prior_nat1, prior_nat2)

In [None]:
from numbers import Number
def log_norm(lam1, lam2):
    log = math.log if isinstance(lam2, Number) else torch.log
    return 0.5 * log(-2 * lam2) - lam1.pow(2) / (4 * lam2)

NE_babak = log_norm(tx1+prior_nat1, prior_nat2)

In [None]:
(-NE_babak) - E_hao

In [None]:
test_batch_size = 10
images_ebm = sgld_sampler.sgld_update(ebm=ebm, 
                                     batch_size=test_batch_size, 
                                     pixels_size=28, 
                                     num_steps=5000, 
                                     persistent=False)
images_ebm = images_ebm.squeeze(1).cpu().detach()
images_ebm = torch.clamp(images_ebm, min=-1, max=1) * 0.5 + 0.5

In [None]:
gs = gridspec.GridSpec(int(test_batch_size/10), 10)
gs.update(left=0.0 , bottom=0.0, right=1.0, top=1.0, wspace=0, hspace=0)
fig = plt.figure(figsize=(10, 10*int(test_batch_size/10)/ 10))
for i in range(test_batch_size):
    ax = fig.add_subplot(gs[int(i/10), i%10])
    ax.imshow(images_ebm[i], cmap='gray', vmin=0, vmax=1.0)
    ax.set_xticks([])
    ax.set_yticks([])
# plt.savefig('ebm_samples.png')