In [1]:
%matplotlib inline
%run ../../import_envs.py
print('probtorch:', probtorch.__version__, 
      'torch:', torch.__version__, 
      'cuda:', torch.cuda.is_available())

probtorch: 0.0+5a2c637 torch: 0.4.1 cuda: True


In [2]:
## Load dataset
data_path = "../rings_fixed_radius"
Data = torch.from_numpy(np.load(data_path + '/obs.npy')).float()
FIXED_RADIUS = 1.5

NUM_DATASETS, N, D = Data.shape
K = 3 ## number of clusters
SAMPLE_SIZE = 10
NUM_HIDDEN_GLOBAL = 8
NUM_HIDDEN_LOCAL = 64
NUM_STATS = 16

MCMC_SIZE = 10
BATCH_SIZE = 20
NUM_EPOCHS = 1000
LEARNING_RATE = 1e-4
PRIOR_FLAG = False
ONLY_FORWARD = True
CUDA = torch.cuda.is_available()
PATH = 'ep-onlyf-%dsteps-%dsamples' % (MCMC_SIZE, SAMPLE_SIZE)
DEVICE = torch.device('cuda:1')

obs_rad = torch.ones(1) * FIXED_RADIUS
noise_sigma = torch.ones(1) * 0.05
if CUDA:
    obs_rad = obs_rad.cuda().to(DEVICE)
    noise_sigma = noise_sigma.cuda().to(DEVICE)
Train_Params = (NUM_EPOCHS, NUM_DATASETS, SAMPLE_SIZE, BATCH_SIZE, CUDA, DEVICE, PATH)
Model_Params = (obs_rad, noise_sigma, N, K, D, MCMC_SIZE, PRIOR_FLAG, ONLY_FORWARD)

In [3]:
from local_enc import *
from global_oneshot import *
from global_enc import *
## if reparameterize continuous variables
Reparameterized = False
# initialization
enc_z = Enc_z(K, D, NUM_HIDDEN_LOCAL, CUDA, DEVICE)
enc_eta = Enc_mu(K, D, NUM_HIDDEN_GLOBAL, NUM_STATS, CUDA, DEVICE, Reparameterized)
if CUDA:
    enc_z.cuda().to(DEVICE)
    enc_eta.cuda().to(DEVICE)
if PRIOR_FLAG:
    optimizer =  torch.optim.Adam(list(enc_eta.parameters())+list(enc_z.parameters()),lr=LEARNING_RATE, betas=(0.9, 0.99))
    models = (enc_eta, enc_z)
else:
    oneshot_eta = Oneshot_mu(K, D, NUM_HIDDEN_GLOBAL, NUM_STATS, CUDA, DEVICE, Reparameterized)
    if CUDA:
        oneshot_eta.cuda().to(DEVICE)
    optimizer =  torch.optim.Adam(list(oneshot_eta.parameters())+list(enc_eta.parameters())+list(enc_z.parameters()),lr=LEARNING_RATE, betas=(0.9, 0.99))
    models = (oneshot_eta, enc_eta, enc_z)

In [None]:
from ag_ep import *
train(models, EUBO_init_eta, optimizer, Data, Model_Params, Train_Params)

epoch: 0\1000 (26s),  symKL_DB_eta: 250691.314,  symKL_DB_z: 176464.750,  loss: -167381.188,  ess: 2.140
epoch: 1\1000 (26s),  symKL_DB_eta: 238708.322,  symKL_DB_z: 176772.908,  loss: -154984.423,  ess: 2.154
epoch: 2\1000 (26s),  symKL_DB_eta: 221759.499,  symKL_DB_z: 177079.631,  loss: -143491.819,  ess: 2.180
epoch: 3\1000 (26s),  symKL_DB_eta: 199695.606,  symKL_DB_z: 175031.881,  loss: -133132.605,  ess: 2.213
epoch: 4\1000 (26s),  symKL_DB_eta: 175460.592,  symKL_DB_z: 172076.405,  loss: -122710.675,  ess: 2.250
epoch: 5\1000 (26s),  symKL_DB_eta: 150715.672,  symKL_DB_z: 168337.454,  loss: -113289.351,  ess: 2.295
epoch: 6\1000 (26s),  symKL_DB_eta: 127510.824,  symKL_DB_z: 164329.736,  loss: -104418.787,  ess: 2.348
epoch: 7\1000 (26s),  symKL_DB_eta: 106666.779,  symKL_DB_z: 158661.002,  loss: -95777.132,  ess: 2.407
epoch: 8\1000 (26s),  symKL_DB_eta: 88675.807,  symKL_DB_z: 150908.695,  loss: -88625.642,  ess: 2.470
epoch: 9\1000 (26s),  symKL_DB_eta: 74410.445,  symKL_DB_z

In [None]:
torch.save(enc_z.state_dict(), "../weights/enc-z-%s" % PATH)
torch.save(enc_mu.state_dict(), "../weights/enc-mu-%s" % PATH)
torch.save(oneshot_mu.state_dict(), "../weights/oneshot-mu-%s" % PATH)

In [None]:
BATCH_SIZE_TEST = 50
RESAMPLE=True
DETACH=True
obs, metric_step, reused = test(models, EUBO_init_eta, Data, Model_Params, Train_Params)
(q_mu, _, q_z, _) = reused

In [None]:
%time plot_samples(obs, q_mu, q_z, K, PATH)

In [None]:
incremental_gap = symkls_test.cpu().data.numpy()[1:]
M = incremental_gap.shape[0]
overall_gap = np.zeros(M)
for m in range(M):
    overall_gap[m] = incremental_gap[:m+1].sum()

In [None]:
fig = plt.figure(figsize=(12,6))
ax = fig.add_subplot(111)
plt.yscale("log")
ax.plot(incremental_gap, label="incremental gap")
ax.plot(overall_gap, label='overall gap')
ax.legend(fontsize=14)
ax.set_xlabel('Steps')