In [1]:
%matplotlib inline
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import torch
from sebm_mnist.data import load_data
from sebm_mnist.modules.sgld import SGLD_sampler
from sebm_mnist.modules.energy_function_cnn import Energy_function

CUDA = torch.cuda.is_available()
if CUDA:
    DEVICE = torch.device('cuda:0')
print('torch:', torch.__version__, 'CUDA:', CUDA)
## model hyper-parameters
# hidden_dim = 1024
# pixels_dim = 28*28
## EBM hyper-parameters
sgld_init_sample_std = 0.1
sgld_noise_std = 0.0075
LOAD_VERSION = 'ebm-cnn-v1' 
print('Load trained energy function...')
ef = Energy_function()
if CUDA:
    ef.cuda().to(DEVICE)   
# ef.load_state_dict(torch.load('../weights/ef-%s' % LOAD_VERSION))
# for p in ef.parameters():
#     p.requires_grad = False
print('Initialize SGLD sampler...')
sgld_sampler = SGLD_sampler(sgld_init_sample_std, sgld_noise_std, CUDA, DEVICE)

torch: 1.3.0 CUDA: True
Load trained energy function...
Initialize SGLD sampler...


In [5]:
a = torch.randn(5,4)
torch.nn.Sigmoid(a)

TypeError: __init__() takes 1 positional argument but 2 were given

In [2]:
    ## data directory
    print('Load MNIST dataset...')
    DATA_DIR = '/home/hao/Research/sebm_data/'
    train_data, test_data = load_data(DATA_DIR, 10)
    

Load MNIST dataset...


In [3]:
for (images, _) in train_data:
    break
if CUDA:
    images = images.cuda().to(DEVICE)
Ex = ef.forward(images)
Ex.shape

torch.Size([10, 64, 12, 12])

In [None]:
visual_sample_size = 50
ebm_images = sgld_sampler.sgld_update(ef=ef, 
                                      sample_size=1, 
                                      batch_size=visual_sample_size, 
                                      pixels_size=784, 
                                      num_steps=1000, 
                                      step_size=2,
                                      buffer_size=None,
                                      buffer_percent=None,
                                      persistent=False)
ebm_images = ebm_images.squeeze(0).cpu().view(visual_sample_size, 28, 28)
ebm_images = torch.clamp(ebm_images, min=-1, max=1) * 0.5 + 0.5

In [None]:
gs = gridspec.GridSpec(int(visual_sample_size/10), 10)
gs.update(left=0.0 , bottom=0.0, right=1.0, top=1.0, wspace=0, hspace=0)
fig = plt.figure(figsize=(15, 15*int(visual_sample_size/10)/ 10))
for i in range(visual_sample_size):
    ax = fig.add_subplot(gs[int(i/10), i%10])
    ax.imshow(ebm_images[i], cmap='gray', vmin=0, vmax=1.0)
    ax.set_xticks([])
    ax.set_yticks([])