In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

In [3]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

# SLCP experiment

In [4]:
from simulators import SLCP
from datasets import LTEDataset
from models import Flatten, MNRE
from criterions import WeightedLoss, RDLoss
from samplers import TractableSampler, NRESampler
from histograms import pairhist, corner, kld, remd

## Simulator 

In [5]:
simulator = SLCP().to(device)
simulator.sample()

(tensor([ 2.9177,  0.1338, -2.9742,  1.6339,  1.7366], device='cuda:0'),
 tensor([[-3.9109, -1.6918],
         [ 0.8944, -0.6594],
         [ 1.5325, -0.4307],
         [ 9.2922,  0.6959]], device='cuda:0'))

In [6]:
trainset = LTEDataset(simulator, mode=-1)

## Model

In [7]:
joint_subsets = torch.tensor([
    [False, False, False, True, True],
    [False, False, True, True, True],
    [True, True, True, True, True]
])

joint_subsets

tensor([[False, False, False,  True,  True],
        [False, False,  True,  True,  True],
        [ True,  True,  True,  True,  True]])

In [8]:
model = MNRE(masks=joint_subsets, encoder=Flatten((4, 2)), num_layers=10, hidden_size=512, activation=nn.SELU).to(device)
criterion = WeightedLoss(RDLoss(), .001).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-5, amsgrad=True)

In [9]:
model.load_state_dict(torch.load('mnre.pth', map_location=device))

<All keys matched successfully>

## Metropolis-Hastings

In [10]:
low = simulator.low.cpu()
high = simulator.high.cpu()

In [11]:
theta_star = torch.tensor([0.7, -2.9, -1., -0.9,  0.6]).to(device)
x_star = torch.tensor([
    [-0.48406151, -3.13977371],
    [-0.43098274, -3.50238278],
    [-0.03512463, -2.87554796],
    [ 1.43279532, -2.80650507]
]).to(device)

### Likelihood

In [12]:
sampler = TractableSampler(simulator, x_star, 2 ** 12, sigma=0.1)
samples = sampler(2 ** 8, 2 ** 4)
hists_lhd = pairhist(samples, low, high, bins=100, normed=True, bounded=True)

### MNRE

In [17]:
model.eval()
z_star = model.encoder(x_star)

tensor([-0.4841, -3.1398, -0.4310, -3.5024, -0.0351, -2.8755,  1.4328, -2.8065],
       device='cuda:0')


#### All

In [14]:
_, nre = model[2]

In [15]:
sampler = NRESampler(simulator.prior, nre, z_star, 2 ** 16, sigma=0.1)
samples = sampler(2 ** 8)
hists_all = pairhist(samples, low, high, bins=100, normed=True, bounded=True)

In [16]:
print('KL =', kld(hists_lhd[4][3], hists_all[4][3]))
print('EMD =', remd(hists_lhd[4][3], hists_all[4][3]))

KL = tensor(0.0330)
EMD = tensor(0.0456)


#### 4 & 5

In [18]:
mask, nre = model[0]
mask = mask.cpu()

In [19]:
sampler = NRESampler(simulator.subprior(mask), nre, z_star, 2 ** 16, sigma=0.1)
samples = sampler(2 ** 8)
hists_45 = pairhist(samples, low[mask], high[mask], bins=100, normed=True, bounded=True)

In [20]:
print('KL =', kld(hists_lhd[4][3], hists_45[1][0]))
print('EMD =', remd(hists_lhd[4][3], hists_45[1][0]))

KL = tensor(0.8700)
EMD = tensor(0.1430)


In [21]:
print('KL =', kld(hists_all[4][3], hists_45[1][0]))
print('EMD =', remd(hists_all[4][3], hists_45[1][0]))

KL = tensor(0.8551)
EMD = tensor(0.1450)


#### 3, 4 & 5

In [22]:
mask, nre = model[1]
mask = mask.cpu()

In [28]:
sampler = NRESampler(simulator.subprior(mask), nre, z_star, 2 ** 16, sigma=0.1)
samples = sampler(2 ** 8)
hists_345 = pairhist(samples, low[mask], high[mask], bins=100, normed=True, bounded=True)

In [29]:
print('KL =', kld(hists_lhd[4][3], hists_345[2][1]))
print('EMD =', remd(hists_lhd[4][3], hists_345[2][1]))

KL = tensor(0.7598)
EMD = tensor(0.1396)


In [30]:
print('KL =', kld(hists_all[4][3], hists_345[2][1]))
print('EMD =', remd(hists_all[4][3], hists_345[2][1]))

KL = tensor(0.7029)
EMD = tensor(0.1349)


## Distillation

In [None]:
model.train()

epoch = 0
epoch_size = 256

losses = []

for thetas, xs, mask in trainset:
    ratios = model(thetas, xs)
    loss = criterion(ratios[:, :-1], ratios[:, -1].detach())

    losses.append(loss.tolist())

    optimizer.zero_grad()
    loss.backward()
    nn.utils.clip_grad_norm_(model.parameters(), 1.)
    optimizer.step()

    if len(losses) == epoch_size:
        losses = torch.tensor(losses)

        print(f'{epoch}: {losses.mean()} +- {losses.std()}')

        epoch += 1
        losses = []

        if epoch == 50:
            break

## Metropolis-Hastings
### MNRE

In [None]:
model.eval()
z_star = model.encoder(x_star)

#### 4 & 5

In [None]:
mask, nre = model[0]
mask = mask.cpu()

In [None]:
sampler = NRESampler(simulator.subprior(mask), nre, z_star, sigma=0.1)
samples = sampler(4096)
hists_45 = pairhist(samples, low[mask], high[mask], bins=60, normed=True, bounded=True)

In [None]:
print('KL =', kld(hists_lhd[4][3], hists_45[1][0]))
print('EMD =', remd(hists_lhd[4][3], hists_45[1][0]))

In [None]:
print('KL =', kld(hists_all[4][3], hists_45[1][0]))
print('EMD =', remd(hists_all[4][3], hists_45[1][0]))

#### 3, 4 & 5

In [None]:
mask, nre = model[1]
mask = mask.cpu()

In [None]:
sampler = NRESampler(simulator.subprior(mask), nre, z_star, sigma=0.1)
samples = sampler(4096)
hists_345 = pairhist(samples, low[mask], high[mask], bins=60, normed=True, bounded=True)

In [None]:
print('KL =', kld(hists_lhd[4][3], hists_345[2][1]))
print('EMD =', remd(hists_lhd[4][3], hists_345[2][1]))

In [None]:
print('KL =', kld(hists_all[4][3], hists_345[2][1]))
print('EMD =', remd(hists_all[4][3], hists_345[2][1]))

In [None]:
a = torch.cat([x for x in sampler(4096)])

In [None]:
a.shape