In [1]:
%pylab inline

Populating the interactive namespace from numpy and matplotlib


In [2]:
import torch
import torch.nn as nn
import pandas as pd
from time import time

from matplotlib import gridspec
from mpl_toolkits.axes_grid1 import ImageGrid

import sys
sys.path.append('../')
from core import distributions, metrics, kernels

In [3]:
def eval_kernel(kernel, init_state, dim, T, batch_size, burn_in):
    samples = np.zeros([batch_size,T,dim])
    AR = 0.0
    state = init_state
    start_time = time()
    for t in range(T):
        state, accepted_mask = kernel(state)
        AR += torch.sum(accepted_mask).float()/batch_size/T
        samples[:,t,:] = state['x'].cpu().numpy()
    run_time = time()-start_time
    samples = samples[:,burn_in:,:]
    ess = metrics.batch_means_ess(samples)
    print('AR:', AR.cpu().numpy())
    print('time:', run_time)
    print('mean ess:', np.mean(np.min(ess, axis=1), axis=0))
    print('std ess:', np.std(np.min(ess, axis=1), axis=0))
    print('mean ess/s:', np.mean(np.min(ess/run_time, axis=1), axis=0))
    print('std ess/s:', np.std(np.min(ess/run_time, axis=1), axis=0))
    return AR, ess, run_time

## MogTwo

In [4]:
device = torch.device('cuda')
target = distributions.MOGTwo(device)
dim = target.mean().shape[0]

