In [1]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader

from activations import Cosine
from networks import WHVIRegression
from layers import WHVILinear
from torch_datasets import ToyDataset

In [2]:
torch.manual_seed(0)  # Seed for reproducibility

# Data
dataset = ToyDataset(n=128)
data_loader = DataLoader(dataset, batch_size=64)

In [3]:
net = WHVIRegression([
    nn.Linear(1, 128),
    Cosine(),
    WHVILinear(128, lambda_=0.01),
    Cosine(),
    WHVILinear(128, lambda_=0.01),
    Cosine(),
    nn.Linear(128, 1)
])

gamma=0.0005
p = 0.3
optimizer = optim.Adam(net.parameters(), lr=0.001)
lr_scheduler = optim.lr_scheduler.LambdaLR(optimizer, lambda t: (1 + gamma * t)**(-p))

In [4]:
net.train_model(data_loader, optimizer, epochs2 = 5000)

[Epoch 0] Loss = 496.006, sigma = 1.000
[Epoch 100] Loss = 445.901, sigma = 1.000
[Epoch 200] Loss = 396.784, sigma = 1.000
[Epoch 300] Loss = 348.763, sigma = 1.000
[Epoch 400] Loss = 302.299, sigma = 1.000
[Epoch 0] Loss = 257.999, sigma = 0.999
[Epoch 100] Loss = 202.414, sigma = 0.797
[Epoch 200] Loss = 143.937, sigma = 0.569
[Epoch 300] Loss = 72.888, sigma = 0.296
[Epoch 400] Loss = 3.224, sigma = 0.112
[Epoch 500] Loss = -29.751, sigma = 0.098
[Epoch 600] Loss = -53.672, sigma = 0.077
[Epoch 700] Loss = -57.197, sigma = 0.081
[Epoch 800] Loss = -67.332, sigma = 0.084
[Epoch 900] Loss = -34.639, sigma = 0.075
[Epoch 1000] Loss = -67.592, sigma = 0.077
[Epoch 1100] Loss = -53.333, sigma = 0.073
[Epoch 1200] Loss = -83.559, sigma = 0.074
[Epoch 1300] Loss = -86.521, sigma = 0.069
[Epoch 1400] Loss = -74.947, sigma = 0.064
[Epoch 1500] Loss = -86.216, sigma = 0.064
[Epoch 1600] Loss = -91.580, sigma = 0.060
[Epoch 1700] Loss = -79.013, sigma = 0.058
[Epoch 1800] Loss = -97.094, sigm

In [None]:
net.eval()
net.eval_samples = 500
x_test = torch.reshape(torch.linspace(-2, 3, 1000), (-1, 1))
y_test = dataset.f(x_test)
y_pred = net(x_test)

In [None]:
plt.figure()
plt.ylim(-1, 2.5)
plt.xlim(-2, 3)
for i in range(y_pred.size()[2]):
    plt.plot(
        x_test.numpy().ravel(),
        y_pred[..., i].detach().numpy().ravel(),
        c='r', 
        alpha=0.01
    )

plt.scatter(dataset.x, dataset.y, ec='k', label='Noisy training measurements')
plt.plot(x_test, y_test, label='True function')
plt.legend()
plt.show()