In [1]:
%matplotlib inline
%run ../../import_envs.py
print('probtorch:', probtorch.__version__, 
      'torch:', torch.__version__, 
      'cuda:', torch.cuda.is_available())

from training_v2 import *

probtorch: 0.0+5a2c637 torch: 0.4.1 cuda: True


In [2]:
## Load dataset
data_path = "../rings_fixed_radius"
Data = torch.from_numpy(np.load(data_path + '/obs.npy')).float()
FIXED_RADIUS = 1.5

NUM_DATASETS, N, D = Data.shape
K = 3 ## number of clusters
SAMPLE_SIZE = 10
NUM_HIDDEN_GLOBAL = 8
NUM_HIDDEN_LOCAL = 64
NUM_HIDDEN_DEC = 64
NUM_STATS = 16

MCMC_SIZE = 10
BATCH_SIZE = 20
NUM_EPOCHS = 1000
LEARNING_RATE = 1e-4
ONLY_FORWARD = True
CUDA = torch.cuda.is_available()
PATH = 'halo-ep-generative-onlyf-%dsteps-%dsamples' % (MCMC_SIZE, SAMPLE_SIZE)
DEVICE = torch.device('cuda:1')

obs_rad = torch.ones(1) * FIXED_RADIUS
noise_sigma = torch.ones(1) * 0.05
if CUDA:
    obs_rad = obs_rad.cuda().to(DEVICE)
    noise_sigma = noise_sigma.cuda().to(DEVICE)
Train_Params = (NUM_EPOCHS, NUM_DATASETS, SAMPLE_SIZE, BATCH_SIZE, CUDA, DEVICE, PATH)
Model_Params = (obs_rad, noise_sigma, N, K, D, MCMC_SIZE, ONLY_FORWARD)

In [3]:
from local_enc import *
from global_oneshot import *
from global_enc import *
from decoder import *
## if reparameterize continuous variables
Reparameterized = False
# initialization
dec_x = Dec_x(D, NUM_HIDDEN_DEC, CUDA, DEVICE, Reparameterized)
enc_z = Enc_z(K, D, NUM_HIDDEN_LOCAL, CUDA, DEVICE)
enc_eta = Enc_mu(K, D, NUM_HIDDEN_GLOBAL, NUM_STATS, CUDA, DEVICE, Reparameterized)
oneshot_eta = Oneshot_mu(K, D, NUM_HIDDEN_GLOBAL, NUM_STATS, CUDA, DEVICE, Reparameterized)
if CUDA:
    dec_x.cuda().to(DEVICE)
    enc_z.cuda().to(DEVICE)
    enc_eta.cuda().to(DEVICE)
    oneshot_eta.cuda().to(DEVICE)
        
optimizer_enc =  torch.optim.Adam(list(oneshot_eta.parameters())+list(enc_eta.parameters())+list(enc_z.parameters()),lr=LEARNING_RATE, betas=(0.9, 0.99))
optimizer_dec =  torch.optim.Adam(list(dec_x.parameters()),lr=LEARNING_RATE, betas=(0.9, 0.99))

models = (oneshot_eta, enc_eta, enc_z, dec_x)

In [4]:
from ag_ep_v2 import *
train_v2(models, EUBO_init_eta_v2, optimizer_enc, optimizer_dec, Data, Model_Params, Train_Params)

epoch: 0\1000 (174s),  symKL_DB_eta: 77.671,  symKL_DB_z: 417.158,  gap: 123.187,  loss: 1636.438,  ess: 2.574
epoch: 1\1000 (121s),  symKL_DB_eta: 118.352,  symKL_DB_z: 419.417,  gap: 150.533,  loss: 2152.051,  ess: 2.540
epoch: 2\1000 (120s),  symKL_DB_eta: 206.768,  symKL_DB_z: 425.265,  gap: 200.022,  loss: 2535.035,  ess: 2.506
epoch: 3\1000 (120s),  symKL_DB_eta: 284.743,  symKL_DB_z: 434.542,  gap: 259.138,  loss: 2931.612,  ess: 2.475
epoch: 4\1000 (120s),  symKL_DB_eta: 359.084,  symKL_DB_z: 452.758,  gap: 308.365,  loss: 3373.371,  ess: 2.428
epoch: 5\1000 (120s),  symKL_DB_eta: 438.781,  symKL_DB_z: 477.749,  gap: 367.986,  loss: 3867.807,  ess: 2.372
epoch: 6\1000 (121s),  symKL_DB_eta: 519.982,  symKL_DB_z: 495.973,  gap: 413.705,  loss: 4454.877,  ess: 2.334
epoch: 7\1000 (121s),  symKL_DB_eta: 604.996,  symKL_DB_z: 528.771,  gap: 477.128,  loss: 5183.162,  ess: 2.277
epoch: 8\1000 (121s),  symKL_DB_eta: 685.318,  symKL_DB_z: 568.020,  gap: 550.521,  loss: 6011.960,  ess:

ValueError: NaN log prob encountered in nodewith name: x_recon

In [5]:
torch.save(dec_x.state.dict(), "../weights/dec-x-%s" % PATH)
torch.save(enc_z.state_dict(), "../weights/enc-z-%s" % PATH)
torch.save(enc_mu.state_dict(), "../weights/enc-mu-%s" % PATH)
torch.save(oneshot_mu.state_dict(), "../weights/oneshot-mu-%s" % PATH)

AttributeError: 'Dec_x' object has no attribute 'state'

In [None]:
BATCH_SIZE_TEST = 50
RESAMPLE=True
DETACH=True
obs, metric_step, reused = test(models, EUBO_init_eta, Data, Model_Params, Train_Params)
(q_mu, _, q_z, _) = reused

In [None]:
%time plot_samples(obs, q_mu, q_z, K, PATH)

In [None]:
incremental_gap = symkls_test.cpu().data.numpy()[1:]
M = incremental_gap.shape[0]
overall_gap = np.zeros(M)
for m in range(M):
    overall_gap[m] = incremental_gap[:m+1].sum()

In [None]:
fig = plt.figure(figsize=(12,6))
ax = fig.add_subplot(111)
plt.yscale("log")
ax.plot(incremental_gap, label="incremental gap")
ax.plot(overall_gap, label='overall gap')
ax.legend(fontsize=14)
ax.set_xlabel('Steps')