In [5]:
batch_size=100
T = 20000
x = (torch.zeros([batch_size,dim])).to(device)
d = torch.ones([batch_size,1]).to(device)
init_state_irr = {'x': x, 'd': d}
eps = 1.2
kernel_irr = kernels.irrMALA(target, dim, step_size=eps, sigma=np.sqrt(2.0*eps), device=device)
AR_irr, ess_irr, time_irr = eval_kernel(kernel_irr, init_state_irr, dim, 
                                        T=T, batch_size=batch_size, burn_in=T//20)

AR: 0.29265884
time: 123.88203167915344
mean ess: 0.026629105146505617
std ess: 0.007602860435063467
mean ess/s: 0.00021495534732166252
std ess/s: 6.13717771012538e-05


In [6]:
batch_size=100
T = 20000
x = (torch.zeros([batch_size,dim])).to(device)
d = torch.ones([batch_size,1]).to(device)
init_state_rw = {'x': x, 'd': d}
eps=1.1
kernel_rw = kernels.rwMALA(target, dim, step_size=eps, sigma=np.sqrt(2.0*eps), device=device)
AR_rw, ess_rw, time_rw = eval_kernel(kernel_rw, init_state_rw, dim, 
                                     T=T, batch_size=batch_size, burn_in=T//20)

AR: 0.25305277
time: 92.13073062896729
mean ess: 0.006902460617468302
std ess: 0.0018689924722145325
mean ess/s: 7.492028523323213e-05
std ess/s: 2.0286309024742422e-05


## Heart

In [4]:
device = torch.device('cuda')
target = distributions.Heart(device)
dim = target.mean().shape[0]

In [19]:
batch_size=100
T = 20000
x = (torch.zeros([batch_size,dim])).to(device)
d = torch.ones([batch_size,1]).to(device)
init_state_irr = {'x': x, 'd': d}
eps = 0.003
kernel_irr = kernels.irrMALA(target, dim, step_size=eps, sigma=np.sqrt(2.0*eps), device=device)
AR_irr, ess_irr, time_irr = eval_kernel(kernel_irr, init_state_irr, dim, 
                                        T=T, batch_size=batch_size, burn_in=T//20)

AR: 0.33483467
time: 246.5564935207367
mean ess: 0.011639801443606672
std ess: 0.0014356228615663592
mean ess/s: 4.720947024105738e-05
std ess/s: 5.822693375729793e-06


In [12]:
batch_size=100
T = 20000
x = (torch.zeros([batch_size,dim])).to(device)
d = torch.ones([batch_size,1]).to(device)
init_state_rw = {'x': x, 'd': d}
eps=0.017
kernel_rw = kernels.rwMALA(target, dim, step_size=eps, sigma=np.sqrt(2.0*eps), device=device)
AR_rw, ess_rw, time_rw = eval_kernel(kernel_rw, init_state_rw, dim, 
                                     T=T, batch_size=batch_size, burn_in=T//20)

AR: 0.36826077
time: 175.35398077964783
mean ess: 0.08148800813348814
std ess: 0.012048615861217922
mean ess/s: 0.0004647057783985359
std ess/s: 6.871025001912203e-05


## Australian

In [6]:
device = torch.device('cuda')
target = distributions.Australian(device)
dim = target.mean().shape[0]

In [27]:
batch_size=100
T = 20000
x = (torch.zeros([batch_size,dim])).to(device)
d = torch.ones([batch_size,1]).to(device)
init_state_irr = {'x': x, 'd': d}
eps = 0.002
kernel_irr = kernels.irrMALA(target, dim, step_size=eps, sigma=np.sqrt(2.0*eps), device=device)
AR_irr, ess_irr, time_irr = eval_kernel(kernel_irr, init_state_irr, dim, 
                                        T=T, batch_size=batch_size, burn_in=T//20)

AR: 0.2755123
time: 226.49076437950134
mean ess: 0.006410871045491162
std ess: 0.0013060487761642281
mean ess/s: 2.830522058175093e-05
std ess/s: 5.766454891625739e-06


In [8]:
batch_size=100
T = 20000
x = (torch.zeros([batch_size,dim])).to(device)
d = torch.ones([batch_size,1]).to(device)
init_state_rw = {'x': x, 'd': d}
eps=0.0075
kernel_rw = kernels.rwMALA(target, dim, step_size=eps, sigma=np.sqrt(2.0*eps), device=device)
AR_rw, ess_rw, time_rw = eval_kernel(kernel_rw, init_state_rw, dim, 
                                     T=T, batch_size=batch_size, burn_in=T//20)

AR: 0.49830386
time: 186.16454339027405
mean ess: 0.04257893329990809
std ess: 0.009961171865058792
mean ess/s: 0.00022871666389579843
std ess/s: 5.3507352601382636e-05


## German

In [4]:
device = torch.device('cuda')
target = distributions.German(device)
dim = target.mean().shape[0]

In [36]:
batch_size=100
T = 20000
x = (torch.zeros([batch_size,dim])).to(device)
d = torch.ones([batch_size,1]).to(device)
init_state_irr = {'x': x, 'd': d}
eps = 0.0007
kernel_irr = kernels.irrMALA(target, dim, step_size=eps, sigma=np.sqrt(2.0*eps), device=device)
AR_irr, ess_irr, time_irr = eval_kernel(kernel_irr, init_state_irr, dim, 
                                        T=T, batch_size=batch_size, burn_in=T//20)

AR: 0.21619353
time: 255.85654711723328
mean ess: 0.003996691934207145
std ess: 0.0005570143224095623
mean ess/s: 1.5620831201070903e-05
std ess/s: 2.1770571387971504e-06


In [39]:
batch_size=100
T = 20000
x = (torch.zeros([batch_size,dim])).to(device)
d = torch.ones([batch_size,1]).to(device)
init_state_rw = {'x': x, 'd': d}
eps=0.003
kernel_rw = kernels.rwMALA(target, dim, step_size=eps, sigma=np.sqrt(2.0*eps), device=device)
AR_rw, ess_rw, time_rw = eval_kernel(kernel_rw, init_state_rw, dim, 
                                     T=T, batch_size=batch_size, burn_in=T//20)

AR: 0.4022522
time: 172.59669399261475
mean ess: 0.025507646781855634
std ess: 0.005416520726826999
mean ess/s: 0.0001477875745577553
std ess/s: 3.1382528839508174e-05
