In [None]:
import matplotlib.pyplot as plt
import time
import torch
from torch import logsumexp
from torch.distributions.normal import Normal
import numpy as np
from utils import *
from plots import *
from objectives import *

In [None]:
## training parameters
STEPS = 10000
NUM_SAMPLES = 5
NUM_SAMPLES_SNR = 100
LEARNING_RATE = 5*1e-3
## model parameters
p_mu = torch.tensor([0.0])
p_sigma = torch.tensor([1.0])
q_mu = torch.tensor([4.0], requires_grad=True)
q_sigma = torch.tensor([2.0], requires_grad=True) 
## initialize optimizer
optimizer = torch.optim.SGD([q_mu, q_sigma], lr=LEARNING_RATE)
## estimators
ests = ['mc', 'iwae', 'iwae-dreg', 'rws', 'rws-dreg', 'stl']

In [None]:
dEUBOs = dict()
dELBOs = dict()
dIWELBOs = dict()
dSNRs = dict()
dVARs = dict()
dESSs = dict()

for est in ests:
    ## model parameters
    p_mu = torch.tensor([0.0])
    p_sigma = torch.tensor([1.0])
    q_mu = torch.tensor([4.0], requires_grad=True)
    q_sigma = torch.tensor([2.0], requires_grad=True) 
    ## initialize optimizer
    optimizer = torch.optim.SGD([q_mu, q_sigma], lr=LEARNING_RATE)
    print('======= start training by %s ========\n' % est)
    if est == 'mc':
        EUBOs, ELBOs, IWELBOs, ESSs, SNRs, VARs = train(mc, q_mu, q_sigma, p_mu, p_sigma, STEPS, NUM_SAMPLES, NUM_SAMPLES_SNR, optimizer, filename=est, alpha=None)        
    elif est == 'iwae':
        EUBOs, ELBOs, IWELBOs, ESSs, SNRs, VARs = train(iwae, q_mu, q_sigma, p_mu, p_sigma, STEPS, NUM_SAMPLES, NUM_SAMPLES_SNR, optimizer, filename=est, alpha=None)        
    elif est == 'iwae-dreg':
        EUBOs, ELBOs, IWELBOs, ESSs, SNRs, VARs = train(dreg, q_mu, q_sigma, p_mu, p_sigma, STEPS, NUM_SAMPLES, NUM_SAMPLES_SNR, optimizer, filename=est, alpha=0.0)        
    elif est == 'rws':
        EUBOs, ELBOs, IWELBOs, ESSs, SNRs, VARs = train(rws, q_mu, q_sigma, p_mu, p_sigma, STEPS, NUM_SAMPLES, NUM_SAMPLES_SNR, optimizer, filename=est, alpha=None)      
    elif est == 'rws-dreg':
        EUBOs, ELBOs, IWELBOs, ESSs, SNRs, VARs = train(dreg, q_mu, q_sigma, p_mu, p_sigma, STEPS, NUM_SAMPLES, NUM_SAMPLES_SNR, optimizer, filename=est, alpha=1.0)      
    elif est == 'stl':
        EUBOs, ELBOs, IWELBOs, ESSs, SNRs, VARs = train(dreg, q_mu, q_sigma, p_mu, p_sigma, STEPS, NUM_SAMPLES, NUM_SAMPLES_SNR, optimizer, filename=est, alpha=0.5)      
    
    dEUBOs[est] = EUBOs
    dELBOs[est] = ELBOs
    dIWELBOs[est] = IWELBOs
    dSNRs[est] = SNRs
    dVARs[est] = VARs
    dESSs[est] = ESSs
    print('======= end training by %s ========\n' % est)

In [None]:
def plot_results_multiple(ELBOs, ESSs, SNRs, num_samples, ests):
    colors = ['blue', 'red', 'orange', 'black', 'green']
    fig = plt.figure(figsize=(40,20))
    plt.tight_layout()
    axes = fig.subplots(3, 3, sharex=True)
    axes[1,0].set_yscale('log')
    axes[1,1].set_yscale('log')
    axes[1,2].set_yscale('log')
    for i, est in enumerate(ests):
        if est == 'IWAE' or est == 'IWAE-DReG':
            axes[0,0].plot(ELBOs[i], c=colors[i], label="ELBO " + est)
            axes[1,0].plot(SNRs[i], c=colors[i], label='SNR ' + est)
#             axes[1,0].plot(Gvars[i], c=colors[i], label='Grad var ' + est)
            ess_ratio = np.array(ESSs[i]) / num_samples
            ave_ess = np.reshape(ess_ratio, (-1, 10)).mean(-1)
            N = ave_ess.shape[0]
            axes[2,0].plot(np.arange(N) * 10, ave_ess, '-o', c=colors[i], label='ESS ' + est)
        elif est == 'RWS' or est == 'RWS-DReG':
            axes[0,1].plot(ELBOs[i], c=colors[i], label="ELBO " + est)
            axes[1,1].plot(SNRs[i], c=colors[i], label='SNR ' + est)
#             axes[1,1].plot(Gvars[i], c=colors[i], label='Grad var ' + est)
            ess_ratio = np.array(ESSs[i]) / num_samples
            ave_ess = np.reshape(ess_ratio, (-1, 10)).mean(-1)
            N = ave_ess.shape[0]
            axes[2,1].plot(np.arange(N) * 10, ave_ess, '-o', c=colors[i], label='ESS ' + est)
        else:   
            axes[0,2].plot(ELBOs[i], c=colors[i], label="ELBO " + est)
            axes[1,2].plot(SNRs[i], c=colors[i], label='SNR ' + est)
#             axes[1,2].plot(Gvars[i], c=colors[i], label='Grad var ' + est)
            ess_ratio = np.array(ESSs[i]) / num_samples
            ave_ess = np.reshape(ess_ratio, (-1, 10)).mean(-1)
            N = ave_ess.shape[0]
            axes[2,2].plot(np.arange(N) * 10, ave_ess, '-o', c=colors[i], label='ESS ' + est)
    axes[0,0].set_ylim([-10,2])   
    axes[0,1].set_ylim([-10,2]) 
    axes[0,2].set_ylim([-10,2])
    
    axes[1,0].set_ylim([1e-4,1e3])
    axes[1,1].set_ylim([1e-4,1e3])
    axes[1,2].set_ylim([1e-4,1e3])
    axes[1,0].set_yticks
    axes[2,0].set_ylim([0, 1])

    axes[0,0].legend()
    axes[1,0].legend()
    axes[2,0].legend()
    
    axes[0,1].legend()
    axes[1,1].legend()
    axes[2,1].legend()
    
    axes[0,2].legend()
    axes[1,2].legend()
    axes[2,2].legend()
    plt.savefig(PATH + 'training_results_%dsamples.svg' % num_samples)

In [None]:
plot_results_multiple(ELBOs, ESSs, SNRs, num_samples, ests)