In [2]:
import numpy as np
import torch

import matplotlib
import matplotlib.pyplot as plt
%matplotlib widget

import corner

import seaborn as sns
sns.set()

In [3]:
means = torch.Tensor([
    [0.0, 0.0],
    [2, 3],
    [2, -3]
])
covars = torch.Tensor([
    [
        [0.1, 0],
        [0, 1.5]
    ],
    [
        [1, 0],
        [0, 0.1]
    ],
    [
        [1, 0],
        [0, 0.1]
    ]
])

Z = torch.distributions.MultivariateNormal(loc=means, covariance_matrix=covars).sample((50000,)).reshape((-1, 2))
idx = torch.randperm(Z.shape[0])
Z = Z[idx]
Z_train = Z[:Z.shape[0] // 2]
Z_test = Z[Z.shape[0] // 2:]

x_lim = (-2, 6)
y_lim = (-6, 6)

fig, ax = plt.subplots()
corner.hist2d(Z_train[:, 0].numpy(), Z_train[:, 1].numpy(), ax=ax)
ax.set_xlim(x_lim)
ax.set_ylim(y_lim)
ax.set_title(r'$P(\mathbf{z})$')

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Text(0.5, 1.0, '$P(\\mathbf{z})$')

In [4]:
S = torch.Tensor([
    [0.1, 0],
    [0, 3]
])
X = Z + torch.distributions.MultivariateNormal(loc=torch.Tensor([0.0, 0.0]), covariance_matrix=S).sample((Z.shape[0],))
X_train = X[:X.shape[0] // 2]
X_test = X[X.shape[0] // 2:]

fig, ax = plt.subplots()
corner.hist2d(X_train[:, 0].numpy(), X_train[:, 1].numpy())
ax.set_xlim(x_lim)
ax.set_ylim(y_lim)
ax.set_title(r'$P(\mathbf{x})$')

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Text(0.5, 1.0, '$P(\\mathbf{x})$')

In [5]:
from deconv.gmm.data import DeconvDataset

train_data = DeconvDataset(X_train, S.repeat(X_train.shape[0], 1, 1))
test_data = DeconvDataset(X_test, S.repeat(X_test.shape[0], 1, 1))

In [6]:
from deconv.flow.svi import SVIFlow
svi = SVIFlow(
    2,
    5,
    device=torch.device('cuda'),
    batch_size=512,
    epochs=25,
    lr=1e-4
)

In [7]:
svi.fit(train_data, val_data=None)

Epoch 0, Train Loss: -5.088949020182292
Epoch 1, Train Loss: -4.314747986653646
Epoch 2, Train Loss: -4.127806844889323
Epoch 3, Train Loss: -4.1097688818359375
Epoch 4, Train Loss: -4.100813252766927
Epoch 5, Train Loss: -4.0980733203125
Epoch 6, Train Loss: -4.094063190104166
Epoch 7, Train Loss: -4.089927169596354
Epoch 8, Train Loss: -4.0882167464192705
Epoch 9, Train Loss: -4.0847424308268225
Epoch 10, Train Loss: -4.083275260416666
Epoch 11, Train Loss: -4.083978240559896
Epoch 12, Train Loss: -4.080865526529948
Epoch 13, Train Loss: -4.0820182763671875
Epoch 14, Train Loss: -4.080985008138021
Epoch 15, Train Loss: -4.078597290039062
Epoch 16, Train Loss: -4.079479127604166
Epoch 17, Train Loss: -4.079202639973959
Epoch 18, Train Loss: -4.079061813151042
Epoch 19, Train Loss: -4.079924132486979
Epoch 20, Train Loss: -4.078985638020833
Epoch 21, Train Loss: -4.07660890625
Epoch 22, Train Loss: -4.078517631022136
Epoch 23, Train Loss: -4.076684321289062
Epoch 24, Train Loss: -4.076

In [None]:
import torch.utils.data as data_utils

optimiser = torch.optim.Adam(params=svi.model._prior.parameters(), lr=1e-4)
loader = data_utils.DataLoader(
    Z_train,
    batch_size=512,
    shuffle=True,
    num_workers=4,
    pin_memory=True
)

for i in range(50):
    svi.model._prior.train()
    
    train_loss = 0.0
    
    for j, d in enumerate(loader):
        d = d.to(svi.device)
        optimiser.zero_grad()
        
        torch.set_default_tensor_type('torch.cuda.FloatTensor')
        log_p = svi.model._prior.log_prob(d)
        loss = -1 * torch.mean(log_p)
        loss.backward()
        optimiser.step()
        
        train_loss += torch.sum(log_p).item()
        
    print('Epoch {}, Train Loss: {}'.format(i, train_loss / Z_train.shape[0]))

In [None]:
prior_samples = svi.sample_prior(10000)

fig, ax = plt.subplots()
corner.hist2d(prior_samples[:, 0].cpu().numpy(), prior_samples[:, 1].cpu().numpy(), ax=ax)
ax.set_xlim(x_lim)
ax.set_ylim(y_lim)
ax.set_title('Prior fitted directly to $\mathbf{z}$')

In [None]:
for param in svi.model._prior.parameters():
    param.requires_grad = False
svi.fit(train_data, val_data=None)

In [None]:
test_point = (
    torch.Tensor([[3.0, 0.0]]).to(svi.device),
    torch.cholesky(torch.Tensor([[
        [0.1, 0],
        [0, 3]
    ]])).to(svi.device)
)
ctx = svi.model._inputs_encoder(test_point)

posterior_samples = torch.zeros((10000, 2)).cpu()

torch.set_default_tensor_type('torch.cuda.FloatTensor')

for i in range(10000 // 500):
    with torch.no_grad():
        start = i * 500
        stop = (i + 1) * 500
        posterior_samples[start:stop, :] = svi.model._approximate_posterior.sample(500, context=ctx).cpu()[0, :, :]

In [None]:
from deconv.gmm.plotting import plot_covariance
fig, ax = plt.subplots()
corner.hist2d(posterior_samples[:, 0].numpy(), posterior_samples[:, 1].numpy(), ax=ax)
ax.set_xlim(-2, 7)
ax.set_ylim(-5, 7)
plot_covariance(
    np.array([3.0, 0.0]),
    np.array([
        [0.1, 0],
        [0, 3]
    ]),
    ax=ax,
    color='r'
)
ax.set_xlim(x_lim)
ax.set_ylim(y_lim)
ax.set_title('Posterior for test point after fitting with frozen pretrained prior')

In [None]:
for param in svi.model._prior.parameters():
    param.requires_grad = True
svi.fit(train_data, val_data=None)

In [None]:
test_point = (
    torch.Tensor([[2.0, 0]]).to(svi.device),
    torch.cholesky(torch.Tensor([[
        [0.1, 0],
        [0, 3]
    ]])).to(svi.device)
)
ctx = svi.model._inputs_encoder(test_point)

posterior_samples = torch.zeros((10000, 2)).cpu()

torch.set_default_tensor_type('torch.cuda.FloatTensor')

for i in range(10000 // 500):
    with torch.no_grad():
        start = i * 500
        stop = (i + 1) * 500
        posterior_samples[start:stop, :] = svi.model._approximate_posterior.sample(500, context=ctx).cpu()[0, :, :]

In [None]:
with torch.no_grad():
    prior_samples = svi.model._prior.sample(10000)

fig, ax = plt.subplots()
corner.hist2d(prior_samples[:, 0].cpu().numpy(), prior_samples[:, 1].cpu().numpy(), ax=ax)
ax.set_xlim(x_lim)
ax.set_ylim(y_lim)
ax.set_title('Prior after joint fitting with posterior')

In [None]:
from deconv.gmm.plotting import plot_covariance
fig, ax = plt.subplots()
corner.hist2d(posterior_samples[:, 0].numpy(), posterior_samples[:, 1].numpy(), ax=ax)
ax.set_xlim(x_lim)
ax.set_ylim(y_lim)
plot_covariance(
    np.array([2.0, 0]),
    np.array([
        [0.1, 0],
        [0, 3]
    ]),
    ax=ax,
    color='r'
)
ax.set_xlim(x_lim)
ax.set_ylim(y_lim)
ax.set_title('Posterior for test point after joint fitting')

In [None]:
from deconv.flow.svi_mdn import SVIMDNFlow

In [None]:
svi_mdn = SVIMDNFlow(
    2,
    5,
    device=torch.device('cuda'),
    batch_size=512,
    epochs=100,
    lr=1e-4,
    kl_warmup=0.0,
    kl_init_factor=1.0
)

In [None]:
import torch.utils.data as data_utils

optimiser = torch.optim.Adam(params=svi_mdn.model._prior.parameters(), lr=1e-4)
loader = data_utils.DataLoader(
    Z_train,
    batch_size=512,
    shuffle=True,
    num_workers=4,
    pin_memory=True
)

for i in range(50):
    svi_mdn.model._prior.train()
    
    train_loss = 0.0
    
    for j, d in enumerate(loader):
        d = d.to(svi_mdn.device)
        optimiser.zero_grad()
        
        torch.set_default_tensor_type('torch.cuda.FloatTensor')
        log_p = svi_mdn.model._prior.log_prob(d)
        loss = -1 * torch.mean(log_p)
        loss.backward()
        optimiser.step()
        
        train_loss += torch.sum(log_p).item()
        
    print('Epoch {}, Train Loss: {}'.format(i, train_loss / Z_train.shape[0]))

In [None]:
for param in svi_mdn.model._prior.parameters():
    param.requires_grad = False
svi_mdn.fit(train_data, val_data=None)

In [None]:
torch.set_default_tensor_type('torch.cuda.FloatTensor')

with torch.no_grad():
    prior_samples = svi_mdn.model._prior.sample(10000)
    
fig, ax = plt.subplots()
corner.hist2d(prior_samples[:, 0].cpu().numpy(), prior_samples[:, 1].cpu().numpy(), ax=ax)
ax.set_xlim(x_lim)
ax.set_ylim(y_lim)
ax.set_title('Prefitted Prior')

In [None]:
test_point = (
    torch.Tensor([[3.0, -6.0]]).to(svi_mdn.device),
    torch.cholesky(torch.Tensor([[
        [0.1, 0],
        [0, 3]
    ]])).to(svi_mdn.device)
)
ctx = svi_mdn.model._inputs_encoder(test_point)

posterior_samples = torch.zeros((10000, 2)).cpu()

torch.set_default_tensor_type('torch.cuda.FloatTensor')

for i in range(10000 // 500):
    with torch.no_grad():
        start = i * 500
        stop = (i + 1) * 500
        posterior_samples[start:stop, :] = svi_mdn.model._approximate_posterior.sample(500, context=ctx).cpu()[0, :, :]

In [None]:
from deconv.gmm.plotting import plot_covariance
fig, ax = plt.subplots()
corner.hist2d(posterior_samples[:, 0].numpy(), posterior_samples[:, 1].numpy(), ax=ax)
ax.set_xlim(x_lim)
ax.set_ylim(y_lim)
plot_covariance(
    test_point[0].cpu().numpy()[0],
    np.array([
        [0.1, 0],
        [0, 3]
    ]),
    ax=ax,
    color='r'
)
ax.set_xlim(x_lim)
ax.set_ylim(y_lim)
ax.set_title('Posterior for test point with MDN and pretrained prior.')

In [None]:
test_point[0].cpu().numpy()[0]