In [1]:
import numpy as np
import torch

import matplotlib
import matplotlib.pyplot as plt
%matplotlib widget

import corner

import seaborn as sns
sns.set()

In [2]:
means = torch.Tensor([
    [0.0, 0.0],
    [2, 3],
    [2, -3]
])
covars = torch.Tensor([
    [
        [0.1, 0],
        [0, 1.5]
    ],
    [
        [1, 0],
        [0, 0.1]
    ],
    [
        [1, 0],
        [0, 0.1]
    ]
])

Z = torch.distributions.MultivariateNormal(loc=means, covariance_matrix=covars).sample((50000,)).reshape((-1, 2))
idx = torch.randperm(Z.shape[0])
Z = Z[idx]
Z_train = Z[:Z.shape[0] // 2]
Z_test = Z[Z.shape[0] // 2:]

x_lim = (-2, 6)
y_lim = (-6, 6)

fig, ax = plt.subplots()
corner.hist2d(Z_train[:, 0].numpy(), Z_train[:, 1].numpy(), ax=ax)
ax.set_xlim(x_lim)
ax.set_ylim(y_lim)
ax.set_title(r'$P(\mathbf{z})$')

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Text(0.5, 1.0, '$P(\\mathbf{z})$')

In [3]:
S = torch.Tensor([
    [0.1, 0],
    [0, 3]
])
X = Z + torch.distributions.MultivariateNormal(loc=torch.Tensor([0.0, 0.0]), covariance_matrix=S).sample((Z.shape[0],))
X_train = X[:X.shape[0] // 2]
X_test = X[X.shape[0] // 2:]

fig, ax = plt.subplots()
corner.hist2d(X_train[:, 0].numpy(), X_train[:, 1].numpy())
ax.set_xlim(x_lim)
ax.set_ylim(y_lim)
ax.set_title(r'$P(\mathbf{x})$')

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Text(0.5, 1.0, '$P(\\mathbf{x})$')

In [4]:
from deconv.gmm.data import DeconvDataset

train_data = DeconvDataset(X_train, torch.cholesky(S.repeat(X_train.shape[0], 1, 1)))
test_data = DeconvDataset(X_test, torch.cholesky(S.repeat(X_test.shape[0], 1, 1)))

In [5]:
from deconv.flow.svi import SVIFlow
svi = SVIFlow(
    2,
    5,
    device=torch.device('cuda'),
    batch_size=512,
    epochs=10,
    lr=1e-4,
    n_samples=50,
    use_iwae=True
)

In [None]:
svi.fit(train_data, val_data=None)

Epoch 0, Train Loss: -4.882514130859375
Epoch 1, Train Loss: -4.107092342122396
Epoch 2, Train Loss: -4.0849835546875


In [7]:
torch.set_default_tensor_type('torch.cuda.FloatTensor')
prior_samples = svi.sample_prior(10000)

fig, ax = plt.subplots()
corner.hist2d(prior_samples[0, :, 0].numpy(), prior_samples[0, :, 1].numpy(), ax=ax)
ax.set_xlim(x_lim)
ax.set_ylim(y_lim)
ax.set_title('Prior')

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Text(0.5, 1.0, 'Prior')

In [8]:
mean = np.array([[3.0, 0.0], [0.0, 0.0]])
cov = np.array([
    [
        [0.1, 0],
        [0, 3]
    ],
    [
        [0.1, 0],
        [0, 3]
    ]
])
test_point = [
    torch.Tensor(mean).to(svi.device),
    torch.cholesky(torch.Tensor(cov)).to(svi.device)
]
torch.set_default_tensor_type('torch.cuda.FloatTensor')
posterior_samples = svi.sample_posterior(test_point, 10000)

from deconv.gmm.plotting import plot_covariance
fig, ax = plt.subplots()
corner.hist2d(posterior_samples[0, :, 0].numpy(), posterior_samples[0, :, 1].numpy(), ax=ax)
ax.set_xlim(x_lim)
ax.set_ylim(y_lim)
plot_covariance(
    mean[0],
    cov[0],
    ax=ax,
    color='r'
)
ax.set_xlim(x_lim)
ax.set_ylim(y_lim)
ax.set_title('Posterior for test point')

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Text(0.5, 1.0, 'Posterior for test point')

In [9]:
rsamples = svi.resample_posterior(test_point, 10000)

fig, ax = plt.subplots()
corner.hist2d(rsamples[0, :, 0].cpu().numpy(), rsamples[0, :, 1].cpu().numpy(), ax=ax)
ax.set_xlim(x_lim)
ax.set_ylim(y_lim)
plot_covariance(
    mean[0],
    cov[0],
    ax=ax,
    color='r'
)
ax.set_xlim(x_lim)
ax.set_ylim(y_lim)
ax.set_title('Posterior for test point')

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Text(0.5, 1.0, 'Posterior for test point')

In [None]:
idx.shape

In [None]:
samples[torch.arange(2)[:, None, None], idx.T[:, :, None], torch.arange(2)[None, None, :]].shape

In [None]:
rsamples.squeeze().shape

In [None]:
samples.shape

In [None]:
np.newaxis

In [None]:
len(samples)