In [5]:
import matplotlib.pyplot as plt
import time
import torch
from torch import logsumexp
from torch.distributions.normal import Normal
import numpy as np
from utils import *
from plots import *
from objectives import *

In [6]:
## training parameters
STEPS = 10000
NUM_SAMPLES = 5
NUM_SAMPLES_SNR = 100
LEARNING_RATE = 1e-3
## model parameters
p_mu = torch.tensor([0.0])
p_sigma = torch.tensor([1.0])
q_mu = torch.tensor([3.0], requires_grad=True)
q_sigma = torch.tensor([2.0], requires_grad=True) 
## initialize optimizer
optimizer = torch.optim.SGD([q_mu, q_sigma], lr=LEARNING_RATE)
## estimators
ests = ['mc', 'iwae', 'iwae-dreg', 'rws', 'rws-dreg', 'stl']

In [None]:
dEUBOs = dict()
dELBOs = dict()
dIWELBOs = dict()
dSNRs = dict()
dVARs = dict()
dESSs = dict()

for est in ests:
    ## model parameters
    p_mu = torch.tensor([0.0])
    p_sigma = torch.tensor([1.0])
    q_mu = torch.tensor([4.0], requires_grad=True)
    q_sigma = torch.tensor([2.0], requires_grad=True) 
    ## initialize optimizer
    optimizer = torch.optim.SGD([q_mu, q_sigma], lr=LEARNING_RATE)
    print('======= start training by %s ========\n' % est)
    if est == 'mc':
        EUBOs, ELBOs, IWELBOs, ESSs, SNRs, VARs = train(mc, q_mu, q_sigma, p_mu, p_sigma, STEPS, NUM_SAMPLES, NUM_SAMPLES_SNR, optimizer, filename=est, alpha=None)        
    elif est == 'iwae':
        EUBOs, ELBOs, IWELBOs, ESSs, SNRs, VARs = train(iwae, q_mu, q_sigma, p_mu, p_sigma, STEPS, NUM_SAMPLES, NUM_SAMPLES_SNR, optimizer, filename=est, alpha=None)        
    elif est == 'iwae-dreg':
        EUBOs, ELBOs, IWELBOs, ESSs, SNRs, VARs = train(dreg, q_mu, q_sigma, p_mu, p_sigma, STEPS, NUM_SAMPLES, NUM_SAMPLES_SNR, optimizer, filename=est, alpha=0.0)        
    elif est == 'rws':
        EUBOs, ELBOs, IWELBOs, ESSs, SNRs, VARs = train(rws, q_mu, q_sigma, p_mu, p_sigma, STEPS, NUM_SAMPLES, NUM_SAMPLES_SNR, optimizer, filename=est, alpha=None)      
    elif est == 'rws-dreg':
        EUBOs, ELBOs, IWELBOs, ESSs, SNRs, VARs = train(dreg, q_mu, q_sigma, p_mu, p_sigma, STEPS, NUM_SAMPLES, NUM_SAMPLES_SNR, optimizer, filename=est, alpha=1.0)      
    elif est == 'stl':
        EUBOs, ELBOs, IWELBOs, ESSs, SNRs, VARs = train(dreg, q_mu, q_sigma, p_mu, p_sigma, STEPS, NUM_SAMPLES, NUM_SAMPLES_SNR, optimizer, filename=est, alpha=0.5)      
    
    dEUBOs[est] = EUBOs
    dELBOs[est] = ELBOs
    dIWELBOs[est] = IWELBOs
    dSNRs[est] = SNRs
    dVARs[est] = VARs
    dESSs[est] = ESSs
    print('======= end training by %s ========\n' % est)


