In [1]:
%matplotlib inline
%run ../../import_envs.py
print('probtorch:', probtorch.__version__, 
      'torch:', torch.__version__, 
      'cuda:', torch.cuda.is_available())

probtorch: 0.0+5a2c637 torch: 0.4.1 cuda: True


In [2]:
## Load dataset
data_path = "../gmm_dataset_3c"
Data = torch.from_numpy(np.load(data_path + '/obs.npy')).float()

NUM_DATASETS, N, D = Data.shape
K = 3 ## number of clusters
SAMPLE_SIZE = 10
NUM_HIDDEN_LOCAL = 32

MCMC_SIZE = 10
BATCH_SIZE = 500
NUM_EPOCHS = 250
LEARNING_RATE = 5 * 1e-4
CUDA = torch.cuda.is_available()
PATH = '../neurips/ag-10runs/'
DEVICE = torch.device('cuda:1')

Train_Params = (NUM_DATASETS, SAMPLE_SIZE, BATCH_SIZE, CUDA, DEVICE, PATH)
Model_Params = (N, K, D, MCMC_SIZE)

In [3]:
from local_enc import *
from global_oneshot import *
from global_enc_v1 import *
## if reparameterize continuous variables
Reparameterized = False
# initialization
enc_z = Enc_z(K, D, NUM_HIDDEN_LOCAL, CUDA, DEVICE)
enc_eta = Enc_eta(K, D, CUDA, DEVICE, Reparameterized)
oneshot_eta = Oneshot_eta(K, D, CUDA, DEVICE, Reparameterized)

if CUDA:
    enc_z.cuda().to(DEVICE)
    enc_eta.cuda().to(DEVICE)
    oneshot_eta.cuda().to(DEVICE)
optimizer =  torch.optim.Adam(list(oneshot_eta.parameters())+list(enc_eta.parameters())+list(enc_z.parameters()),lr=LEARNING_RATE, betas=(0.9, 0.99))
models = (oneshot_eta, enc_eta, enc_z)

In [9]:
logs = []
for i in range(10):
    print("experiment : %d" % i)
    enc_z = Enc_z(K, D, NUM_HIDDEN_LOCAL, CUDA, DEVICE)
    enc_eta = Enc_eta(K, D, CUDA, DEVICE, Reparameterized)
    oneshot_eta = Oneshot_eta(K, D, CUDA, DEVICE, Reparameterized)

    if CUDA:
        enc_z.cuda().to(DEVICE)
        enc_eta.cuda().to(DEVICE)
        oneshot_eta.cuda().to(DEVICE)
    optimizer =  torch.optim.Adam(list(oneshot_eta.parameters())+list(enc_eta.parameters())+list(enc_z.parameters()),lr=LEARNING_RATE, betas=(0.9, 0.99))
    models = (oneshot_eta, enc_eta, enc_z)
    enc_z.load_state_dict(torch.load(PATH + "enc-z-ag-10runs-%dround" % i))
    enc_eta.load_state_dict(torch.load(PATH + "enc-eta-ag-10runs-%dround" % i))
    oneshot_eta.load_state_dict(torch.load(PATH + "oneshot-eta-ag-10runs-%dround" % i))
    a = test(models, EUBO_init_eta_test, Data, Model_Params, Train_Params)
    logs.append(a)

experiment : 0
iteration:0/40
experiment : 1
iteration:0/40
experiment : 2
iteration:0/40
experiment : 3
iteration:0/40
experiment : 4
iteration:0/40
experiment : 5
iteration:0/40
experiment : 6
iteration:0/40
experiment : 7
iteration:0/40
experiment : 8
iteration:0/40
experiment : 9
iteration:0/40


In [8]:
def test(models, objective, data, Model_Params, Train_Params):
    """
    generic training function
    """
    Metrics = {"DB_eta" : [], "DB_z" : [], "ess_eta" : [], "ess_z": []}
    (NUM_DATASETS, S, B, CUDA, device, path) = Train_Params

    NUM_BATCHES = int((NUM_DATASETS / B))

    SubTrain_Params = (device, S, B) + Model_Params
    indices = torch.randperm(NUM_DATASETS)
    time_start = time.time()
    for step in range(NUM_BATCHES):
        batch_indices = indices[step*B : (step+1)*B]
        obs = data[batch_indices]
        obs = shuffler(obs).repeat(S, 1, 1, 1)
        if CUDA:
            obs =obs.cuda().to(device)
        metric_step = objective(models, obs, SubTrain_Params)
        for key in Metrics.keys():
            if Metrics[key] == None:
                Metrics[key] = [metric_step[key].cpu().data.numpy()]
            else:
                Metrics[key].append(metric_step[key].cpu().data.numpy())

        if step % 100 == 0:
            time_end = time.time()
            print('iteration:%d/%d' % (step, NUM_BATCHES))
            time_start = time.time()
    return Metrics

