In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

In [3]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

# SLCP experiment

In [4]:
from simulators import SLCP
from datasets import LTEDataset
from models import Flatten, MNRE
from criterions import WeightedLoss, RELoss, RDLoss
from samplers import TractableSampler, NRESampler
from histograms import pairhist, corner

## Simulator 

In [5]:
simulator = SLCP().to(device)
simulator.sample()

(tensor([-1.4355, -1.8684,  2.5289, -0.5714, -2.0816], device='cuda:0'),
 tensor([[-13.7469,  -1.3300],
         [  7.9673,  -2.3458],
         [  4.1226,  -2.1861],
         [ -0.2871,  -1.8487]], device='cuda:0'))

In [6]:
trainset = LTEDataset(simulator, mode=-1)

## Model

In [7]:
joint_subsets = torch.tensor([
    [False, False, False, True, True],
    [False, False, True, True, True],
    [True, True, True, True, True]
])

joint_subsets

tensor([[False, False, False,  True,  True],
        [False, False,  True,  True,  True],
        [ True,  True,  True,  True,  True]])

In [8]:
model = MNRE(masks=joint_subsets, encoder=Flatten((4, 2)), num_layers=10, hidden_size=512, activation=nn.SELU).to(device)
criterion = WeightedLoss([RDLoss(joint_subsets)], [.01]).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-5, amsgrad=True)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=2.5e-1, patience=5, threshold=1e-2, verbose=True)

In [9]:
model.load_state_dict(torch.load('mnre.pth', map_location=device))

<All keys matched successfully>

## Metropolis-Hastings

In [10]:
low = simulator.low.cpu()
high = simulator.high.cpu()

In [11]:
theta_star = torch.tensor([0.7, -2.9, -1., -0.9,  0.6]).to(device)
x_star = torch.tensor([
    [-0.48406151, -3.13977371],
    [-0.43098274, -3.50238278],
    [-0.03512463, -2.87554796],
    [ 1.43279532, -2.80650507]
]).to(device)

In [12]:
x_star = x_star.expand(2 ** 12, -1, -1)

### Likelihood

In [13]:
sampler = TractableSampler(simulator, x_star, sigma=0.1)
samples = sampler(4096)
hists_lhd = pairhist(samples, low, high, bins=60, normed=True, bounded=True)

### MNRE

In [14]:
model.eval()
z_star = model.encoder(x_star)

#### All

In [15]:
_, nre = model[2]

In [16]:
sampler = NRESampler(simulator.prior, nre, z_star, sigma=0.1)
samples = sampler(4096)
hists_all = pairhist(samples, low, high, bins=60, normed=True, bounded=True)

In [17]:
def kl_divergence(p: torch.Tensor, q: torch.Tensor, epsilon: float = 1e-8) -> torch.Tensor:
    mask = p > 0
    
    return (p * ((p + epsilon).log() - (q + epsilon).log())).sum()

kl_divergence(hists_lhd[4][3], hists_all[4][3])

tensor(0.0905)

#### 4 & 5

In [18]:
mask, nre = model[0]
mask = mask.cpu()

In [19]:
sampler = NRESampler(simulator.subprior(mask), nre, z_star, sigma=0.1)
samples = sampler(4096)
hists_45 = pairhist(samples, low[mask], high[mask], bins=60, normed=True, bounded=True)

In [20]:
kl_divergence(hists_lhd[4][3], hists_45[1][0])

tensor(0.2208)

In [21]:
kl_divergence(hists_all[4][3], hists_45[1][0])

tensor(0.3862)

#### 3, 4 & 5

In [22]:
mask, nre = model[1]
mask = mask.cpu()

In [23]:
sampler = NRESampler(simulator.subprior(mask), nre, z_star, sigma=0.1)
samples = sampler(4096)
hists_345 = pairhist(samples, low[mask], high[mask], bins=60, normed=True, bounded=True)

In [24]:
kl_divergence(hists_lhd[4][3], hists_345[2][1])

tensor(0.2821)

In [25]:
kl_divergence(hists_all[4][3], hists_345[2][1])

tensor(0.3039)

## Distillation

In [27]:
model.train()

epoch = 0
epoch_size = 256

losses = []

for thetas, xs, mask in trainset:
    ratios = model(thetas, xs)
    loss = criterion(ratios, mask)

    losses.append(loss.tolist())

    optimizer.zero_grad()
    loss.backward()
    nn.utils.clip_grad_norm_(model.parameters(), 1.)
    optimizer.step()

    if len(losses) == epoch_size:
        losses = torch.tensor(losses)

        print(f'{epoch}: {losses.mean()} +- {losses.std()}')
        scheduler.step(losses.mean())

        epoch += 1
        losses = []

        if optimizer.param_groups[0]['lr'] < 1e-6:
            break

0: 6434.47998046875 +- 98648.21875
1: 243.2634735107422 +- 2301.369384765625
2: 107.98551940917969 +- 793.1503295898438
3: 33.458988189697266 +- 140.18136596679688
4: 248.1635284423828 +- 2836.357421875
5: 82.89403533935547 +- 663.243408203125
6: 28.50716209411621 +- 182.9430389404297
7: 62.96383285522461 +- 375.9570617675781
8: 103.03263092041016 +- 606.2960205078125
9: 129.53663635253906 +- 1362.8341064453125
10: 1564.6214599609375 +- 15017.2529296875
11: 23812.28125 +- 379942.65625
12: 80.72254180908203 +- 634.89208984375
Epoch    13: reducing learning rate of group 0 to 2.5000e-06.
13: 51644.11328125 +- 824921.3125
14: 348.1349182128906 +- 5117.32470703125
15: 147.636474609375 +- 1988.786376953125
16: 191.64784240722656 +- 1178.849609375
17: 674.4408569335938 +- 9948.1826171875
18: 51.47740173339844 +- 270.3169860839844
Epoch    19: reducing learning rate of group 0 to 6.2500e-07.


## Metropolis-Hastings
### MNRE

In [28]:
model.eval()
z_star = model.encoder(x_star)

#### 4 & 5

In [29]:
mask, nre = model[0]
mask = mask.cpu()

In [30]:
sampler = NRESampler(simulator.subprior(mask), nre, z_star, sigma=0.1)
samples = sampler(4096)
hists_45 = pairhist(samples, low[mask], high[mask], bins=60, normed=True, bounded=True)

In [31]:
kl_divergence(hists_lhd[4][3], hists_45[1][0])

tensor(0.2862)

In [32]:
kl_divergence(hists_all[4][3], hists_45[1][0])

tensor(0.2462)

#### 3, 4 & 5

In [33]:
mask, nre = model[1]
mask = mask.cpu()

In [34]:
sampler = NRESampler(simulator.subprior(mask), nre, z_star, sigma=0.1)
samples = sampler(4096)
hists_345 = pairhist(samples, low[mask], high[mask], bins=60, normed=True, bounded=True)

In [35]:
kl_divergence(hists_lhd[4][3], hists_345[2][1])

tensor(0.1876)

In [36]:
kl_divergence(hists_all[4][3], hists_345[2][1])

tensor(0.1468)