iteration:0, EUBO:-0.812, ELBO:-11.354, IWELBO:-2.061, ESS:1.259 (0s)
iteration:100, EUBO:-2.127, ELBO:-7.757, IWELBO:-3.317, ESS:1.260 (9s)
iteration:200, EUBO:0.244, ELBO:-4.702, IWELBO:-1.204, ESS:1.068 (9s)
iteration:300, EUBO:1.570, ELBO:-2.582, IWELBO:0.263, ESS:1.170 (9s)
iteration:400, EUBO:3.042, ELBO:-1.214, IWELBO:1.713, ESS:1.144 (10s)
iteration:500, EUBO:2.995, ELBO:-4.537, IWELBO:1.582, ESS:1.103 (9s)
iteration:600, EUBO:1.006, ELBO:0.576, IWELBO:0.804, ESS:3.578 (9s)
iteration:700, EUBO:0.094, ELBO:-0.756, IWELBO:-0.338, ESS:2.399 (10s)
iteration:800, EUBO:0.192, ELBO:-0.795, IWELBO:-0.275, ESS:2.364 (9s)
iteration:900, EUBO:1.732, ELBO:-0.741, IWELBO:0.690, ESS:1.349 (9s)
iteration:1000, EUBO:1.864, ELBO:0.880, IWELBO:1.482, ESS:2.984 (9s)
iteration:1100, EUBO:2.182, ELBO:0.801, IWELBO:1.548, ESS:2.002 (10s)
iteration:1200, EUBO:1.960, ELBO:0.084, IWELBO:1.248, ESS:2.143 (9s)
iteration:1300, EUBO:0.465, ELBO:0.302, IWELBO:0.386, ESS:4.320 (9s)
iteration:1400, EUBO:1.52

In [None]:
def plot_results(dEUBOs, dELBOs, dIWELBOs, dESSs, dSNRs, dVARs, num_samples, num_samples,_snr, steps, lr, ests, fs=10):
    fig = plt.figure(figsize=(fs,fs))
    colors = {'mc':'green', 'iwae': 'red', 'iwae-dreg': 'blue', 
              'rws': 'deepskyblue', 'rws-dreg': 'firebrick', 'stl': 'black'}
    fig = plt.figure(figsize=(fs,fs))
    ax = fig.subplots(5, 3, gridspec_kw={'wspace':0.1, 'hspace':0.1})
    
    for i, est in enumerate(ests):
        EUBOs = dEUBOs[est]
        ELBOs = dELBOs[est]
        IWELBOs = dIWELBOs[est]
        ESSs = dESSs[est]
        SNRs = dSNRs[est]
        VARs = dVARs[est]
        if est == 'mc':
            ax[0, 0].plot(EUBOs, c=colors[est], label=est)
            ax[1, 0].plot(ELBOs, c=colors[est], label=est)
            ax[2, 0].plot(IWELBOs, c=colors[est], label=est)
            ax[3, 0].plot(SNRs, c=colors[est], label=est + '-snr')
            ax[3, 0].plot(VARs, marker='o', c=colors[est], label=est + '-var')
            ax[4, 0].plot(ESSs, c=colors[est], label=est)           
        elif est == 'iwae' or est == 'iwae-dreg':
            ax[0, 1].plot(EUBOs, c=colors[est], label=est)
            ax[1, 1].plot(ELBOs, c=colors[est], label=est)
            ax[2, 1].plot(IWELBOs, c=colors[est], label=est)
            ax[3, 1].plot(SNRs, c=colors[est], label=est + '-snr')
            ax[3, 1].plot(VARs, marker='o', c=colors[est], label=est + '-var')
            ax[4, 1].plot(ESSs, c=colors[est], label=est)       
        elif est =='rws' or est == 'rws-dreg':
            ax[0, 2].plot(EUBOs, c=colors[est], label=est)
            ax[1, 2].plot(ELBOs, c=colors[est], label=est)
            ax[2, 2].plot(IWELBOs, c=colors[est], label=est)
            ax[3, 2].plot(SNRs, c=colors[est], label=est + '-snr')
            ax[3, 2].plot(VARs, marker='o', c=colors[est], label=est + '-var')
            ax[4, 2].plot(ESSs, c=colors[est], label=est)  
            
    ax[0, 1].set_title('EUBO')
    ax[1, 1].set_title('ELBO')
    ax[2, 1].set_title('IWELBO')
    ax[3, 1].set_title('SNR and Variance')
    ax[4, 1].set_title('ESS')
    
    for i in range(5):
        for j in range(3):
            ax[i, j].legend(fontsize=14)
            ax[i, j].tick_params(labelsize=14)
            if i == 3:
                ax[i, j].set_yscale('log')
    plt.savefig('results/results-%dsamples-%dSNRsamples-%dsteps-%flr.svg' % (num_samples, num_samples_snr, steps, lr))

In [None]:
plot_results_multiple(ELBOs, ESSs, SNRs, num_samples, NUM_SAMPLES_SNR, STEPS, LEARNING_RATE, ests)