In [None]:
import matplotlib.pyplot as plt
import time
import torch
from torch import logsumexp
from torch.distributions.normal import Normal
import numpy as np
from utils_v2 import *
from plots import *
from objectives_v2 import *

In [None]:
## training parameters
STEPS = 2000
DATA_DIM= 
NUM_SAMPLES = 16
LEARNING_RATE = 5*1e-2
JOINT_SAMPLE = False
## model parameters
# p_mu = torch.tensor([0.0])
# p_sigma = torch.tensor([1.0])
# q_mu = torch.tensor([8.0], requires_grad=True)
# q_sigma = torch.tensor([2.0], requires_grad=True) 
## initialize optimizer
# optimizer = torch.optim.Adam([q_mu, q_sigma], lr=LEARNING_RATE)
## estimators
ests = ['mc', 'iwae', 'iwae-dreg', 'rws', 'rws-dreg']

In [None]:
dLOSSs = dict()
dESSs = dict()
dKLs = dict()

init_q_mu = torch.randn(DATA_DIM)
print('======= start training ========\n')

for est in ests:
    ## model parameters
    p_mu = torch.zeros(DATA_DIM)
    p_sigma = torch.ones(DATA_DIM)
    q_mu = init_q_mu * 8.0
    q_mu.requires_grad = True
    q_sigma = torch.ones(DATA_DIM) * 2.0
    q_sigma.requires_grad = True
    ## initialize optimizer
    optimizer = torch.optim.Adam([q_mu, q_sigma], lr=LEARNING_RATE)
    time_start = time.time()
    if est == 'mc':
        LOSSs, ESSs, KLs = train(mc, q_mu, q_sigma, p_mu, p_sigma, STEPS, NUM_SAMPLES, DATA_DIM, optimizer, filename=est, joint_sample=JOINT_SAMPLE)        
    elif est == 'iwae':
        LOSSs, ESSs, KLs = train(iwae, q_mu, q_sigma, p_mu, p_sigma, STEPS, NUM_SAMPLES, DATA_DIM, optimizer, filename=est, joint_sample=JOINT_SAMPLE)        
    elif est == 'iwae-dreg':
        LOSSs, ESSs, KLs = train(driwae, q_mu, q_sigma, p_mu, p_sigma, STEPS, NUM_SAMPLES, DATA_DIM, optimizer, filename=est, joint_sample=JOINT_SAMPLE)        
    elif est == 'rws':
        LOSSs, ESSs, KLs = train(rws, q_mu, q_sigma, p_mu, p_sigma, STEPS, NUM_SAMPLES, DATA_DIM, optimizer, filename=est, joint_sample=JOINT_SAMPLE)      
    elif est == 'rws-dreg':
        LOSSs, ESSs, KLs = train(drrws, q_mu, q_sigma, p_mu, p_sigma, STEPS, NUM_SAMPLES, DATA_DIM, optimizer, filename=est, joint_sample=JOINT_SAMPLE)           

    dLOSSs[est] = np.array(LOSSs)
    dESSs[est] = ESSs
    dKLs[est] = KLs
    time_end = time.time()
    print('%s training completed.. (%ds)' % (est, time_end - time_start))

In [None]:
def plot_results_simplified(dLOSSs, dESSs, dKLs, data_dim, num_samples, lr, ests, fs=15):
    fig = plt.figure(figsize=(fs,fs))
    colors = {'mc':'green', 'iwae': 'red', 'iwae-dreg': 'blue', 
              'rws': 'deepskyblue', 'rws-dreg': 'firebrick', 'stl': 'black'}
    fig = plt.figure(figsize=(fs,fs))
    ax = fig.subplots(3, 1, gridspec_kw={'wspace':0.1, 'hspace':0.1})
    
    for i, est in enumerate(ests):
        LOSSs = dLOSSs[est]
        ESSs = dESSs[est]
        KLs = dKLs[est]
        if est == 'mc' or est == 'iwae' or est == 'iwae-dreg':
            ax[0].plot(- LOSSs, c=colors[est], label= 'ELBO' + est)
            ax[1].plot(KLs, c=colors[est], label=est)
            ax[2].plot(np.array(ESSs), c=colors[est], label=est)                 
        elif est =='rws' or est == 'rws-dreg':
            ax[0].plot(LOSSs, c=colors[est], label= 'EUBO' + est)
            ax[1].plot(KLs, c=colors[est], label=est)
            ax[2].plot(np.array(ESSs), c=colors[est], label=est)    
    ax[0].set_title('Objectives')
    ax[1].set_title('exclusive KL')
    ax[2].set_title('ESS')
    
    for i in range(3):
        ax[i].legend(fontsize=12)
        ax[i].tick_params(labelsize=12)
        if i == 1:
            ax[i].set_yscale('log')
    plt.savefig('results/%ddim-%dsamples-%.4flr.svg' % (data_dim, num_samples, lr))

In [None]:
plot_results_simplified(dLOSSs, dESSs, dKLs, DATA_DIM, NUM_SAMPLES, LEARNING_RATE, ests)