In [1]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader

from activations import Cosine
from networks import WHVIRegression
from layers import WHVILinear
from torch_datasets import ToyDataset

In [2]:
torch.manual_seed(0)  # Seed for reproducibility

# Data
dataset = ToyDataset(n=128)
data_loader = DataLoader(dataset, batch_size=64)

In [3]:
net = WHVIRegression([
    nn.Linear(1, 128),
    Cosine(),
    WHVILinear(128, 30, lambda_=1e-5),
    Cosine(),
    WHVILinear(30, 20, lambda_=1e-5),
    Cosine(),
    nn.Linear(20, 1)
])

gamma=0.0005
p = 0.3
optimizer = optim.Adam(net.parameters(), lr=1e-3)
lr_scheduler = optim.lr_scheduler.LambdaLR(optimizer, lambda t: (1 + gamma * t)**(-p))

In [4]:
net.train_model(data_loader, optimizer, epochs1=500, epochs2=50000)

[Fix. var.] KL = 879.44, MNLL = 41.91:   1%|          | 3/500 [00:00<00:18, 26.29it/s]       

torch.Size([64, 128]) torch.Size([64, 128]) torch.Size([64, 30])
torch.Size([64, 30]) torch.Size([64, 32]) torch.Size([64, 20])
torch.Size([25, 128]) torch.Size([25, 128]) torch.Size([25, 30])
torch.Size([25, 30]) torch.Size([25, 32]) torch.Size([25, 20])
torch.Size([64, 128]) torch.Size([64, 128]) torch.Size([64, 30])
torch.Size([64, 30]) torch.Size([64, 32]) torch.Size([64, 20])
torch.Size([25, 128]) torch.Size([25, 128]) torch.Size([25, 30])
torch.Size([25, 30]) torch.Size([25, 32]) torch.Size([25, 20])
torch.Size([64, 128]) torch.Size([64, 128]) torch.Size([64, 30])
torch.Size([64, 30]) torch.Size([64, 32]) torch.Size([64, 20])
torch.Size([25, 128]) torch.Size([25, 128]) torch.Size([25, 30])
torch.Size([25, 30]) torch.Size([25, 32]) torch.Size([25, 20])
torch.Size([64, 128]) torch.Size([64, 128]) torch.Size([64, 30])
torch.Size([64, 30]) torch.Size([64, 32]) torch.Size([64, 20])
torch.Size([25, 128]) torch.Size([25, 128]) torch.Size([25, 30])
torch.Size([25, 30]) torch.Size([25, 32

[Fix. var.] KL = 879.44, MNLL = 41.91:   3%|▎         | 13/500 [00:00<00:12, 39.11it/s]

torch.Size([25, 30]) torch.Size([25, 32]) torch.Size([25, 20])
torch.Size([64, 128]) torch.Size([64, 128]) torch.Size([64, 30])
torch.Size([64, 30]) torch.Size([64, 32]) torch.Size([64, 20])
torch.Size([25, 128]) torch.Size([25, 128]) torch.Size([25, 30])
torch.Size([25, 30]) torch.Size([25, 32]) torch.Size([25, 20])
torch.Size([64, 128]) torch.Size([64, 128]) torch.Size([64, 30])
torch.Size([64, 30]) torch.Size([64, 32]) torch.Size([64, 20])
torch.Size([25, 128]) torch.Size([25, 128]) torch.Size([25, 30])
torch.Size([25, 30]) torch.Size([25, 32]) torch.Size([25, 20])
torch.Size([64, 128]) torch.Size([64, 128]) torch.Size([64, 30])
torch.Size([64, 30]) torch.Size([64, 32]) torch.Size([64, 20])
torch.Size([25, 128]) torch.Size([25, 128]) torch.Size([25, 30])
torch.Size([25, 30]) torch.Size([25, 32]) torch.Size([25, 20])
torch.Size([64, 128]) torch.Size([64, 128]) torch.Size([64, 30])
torch.Size([64, 30]) torch.Size([64, 32]) torch.Size([64, 20])
torch.Size([25, 128]) torch.Size([25, 128

[Fix. var.] KL = 797.40, MNLL = 33.52:   4%|▎         | 18/500 [00:00<00:11, 40.53it/s]

torch.Size([25, 128]) torch.Size([25, 128]) torch.Size([25, 30])
torch.Size([25, 30]) torch.Size([25, 32]) torch.Size([25, 20])
torch.Size([64, 128]) torch.Size([64, 128]) torch.Size([64, 30])
torch.Size([64, 30]) torch.Size([64, 32]) torch.Size([64, 20])
torch.Size([25, 128]) torch.Size([25, 128]) torch.Size([25, 30])
torch.Size([25, 30]) torch.Size([25, 32]) torch.Size([25, 20])
torch.Size([64, 128]) torch.Size([64, 128]) torch.Size([64, 30])
torch.Size([64, 30]) torch.Size([64, 32]) torch.Size([64, 20])
torch.Size([25, 128]) torch.Size([25, 128]) torch.Size([25, 30])
torch.Size([25, 30]) torch.Size([25, 32]) torch.Size([25, 20])
torch.Size([64, 128]) torch.Size([64, 128]) torch.Size([64, 30])
torch.Size([64, 30]) torch.Size([64, 32]) torch.Size([64, 20])
torch.Size([25, 128]) torch.Size([25, 128]) torch.Size([25, 30])
torch.Size([25, 30]) torch.Size([25, 32]) torch.Size([25, 20])
torch.Size([64, 128]) torch.Size([64, 128]) torch.Size([64, 30])
torch.Size([64, 30]) torch.Size([64, 32

[Fix. var.] KL = 797.40, MNLL = 33.52:   6%|▌         | 28/500 [00:00<00:11, 41.36it/s]

torch.Size([64, 30]) torch.Size([64, 32]) torch.Size([64, 20])
torch.Size([25, 128]) torch.Size([25, 128]) torch.Size([25, 30])
torch.Size([25, 30]) torch.Size([25, 32]) torch.Size([25, 20])
torch.Size([64, 128]) torch.Size([64, 128]) torch.Size([64, 30])
torch.Size([64, 30]) torch.Size([64, 32]) torch.Size([64, 20])
torch.Size([25, 128]) torch.Size([25, 128]) torch.Size([25, 30])
torch.Size([25, 30]) torch.Size([25, 32]) torch.Size([25, 20])
torch.Size([64, 128]) torch.Size([64, 128]) torch.Size([64, 30])
torch.Size([64, 30]) torch.Size([64, 32]) torch.Size([64, 20])
torch.Size([25, 128]) torch.Size([25, 128]) torch.Size([25, 30])
torch.Size([25, 30]) torch.Size([25, 32]) torch.Size([25, 20])
torch.Size([64, 128]) torch.Size([64, 128]) torch.Size([64, 30])
torch.Size([64, 30]) torch.Size([64, 32]) torch.Size([64, 20])
torch.Size([25, 128]) torch.Size([25, 128]) torch.Size([25, 30])
torch.Size([25, 30]) torch.Size([25, 32]) torch.Size([25, 20])
torch.Size([64, 128]) torch.Size([64, 128

[Fix. var.] KL = 797.40, MNLL = 33.52:   8%|▊         | 38/500 [00:00<00:10, 42.94it/s]

torch.Size([64, 30]) torch.Size([64, 32]) torch.Size([64, 20])
torch.Size([25, 128]) torch.Size([25, 128]) torch.Size([25, 30])
torch.Size([25, 30]) torch.Size([25, 32]) torch.Size([25, 20])
torch.Size([64, 128]) torch.Size([64, 128]) torch.Size([64, 30])
torch.Size([64, 30]) torch.Size([64, 32]) torch.Size([64, 20])
torch.Size([25, 128]) torch.Size([25, 128]) torch.Size([25, 30])
torch.Size([25, 30]) torch.Size([25, 32]) torch.Size([25, 20])
torch.Size([64, 128]) torch.Size([64, 128]) torch.Size([64, 30])
torch.Size([64, 30]) torch.Size([64, 32]) torch.Size([64, 20])
torch.Size([25, 128]) torch.Size([25, 128]) torch.Size([25, 30])
torch.Size([25, 30]) torch.Size([25, 32]) torch.Size([25, 20])
torch.Size([64, 128]) torch.Size([64, 128]) torch.Size([64, 30])
torch.Size([64, 30]) torch.Size([64, 32]) torch.Size([64, 20])
torch.Size([25, 128]) torch.Size([25, 128]) torch.Size([25, 30])
torch.Size([25, 30]) torch.Size([25, 32]) torch.Size([25, 20])
torch.Size([64, 128]) torch.Size([64, 128

[Fix. var.] KL = 725.11, MNLL = 33.10:   9%|▊         | 43/500 [00:01<00:10, 42.84it/s]

torch.Size([25, 30]) torch.Size([25, 32]) torch.Size([25, 20])
torch.Size([64, 128]) torch.Size([64, 128]) torch.Size([64, 30])
torch.Size([64, 30]) torch.Size([64, 32]) torch.Size([64, 20])
torch.Size([25, 128]) torch.Size([25, 128]) torch.Size([25, 30])
torch.Size([25, 30]) torch.Size([25, 32]) torch.Size([25, 20])
torch.Size([64, 128]) torch.Size([64, 128]) torch.Size([64, 30])
torch.Size([64, 30]) torch.Size([64, 32]) torch.Size([64, 20])
torch.Size([25, 128]) torch.Size([25, 128]) torch.Size([25, 30])
torch.Size([25, 30]) torch.Size([25, 32]) torch.Size([25, 20])
torch.Size([64, 128]) torch.Size([64, 128]) torch.Size([64, 30])
torch.Size([64, 30]) torch.Size([64, 32]) torch.Size([64, 20])
torch.Size([25, 128]) torch.Size([25, 128]) torch.Size([25, 30])
torch.Size([25, 30]) torch.Size([25, 32]) torch.Size([25, 20])
torch.Size([64, 128]) torch.Size([64, 128]) torch.Size([64, 30])
torch.Size([64, 30]) torch.Size([64, 32]) torch.Size([64, 20])
torch.Size([25, 128]) torch.Size([25, 128

[Fix. var.] KL = 725.11, MNLL = 33.10:  11%|█         | 53/500 [00:01<00:10, 43.10it/s]

torch.Size([25, 30]) torch.Size([25, 32]) torch.Size([25, 20])
torch.Size([64, 128]) torch.Size([64, 128]) torch.Size([64, 30])
torch.Size([64, 30]) torch.Size([64, 32]) torch.Size([64, 20])
torch.Size([25, 128]) torch.Size([25, 128]) torch.Size([25, 30])
torch.Size([25, 30]) torch.Size([25, 32]) torch.Size([25, 20])
torch.Size([64, 128]) torch.Size([64, 128]) torch.Size([64, 30])
torch.Size([64, 30]) torch.Size([64, 32]) torch.Size([64, 20])
torch.Size([25, 128]) torch.Size([25, 128]) torch.Size([25, 30])
torch.Size([25, 30]) torch.Size([25, 32]) torch.Size([25, 20])
torch.Size([64, 128]) torch.Size([64, 128]) torch.Size([64, 30])
torch.Size([64, 30]) torch.Size([64, 32]) torch.Size([64, 20])
torch.Size([25, 128]) torch.Size([25, 128]) torch.Size([25, 30])
torch.Size([25, 30]) torch.Size([25, 32]) torch.Size([25, 20])
torch.Size([64, 128]) torch.Size([64, 128]) torch.Size([64, 30])
torch.Size([64, 30]) torch.Size([64, 32]) torch.Size([64, 20])
torch.Size([25, 128]) torch.Size([25, 128

[Fix. var.] KL = 661.13, MNLL = 33.14:  12%|█▏        | 58/500 [00:01<00:10, 43.09it/s]

torch.Size([25, 30]) torch.Size([25, 32]) torch.Size([25, 20])
torch.Size([64, 128]) torch.Size([64, 128]) torch.Size([64, 30])
torch.Size([64, 30]) torch.Size([64, 32]) torch.Size([64, 20])
torch.Size([25, 128]) torch.Size([25, 128]) torch.Size([25, 30])
torch.Size([25, 30]) torch.Size([25, 32]) torch.Size([25, 20])
torch.Size([64, 128]) torch.Size([64, 128]) torch.Size([64, 30])
torch.Size([64, 30]) torch.Size([64, 32]) torch.Size([64, 20])
torch.Size([25, 128]) torch.Size([25, 128]) torch.Size([25, 30])
torch.Size([25, 30]) torch.Size([25, 32]) torch.Size([25, 20])
torch.Size([64, 128]) torch.Size([64, 128]) torch.Size([64, 30])
torch.Size([64, 30]) torch.Size([64, 32]) torch.Size([64, 20])
torch.Size([25, 128]) torch.Size([25, 128]) torch.Size([25, 30])
torch.Size([25, 30]) torch.Size([25, 32]) torch.Size([25, 20])
torch.Size([64, 128]) torch.Size([64, 128]) torch.Size([64, 30])
torch.Size([64, 30]) torch.Size([64, 32]) torch.Size([64, 20])
torch.Size([25, 128]) torch.Size([25, 128

[Fix. var.] KL = 661.13, MNLL = 33.14:  14%|█▎        | 68/500 [00:01<00:09, 43.39it/s]

torch.Size([25, 30]) torch.Size([25, 32]) torch.Size([25, 20])
torch.Size([64, 128]) torch.Size([64, 128]) torch.Size([64, 30])
torch.Size([64, 30]) torch.Size([64, 32]) torch.Size([64, 20])
torch.Size([25, 128]) torch.Size([25, 128]) torch.Size([25, 30])
torch.Size([25, 30]) torch.Size([25, 32]) torch.Size([25, 20])
torch.Size([64, 128]) torch.Size([64, 128]) torch.Size([64, 30])
torch.Size([64, 30]) torch.Size([64, 32]) torch.Size([64, 20])
torch.Size([25, 128]) torch.Size([25, 128]) torch.Size([25, 30])
torch.Size([25, 30]) torch.Size([25, 32]) torch.Size([25, 20])
torch.Size([64, 128]) torch.Size([64, 128]) torch.Size([64, 30])
torch.Size([64, 30]) torch.Size([64, 32]) torch.Size([64, 20])
torch.Size([25, 128]) torch.Size([25, 128]) torch.Size([25, 30])
torch.Size([25, 30]) torch.Size([25, 32]) torch.Size([25, 20])
torch.Size([64, 128]) torch.Size([64, 128]) torch.Size([64, 30])
torch.Size([64, 30]) torch.Size([64, 32]) torch.Size([64, 20])
torch.Size([25, 128]) torch.Size([25, 128

[Fix. var.] KL = 661.13, MNLL = 33.14:  16%|█▌        | 78/500 [00:01<00:09, 43.66it/s]

torch.Size([64, 30]) torch.Size([64, 32]) torch.Size([64, 20])
torch.Size([25, 128]) torch.Size([25, 128]) torch.Size([25, 30])
torch.Size([25, 30]) torch.Size([25, 32]) torch.Size([25, 20])
torch.Size([64, 128]) torch.Size([64, 128]) torch.Size([64, 30])
torch.Size([64, 30]) torch.Size([64, 32]) torch.Size([64, 20])
torch.Size([25, 128]) torch.Size([25, 128]) torch.Size([25, 30])
torch.Size([25, 30]) torch.Size([25, 32]) torch.Size([25, 20])
torch.Size([64, 128]) torch.Size([64, 128]) torch.Size([64, 30])
torch.Size([64, 30]) torch.Size([64, 32]) torch.Size([64, 20])
torch.Size([25, 128]) torch.Size([25, 128]) torch.Size([25, 30])
torch.Size([25, 30]) torch.Size([25, 32]) torch.Size([25, 20])
torch.Size([64, 128]) torch.Size([64, 128]) torch.Size([64, 30])
torch.Size([64, 30]) torch.Size([64, 32]) torch.Size([64, 20])
torch.Size([25, 128]) torch.Size([25, 128]) torch.Size([25, 30])
torch.Size([25, 30]) torch.Size([25, 32]) torch.Size([25, 20])
torch.Size([64, 128]) torch.Size([64, 128

[Fix. var.] KL = 604.27, MNLL = 33.14:  17%|█▋        | 83/500 [00:01<00:09, 43.82it/s]

torch.Size([64, 30]) torch.Size([64, 32]) torch.Size([64, 20])
torch.Size([25, 128]) torch.Size([25, 128]) torch.Size([25, 30])
torch.Size([25, 30]) torch.Size([25, 32]) torch.Size([25, 20])
torch.Size([64, 128]) torch.Size([64, 128]) torch.Size([64, 30])
torch.Size([64, 30]) torch.Size([64, 32]) torch.Size([64, 20])
torch.Size([25, 128]) torch.Size([25, 128]) torch.Size([25, 30])
torch.Size([25, 30]) torch.Size([25, 32]) torch.Size([25, 20])
torch.Size([64, 128]) torch.Size([64, 128]) torch.Size([64, 30])
torch.Size([64, 30]) torch.Size([64, 32]) torch.Size([64, 20])
torch.Size([25, 128]) torch.Size([25, 128]) torch.Size([25, 30])
torch.Size([25, 30]) torch.Size([25, 32]) torch.Size([25, 20])
torch.Size([64, 128]) torch.Size([64, 128]) torch.Size([64, 30])
torch.Size([64, 30]) torch.Size([64, 32]) torch.Size([64, 20])
torch.Size([25, 128]) torch.Size([25, 128]) torch.Size([25, 30])
torch.Size([25, 30]) torch.Size([25, 32]) torch.Size([25, 20])
torch.Size([64, 128]) torch.Size([64, 128

[Fix. var.] KL = 604.27, MNLL = 33.14:  19%|█▊        | 93/500 [00:02<00:09, 44.05it/s]

torch.Size([25, 128]) torch.Size([25, 30])
torch.Size([25, 30]) torch.Size([25, 32]) torch.Size([25, 20])
torch.Size([64, 128]) torch.Size([64, 128]) torch.Size([64, 30])
torch.Size([64, 30]) torch.Size([64, 32]) torch.Size([64, 20])
torch.Size([25, 128]) torch.Size([25, 128]) torch.Size([25, 30])
torch.Size([25, 30]) torch.Size([25, 32]) torch.Size([25, 20])
torch.Size([64, 128]) torch.Size([64, 128]) torch.Size([64, 30])
torch.Size([64, 30]) torch.Size([64, 32]) torch.Size([64, 20])
torch.Size([25, 128]) torch.Size([25, 128]) torch.Size([25, 30])
torch.Size([25, 30]) torch.Size([25, 32]) torch.Size([25, 20])
torch.Size([64, 128]) torch.Size([64, 128]) torch.Size([64, 30])
torch.Size([64, 30]) torch.Size([64, 32]) torch.Size([64, 20])
torch.Size([25, 128]) torch.Size([25, 128]) torch.Size([25, 30])
torch.Size([25, 30]) torch.Size([25, 32]) torch.Size([25, 20])
torch.Size([64, 128]) torch.Size([64, 128]) torch.Size([64, 30])
torch.Size([64, 30]) torch.Size([64, 32]) torch.Size([64, 20]

[Fix. var.] KL = 553.55, MNLL = 33.14:  21%|██        | 103/500 [00:02<00:09, 43.20it/s]

torch.Size([25, 30]) torch.Size([25, 32]) torch.Size([25, 20])
torch.Size([64, 128]) torch.Size([64, 128]) torch.Size([64, 30])
torch.Size([64, 30]) torch.Size([64, 32]) torch.Size([64, 20])
torch.Size([25, 128]) torch.Size([25, 128]) torch.Size([25, 30])
torch.Size([25, 30]) torch.Size([25, 32]) torch.Size([25, 20])
torch.Size([64, 128]) torch.Size([64, 128]) torch.Size([64, 30])
torch.Size([64, 30]) torch.Size([64, 32]) torch.Size([64, 20])
torch.Size([25, 128]) torch.Size([25, 128]) torch.Size([25, 30])
torch.Size([25, 30]) torch.Size([25, 32]) torch.Size([25, 20])
torch.Size([64, 128]) torch.Size([64, 128]) torch.Size([64, 30])
torch.Size([64, 30]) torch.Size([64, 32]) torch.Size([64, 20])
torch.Size([25, 128]) torch.Size([25, 128]) torch.Size([25, 30])
torch.Size([25, 30]) torch.Size([25, 32]) torch.Size([25, 20])
torch.Size([64, 128]) torch.Size([64, 128]) torch.Size([64, 30])
torch.Size([64, 30]) torch.Size([64, 32]) torch.Size([64, 20])
torch.Size([25, 128]) torch.Size([25, 128

[Fix. var.] KL = 553.55, MNLL = 33.14:  22%|██▏       | 108/500 [00:02<00:09, 42.61it/s]

torch.Size([25, 30]) torch.Size([25, 32]) torch.Size([25, 20])
torch.Size([64, 128]) torch.Size([64, 128]) torch.Size([64, 30])
torch.Size([64, 30]) torch.Size([64, 32]) torch.Size([64, 20])
torch.Size([25, 128]) torch.Size([25, 128]) torch.Size([25, 30])
torch.Size([25, 30]) torch.Size([25, 32]) torch.Size([25, 20])
torch.Size([64, 128]) torch.Size([64, 128]) torch.Size([64, 30])
torch.Size([64, 30]) torch.Size([64, 32]) torch.Size([64, 20])
torch.Size([25, 128]) torch.Size([25, 128]) torch.Size([25, 30])
torch.Size([25, 30]) torch.Size([25, 32]) torch.Size([25, 20])
torch.Size([64, 128]) torch.Size([64, 128]) torch.Size([64, 30])
torch.Size([64, 30]) torch.Size([64, 32]) torch.Size([64, 20])
torch.Size([25, 128]) torch.Size([25, 128]) torch.Size([25, 30])
torch.Size([25, 30]) torch.Size([25, 32]) torch.Size([25, 20])
torch.Size([64, 128]) torch.Size([64, 128]) torch.Size([64, 30])
torch.Size([64, 30]) torch.Size([64, 32]) torch.Size([64, 20])
torch.Size([25, 128]) torch.Size([25, 128

[Fix. var.] KL = 553.55, MNLL = 33.14:  24%|██▎       | 118/500 [00:02<00:08, 43.03it/s]

torch.Size([25, 30]) torch.Size([25, 32]) torch.Size([25, 20])
torch.Size([64, 128]) torch.Size([64, 128]) torch.Size([64, 30])
torch.Size([64, 30]) torch.Size([64, 32]) torch.Size([64, 20])
torch.Size([25, 128]) torch.Size([25, 128]) torch.Size([25, 30])
torch.Size([25, 30]) torch.Size([25, 32]) torch.Size([25, 20])
torch.Size([64, 128]) torch.Size([64, 128]) torch.Size([64, 30])
torch.Size([64, 30]) torch.Size([64, 32]) torch.Size([64, 20])
torch.Size([25, 128]) torch.Size([25, 128]) torch.Size([25, 30])
torch.Size([25, 30]) torch.Size([25, 32]) torch.Size([25, 20])
torch.Size([64, 128]) torch.Size([64, 128]) torch.Size([64, 30])
torch.Size([64, 30]) torch.Size([64, 32]) torch.Size([64, 20])
torch.Size([25, 128]) torch.Size([25, 128]) torch.Size([25, 30])
torch.Size([25, 30]) torch.Size([25, 32]) torch.Size([25, 20])
torch.Size([64, 128]) torch.Size([64, 128]) torch.Size([64, 30])
torch.Size([64, 30]) torch.Size([64, 32]) torch.Size([64, 20])
torch.Size([25, 128]) torch.Size([25, 128

[Fix. var.] KL = 508.15, MNLL = 33.14:  25%|██▍       | 123/500 [00:02<00:08, 42.68it/s]

torch.Size([25, 30]) torch.Size([25, 32]) torch.Size([25, 20])
torch.Size([64, 128]) torch.Size([64, 128]) torch.Size([64, 30])
torch.Size([64, 30]) torch.Size([64, 32]) torch.Size([64, 20])
torch.Size([25, 128]) torch.Size([25, 128]) torch.Size([25, 30])
torch.Size([25, 30]) torch.Size([25, 32]) torch.Size([25, 20])
torch.Size([64, 128]) torch.Size([64, 128]) torch.Size([64, 30])
torch.Size([64, 30]) torch.Size([64, 32]) torch.Size([64, 20])
torch.Size([25, 128]) torch.Size([25, 128]) torch.Size([25, 30])
torch.Size([25, 30]) torch.Size([25, 32]) torch.Size([25, 20])
torch.Size([64, 128]) torch.Size([64, 128]) torch.Size([64, 30])
torch.Size([64, 30]) torch.Size([64, 32]) torch.Size([64, 20])
torch.Size([25, 128]) torch.Size([25, 128]) torch.Size([25, 30])
torch.Size([25, 30]) torch.Size([25, 32]) torch.Size([25, 20])
torch.Size([64, 128]) torch.Size([64, 128]) torch.Size([64, 30])
torch.Size([64, 30]) torch.Size([64, 32]) torch.Size([64, 20])
torch.Size([25, 128]) torch.Size([25, 128

[Fix. var.] KL = 508.15, MNLL = 33.14:  27%|██▋       | 133/500 [00:03<00:08, 43.06it/s]

torch.Size([25, 30]) torch.Size([25, 32]) torch.Size([25, 20])
torch.Size([64, 128]) torch.Size([64, 128]) torch.Size([64, 30])
torch.Size([64, 30]) torch.Size([64, 32]) torch.Size([64, 20])
torch.Size([25, 128]) torch.Size([25, 128]) torch.Size([25, 30])
torch.Size([25, 30]) torch.Size([25, 32]) torch.Size([25, 20])
torch.Size([64, 128]) torch.Size([64, 128]) torch.Size([64, 30])
torch.Size([64, 30]) torch.Size([64, 32]) torch.Size([64, 20])
torch.Size([25, 128]) torch.Size([25, 128]) torch.Size([25, 30])
torch.Size([25, 30]) torch.Size([25, 32]) torch.Size([25, 20])
torch.Size([64, 128]) torch.Size([64, 128]) torch.Size([64, 30])
torch.Size([64, 30]) torch.Size([64, 32]) torch.Size([64, 20])
torch.Size([25, 128]) torch.Size([25, 128]) torch.Size([25, 30])
torch.Size([25, 30]) torch.Size([25, 32]) torch.Size([25, 20])
torch.Size([64, 128]) torch.Size([64, 128]) torch.Size([64, 30])
torch.Size([64, 30]) torch.Size([64, 32]) torch.Size([64, 20])
torch.Size([25, 128]) torch.Size([25, 128

[Fix. var.] KL = 467.37, MNLL = 33.14:  29%|██▊       | 143/500 [00:03<00:08, 42.65it/s]

torch.Size([25, 30]) torch.Size([25, 32]) torch.Size([25, 20])
torch.Size([64, 128]) torch.Size([64, 128]) torch.Size([64, 30])
torch.Size([64, 30]) torch.Size([64, 32]) torch.Size([64, 20])
torch.Size([25, 128]) torch.Size([25, 128]) torch.Size([25, 30])
torch.Size([25, 30]) torch.Size([25, 32]) torch.Size([25, 20])
torch.Size([64, 128]) torch.Size([64, 128]) torch.Size([64, 30])
torch.Size([64, 30]) torch.Size([64, 32]) torch.Size([64, 20])
torch.Size([25, 128]) torch.Size([25, 128]) torch.Size([25, 30])
torch.Size([25, 30]) torch.Size([25, 32]) torch.Size([25, 20])
torch.Size([64, 128]) torch.Size([64, 128]) torch.Size([64, 30])
torch.Size([64, 30]) torch.Size([64, 32]) torch.Size([64, 20])
torch.Size([25, 128]) torch.Size([25, 128]) torch.Size([25, 30])
torch.Size([25, 30]) torch.Size([25, 32]) torch.Size([25, 20])
torch.Size([64, 128]) torch.Size([64, 128]) torch.Size([64, 30])
torch.Size([64, 30]) torch.Size([64, 32]) torch.Size([64, 20])
torch.Size([25, 128]) torch.Size([25, 128

[Fix. var.] KL = 467.37, MNLL = 33.14:  30%|██▉       | 148/500 [00:03<00:08, 43.22it/s]

torch.Size([25, 30]) torch.Size([25, 32]) torch.Size([25, 20])
torch.Size([64, 128]) torch.Size([64, 128]) torch.Size([64, 30])
torch.Size([64, 30]) torch.Size([64, 32]) torch.Size([64, 20])
torch.Size([25, 128]) torch.Size([25, 128]) torch.Size([25, 30])
torch.Size([25, 30]) torch.Size([25, 32]) torch.Size([25, 20])
torch.Size([64, 128]) torch.Size([64, 128]) torch.Size([64, 30])
torch.Size([64, 30]) torch.Size([64, 32]) torch.Size([64, 20])
torch.Size([25, 128]) torch.Size([25, 128]) torch.Size([25, 30])
torch.Size([25, 30]) torch.Size([25, 32]) torch.Size([25, 20])
torch.Size([64, 128]) torch.Size([64, 128]) torch.Size([64, 30])
torch.Size([64, 30]) torch.Size([64, 32]) torch.Size([64, 20])
torch.Size([25, 128]) torch.Size([25, 128]) torch.Size([25, 30])
torch.Size([25, 30]) torch.Size([25, 32]) torch.Size([25, 20])
torch.Size([64, 128]) torch.Size([64, 128]) torch.Size([64, 30])
torch.Size([64, 30]) torch.Size([64, 32]) torch.Size([64, 20])
torch.Size([25, 128]) torch.Size([25, 128

[Fix. var.] KL = 467.37, MNLL = 33.14:  32%|███▏      | 158/500 [00:03<00:08, 42.65it/s]

torch.Size([64, 30]) torch.Size([64, 32]) torch.Size([64, 20])
torch.Size([25, 128]) torch.Size([25, 128]) torch.Size([25, 30])
torch.Size([25, 30]) torch.Size([25, 32]) torch.Size([25, 20])
torch.Size([64, 128]) torch.Size([64, 128]) torch.Size([64, 30])
torch.Size([64, 30]) torch.Size([64, 32]) torch.Size([64, 20])
torch.Size([25, 128]) torch.Size([25, 128]) torch.Size([25, 30])
torch.Size([25, 30]) torch.Size([25, 32]) torch.Size([25, 20])
torch.Size([64, 128]) torch.Size([64, 128]) torch.Size([64, 30])
torch.Size([64, 30]) torch.Size([64, 32]) torch.Size([64, 20])
torch.Size([25, 128]) torch.Size([25, 128]) torch.Size([25, 30])
torch.Size([25, 30]) torch.Size([25, 32]) torch.Size([25, 20])
torch.Size([64, 128]) torch.Size([64, 128]) torch.Size([64, 30])
torch.Size([64, 30]) torch.Size([64, 32]) torch.Size([64, 20])
torch.Size([25, 128]) torch.Size([25, 128]) torch.Size([25, 30])
torch.Size([25, 30]) torch.Size([25, 32]) torch.Size([25, 20])
torch.Size([64, 128]) torch.Size([64, 128

[Fix. var.] KL = 430.63, MNLL = 33.14:  33%|███▎      | 163/500 [00:03<00:07, 42.56it/s]

torch.Size([25, 30]) torch.Size([25, 32]) torch.Size([25, 20])
torch.Size([64, 128]) torch.Size([64, 128]) torch.Size([64, 30])
torch.Size([64, 30]) torch.Size([64, 32]) torch.Size([64, 20])
torch.Size([25, 128]) torch.Size([25, 128]) torch.Size([25, 30])
torch.Size([25, 30]) torch.Size([25, 32]) torch.Size([25, 20])
torch.Size([64, 128]) torch.Size([64, 128]) torch.Size([64, 30])
torch.Size([64, 30]) torch.Size([64, 32]) torch.Size([64, 20])
torch.Size([25, 128]) torch.Size([25, 128]) torch.Size([25, 30])
torch.Size([25, 30]) torch.Size([25, 32]) torch.Size([25, 20])
torch.Size([64, 128]) torch.Size([64, 128]) torch.Size([64, 30])
torch.Size([64, 30]) torch.Size([64, 32]) torch.Size([64, 20])
torch.Size([25, 128]) torch.Size([25, 128]) torch.Size([25, 30])
torch.Size([25, 30]) torch.Size([25, 32]) torch.Size([25, 20])
torch.Size([64, 128]) torch.Size([64, 128]) torch.Size([64, 30])
torch.Size([64, 30]) torch.Size([64, 32]) torch.Size([64, 20])
torch.Size([25, 128]) torch.Size([25, 128

[Fix. var.] KL = 430.63, MNLL = 33.14:  35%|███▍      | 173/500 [00:04<00:07, 42.17it/s]

torch.Size([25, 30]) torch.Size([25, 32]) torch.Size([25, 20])
torch.Size([64, 128]) torch.Size([64, 128]) torch.Size([64, 30])
torch.Size([64, 30]) torch.Size([64, 32]) torch.Size([64, 20])
torch.Size([25, 128]) torch.Size([25, 128]) torch.Size([25, 30])
torch.Size([25, 30]) torch.Size([25, 32]) torch.Size([25, 20])
torch.Size([64, 128]) torch.Size([64, 128]) torch.Size([64, 30])
torch.Size([64, 30]) torch.Size([64, 32]) torch.Size([64, 20])
torch.Size([25, 128]) torch.Size([25, 128]) torch.Size([25, 30])
torch.Size([25, 30]) torch.Size([25, 32]) torch.Size([25, 20])
torch.Size([64, 128]) torch.Size([64, 128]) torch.Size([64, 30])
torch.Size([64, 30]) torch.Size([64, 32]) torch.Size([64, 20])
torch.Size([25, 128]) torch.Size([25, 128]) torch.Size([25, 30])
torch.Size([25, 30]) torch.Size([25, 32]) torch.Size([25, 20])
torch.Size([64, 128]) torch.Size([64, 128]) torch.Size([64, 30])
torch.Size([64, 30]) torch.Size([64, 32]) torch.Size([64, 20])
torch.Size([25, 128]) torch.Size([25, 128

[Fix. var.] KL = 397.45, MNLL = 33.14:  37%|███▋      | 183/500 [00:04<00:07, 42.18it/s]

torch.Size([25, 30]) torch.Size([25, 32]) torch.Size([25, 20])
torch.Size([64, 128]) torch.Size([64, 128]) torch.Size([64, 30])
torch.Size([64, 30]) torch.Size([64, 32]) torch.Size([64, 20])
torch.Size([25, 128]) torch.Size([25, 128]) torch.Size([25, 30])
torch.Size([25, 30]) torch.Size([25, 32]) torch.Size([25, 20])
torch.Size([64, 128]) torch.Size([64, 128]) torch.Size([64, 30])
torch.Size([64, 30]) torch.Size([64, 32]) torch.Size([64, 20])
torch.Size([25, 128]) torch.Size([25, 128]) torch.Size([25, 30])
torch.Size([25, 30]) torch.Size([25, 32]) torch.Size([25, 20])
torch.Size([64, 128]) torch.Size([64, 128]) torch.Size([64, 30])
torch.Size([64, 30]) torch.Size([64, 32]) torch.Size([64, 20])
torch.Size([25, 128]) torch.Size([25, 128]) torch.Size([25, 30])
torch.Size([25, 30]) torch.Size([25, 32]) torch.Size([25, 20])
torch.Size([64, 128]) torch.Size([64, 128]) torch.Size([64, 30])
torch.Size([64, 30]) torch.Size([64, 32]) torch.Size([64, 20])
torch.Size([25, 128]) torch.Size([25, 128

[Fix. var.] KL = 397.45, MNLL = 33.14:  38%|███▊      | 188/500 [00:04<00:07, 42.44it/s]

torch.Size([25, 128]) torch.Size([25, 128]) torch.Size([25, 30])
torch.Size([25, 30]) torch.Size([25, 32]) torch.Size([25, 20])
torch.Size([64, 128]) torch.Size([64, 128]) torch.Size([64, 30])
torch.Size([64, 30]) torch.Size([64, 32]) torch.Size([64, 20])
torch.Size([25, 128]) torch.Size([25, 128]) torch.Size([25, 30])
torch.Size([25, 30]) torch.Size([25, 32]) torch.Size([25, 20])
torch.Size([64, 128]) torch.Size([64, 128]) torch.Size([64, 30])
torch.Size([64, 30]) torch.Size([64, 32]) torch.Size([64, 20])
torch.Size([25, 128]) torch.Size([25, 128]) torch.Size([25, 30])
torch.Size([25, 30]) torch.Size([25, 32]) torch.Size([25, 20])
torch.Size([64, 128]) torch.Size([64, 128]) torch.Size([64, 30])
torch.Size([64, 30]) torch.Size([64, 32]) torch.Size([64, 20])
torch.Size([25, 128]) torch.Size([25, 128]) torch.Size([25, 30])
torch.Size([25, 30]) torch.Size([25, 32]) torch.Size([25, 20])
torch.Size([64, 128]) torch.Size([64, 128]) torch.Size([64, 30])
torch.Size([64, 30]) torch.Size([64, 32

[Fix. var.] KL = 397.45, MNLL = 33.14:  40%|███▉      | 198/500 [00:04<00:07, 42.82it/s]

torch.Size([64, 30]) torch.Size([64, 32]) torch.Size([64, 20])
torch.Size([25, 128]) torch.Size([25, 128]) torch.Size([25, 30])
torch.Size([25, 30]) torch.Size([25, 32]) torch.Size([25, 20])
torch.Size([64, 128]) torch.Size([64, 128]) torch.Size([64, 30])
torch.Size([64, 30]) torch.Size([64, 32]) torch.Size([64, 20])
torch.Size([25, 128]) torch.Size([25, 128]) torch.Size([25, 30])
torch.Size([25, 30]) torch.Size([25, 32]) torch.Size([25, 20])
torch.Size([64, 128]) torch.Size([64, 128]) torch.Size([64, 30])
torch.Size([64, 30]) torch.Size([64, 32]) torch.Size([64, 20])
torch.Size([25, 128]) torch.Size([25, 128]) torch.Size([25, 30])
torch.Size([25, 30]) torch.Size([25, 32]) torch.Size([25, 20])
torch.Size([64, 128]) torch.Size([64, 128]) torch.Size([64, 30])
torch.Size([64, 30]) torch.Size([64, 32]) torch.Size([64, 20])
torch.Size([25, 128]) torch.Size([25, 128]) torch.Size([25, 30])
torch.Size([25, 30]) torch.Size([25, 32]) torch.Size([25, 20])
torch.Size([64, 128]) torch.Size([64, 128

[Fix. var.] KL = 367.39, MNLL = 33.14:  40%|████      | 202/500 [00:04<00:07, 42.53it/s]


torch.Size([25, 30]) torch.Size([25, 32]) torch.Size([25, 20])
torch.Size([64, 128]) torch.Size([64, 128]) torch.Size([64, 30])
torch.Size([64, 30]) torch.Size([64, 32]) torch.Size([64, 20])
torch.Size([25, 128]) torch.Size([25, 128]) torch.Size([25, 30])
torch.Size([25, 30]) torch.Size([25, 32]) torch.Size([25, 20])
torch.Size([64, 128]) torch.Size([64, 128]) torch.Size([64, 30])
torch.Size([64, 30]) torch.Size([64, 32]) torch.Size([64, 20])
torch.Size([25, 128]) torch.Size([25, 128]) torch.Size([25, 30])
torch.Size([25, 30]) torch.Size([25, 32]) torch.Size([25, 20])


KeyboardInterrupt: 

In [None]:
net.eval()
net.eval_samples = 500
x_test = torch.reshape(torch.linspace(-2, 3, 1000), (-1, 1))
y_test = dataset.f(x_test)
y_pred = net(x_test)

In [None]:
plt.figure()
plt.ylim(-1, 2.5)
plt.xlim(-2, 3)
for i in range(y_pred.size()[2]):
    plt.plot(x_test, y_pred[..., i].detach(), c='r', alpha=0.05)
plt.scatter(dataset.x, dataset.y, ec='k', label='Noisy training measurements')
plt.plot(x_test, y_test, label='True function')
plt.legend()
plt.show()