In [5]:
def EUBO_init_eta_test(models, obs, SubTrain_Params):
    """
    NO Resampling
    Learn neural gibbs samplers for both eta and z,
    non-reparameterized-style gradient estimation
    initialize eta
    """
    (device, sample_size, batch_size, N, K, D, mcmc_size) = SubTrain_Params
    esss_eta = torch.zeros(mcmc_size+1).cuda().to(device)
    esss_z = torch.zeros(mcmc_size+1).cuda().to(device)

    DB_eta = torch.zeros(mcmc_size+1).cuda().to(device)
    DB_z = torch.zeros(mcmc_size+1).cuda().to(device)
    (oneshot_eta, enc_eta, enc_z) = models
    model_os = (oneshot_eta, enc_z)
    obs_tau, obs_mu, state, log_w_f_z, q_eta, p_eta, q_z, p_z = Init_step_eta(model_os, obs, N, K, D, sample_size, batch_size)
    w_f_z = F.softmax(log_w_f_z, 0).detach()
    DB_eta[0] = (w_f_z * log_w_f_z).sum(0).mean() - log_w_f_z.mean()
    DB_z[0] = DB_eta[0] ##
    esss_z[0] = (1. / (w_f_z**2).sum(0)).mean()
    esss_eta[0] = esss_z[0]
    for m in range(mcmc_size):
#         if m == 0:
#             state = resample_state(state, w_f_z, idw_flag=False) ## resample state
#         else:
#             state = resample_state(state, w_f_z, idw_flag=True)
        q_eta, p_eta, q_nu = enc_eta(obs, state, K, D)
        obs_tau, obs_mu, log_w_eta_f, log_w_eta_b  = Incremental_eta(q_eta, p_eta, obs, state, K, D, obs_tau, obs_mu)
        symkl_detailed_balance_eta, eubo_p_q_eta, w_sym_eta, w_f_eta = detailed_balances(log_w_eta_f, log_w_eta_b)
#         obs_mu, obs_tau = resample_eta(obs_mu, obs_tau, w_f_eta, idw_flag=True) ## resample eta
        q_z, p_z = enc_z.forward(obs, obs_tau, obs_mu, N, K, sample_size, batch_size)
        state, log_w_z_f, log_w_z_b = Incremental_z(q_z, p_z, obs, obs_tau, obs_mu, K, D, state)
        symkl_detailed_balance_z, eubo_p_q_z, w_sym_z, w_f_z = detailed_balances(log_w_z_f, log_w_z_b)
        ## symmetric KLs as metrics
        DB_eta[m+1] = symkl_detailed_balance_eta
        DB_z[m+1] = symkl_detailed_balance_z
        esss_eta[m+1] = (1. / (w_sym_eta**2).sum(0)).mean()
        esss_z[m+1] = (1. / (w_sym_z**2).sum(0)).mean()
    metric_step = {"DB_eta" : DB_eta, "DB_z" : DB_z, "ess_eta" : esss_eta, "ess_z": esss_z}
    return metric_step

In [10]:
DB_etas = []
DB_zs = []
ESSs_eta = []
ESSs_z = []
for i in range(10):
#     DB_etas.append(np.array(logs[i]['DB_eta']).mean(0))
#     DB_zs.append(np.array(logs[i]['DB_z']).mean(0))
    ESSs_eta.append(np.array(logs[i]['ess_eta']).mean(0))
    ESSs_z.append(np.array(logs[i]['ess_z']).mean(0))

In [11]:
np.save('test-10runs-ess_eta.npy', np.array(ESSs_eta))
np.save('test-10runs-ess_z.npy', np.array(ESSs_z))

# np.save('test-10runs-DB-eta.npy', np.array(DB_zs))
# np.save('test-10runs-DB-z.npy', np.array(DB_etas))