In [1]:
from pprint import pprint

import numpy as np
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.optim as optim
import torchbnn as bnn
from pytorch_lightning.callbacks import EarlyStopping
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, TensorDataset


class BNN(nn.Module):
    def __init__(self, in_features: int):
        super().__init__()
        self.layers = nn.Sequential(
            bnn.BayesLinear(prior_mu=0, prior_sigma=0.1, in_features=in_features, out_features=256),
            nn.ReLU(),
            bnn.BayesLinear(prior_mu=0, prior_sigma=0.1, in_features=256, out_features=64),
            nn.ReLU(),
            bnn.BayesLinear(prior_mu=0, prior_sigma=0.1, in_features=64, out_features=32),
            nn.ReLU(),
            bnn.BayesLinear(prior_mu=0, prior_sigma=0.1, in_features=32, out_features=1),
        )

    def to(self, device, *args, **kwargs):
        super().to(device, *args, **kwargs)
        for layer in self.layers:
            layer.to(device, *args, **kwargs)
        return self

    def forward(self, x):
        return self.layers(x)


class BNNTask(pl.LightningModule):
    def __init__(self, model: BNN, learning_rate=3e-4, kl_weight=0.1):
        super().__init__()
        self.model = model
        self.learning_rate = learning_rate
        self.kl_weight = kl_weight
        self.mse_loss = nn.MSELoss()
        self.kl_loss = bnn.BKLLoss(reduction="mean", last_layer_only=False)

    def forward(self, x):
        return self.model(x)

    def _loss(self, batch):
        x, y = batch
        predictions = self.model(x).squeeze()
        mse = self.mse_loss(predictions, y.squeeze())
        kl = self.kl_loss(self.model)
        return mse + self.kl_weight * kl

    def training_step(self, batch, batch_idx):
        loss = self._loss(batch)
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        loss = self._loss(batch)
        self.log("val_loss", loss)

    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=self.learning_rate)


class GaussianModel(nn.Module):
    def __init__(self, in_features: int, device=None, min_var: float = 1e-9):
        super().__init__()

        self.in_features = in_features
        self.min_var = min_var
        self.device = device

        self.hidden_layers = nn.Sequential(
            # nn.Linear(self.in_features, 256),
            # nn.ReLU(),
            # nn.Dropout(0.5),
            # nn.Linear(256, 64),
            # nn.ReLU(),
            # nn.Linear(64, 32),
            # nn.ReLU(),
            # nn.Linear(32, 1),
        )
        # last layer
        self.output_layer = nn.Linear(self.in_features, 1, bias=True)

        cov_matrix_ndim = self.output_layer.in_features + 1  # +1 is for the bias
        self.output_cov_matrix_root = nn.Parameter(torch.randn(cov_matrix_ndim, cov_matrix_ndim, requires_grad=True, device=self.device))

    def to(self, device, *args, **kwargs):
        super().to(device, *args, **kwargs)
        self.device = device
        self.output_cov_matrix_root = self.output_cov_matrix_root.to(device, *args, **kwargs)
        return self

    def features(self, x):
        return self.hidden_layers(x)

    def output_cov_matrix(self):
        cov_matrix = self.output_cov_matrix_root.T @ self.output_cov_matrix_root
        cov_matrix[np.diag_indices(cov_matrix.shape[0])] += self.min_var
        return cov_matrix

    def variance(self, features):
        # ones are for the bias
        features_and_ones = torch.hstack([features, torch.ones((features.shape[0], 1), device=features.device)])
        k = features_and_ones.shape[1]
        return features_and_ones.reshape(-1, 1, k) @ self.output_cov_matrix()[None, :, :] @ features_and_ones.reshape(-1, k, 1)

    def forward(self, x, return_var=True):
        features = self.features(x)
        output = self.output_layer(features)
        if return_var:
            return output, self.variance(features)
        return output


class GaussianTask(pl.LightningModule):
    def __init__(self, model, learning_rate=3e-4):
        super().__init__()
        self.model = model
        self.learning_rate = learning_rate

    def forward(self, x, y_err):
        return self.model(x, y_err)

    def _loss(self, batch):
        x, y, y_err = batch
        # predictions, variance = self.model(x, return_var=True)
        # predictions, variance = predictions.squeeze(), variance.squeeze()
        # variance += torch.square(y_err)
        # return torch.mean(torch.square(predictions - y.squeeze()) / variance + torch.log(variance))
        predictions = self.model(x, return_var=False).squeeze()
        return torch.mean(torch.square(predictions - y.squeeze()))

    def training_step(self, batch, batch_idx):
        loss = self._loss(batch)
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        loss = self._loss(batch)
        self.log("val_loss", loss)

    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=self.learning_rate)


def train_nn(X, y, y_err, *, batch_size: int = 1 << 10, n_epoch: int = 10_000, learning_rate: float = 3e-4):
    x_train, x_val, y_train, y_val, y_err_train, y_err_val = (
        torch.tensor(a, dtype=torch.float32)
        for a in train_test_split(X, y, y_err, test_size=0.4, random_state=42)
    )

    model = GaussianModel(x_train.shape[1])
    task = GaussianTask(model, learning_rate=learning_rate)

    # model = BNN(x_train.shape[1])
    # task = BNNTask(model, learning_rate=learning_rate, kl_weight=0.1)

    train_dataset = TensorDataset(x_train, y_train, y_err_train)
    val_dataset = TensorDataset(x_val, y_val, y_err_val)
    # train_dataset = TensorDataset(x_train, y_train)
    # val_dataset = TensorDataset(x_val, y_val)

    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_dataloader = DataLoader(val_dataset, batch_size=1 << 14)

    trainer = pl.Trainer(
        max_epochs=n_epoch,
        # accelerator='auto',
        accelerator='cpu',
        enable_progress_bar=False,
        logger=True,
        callbacks=[
            EarlyStopping('val_loss', patience=10, mode='min'),
        ]
    )
    trainer.fit(task, train_dataloader, val_dataloader)

    return model, task, trainer

In [9]:
def linear_fit(X, y, y_err, model, jac):
    sqrt_w = 1 / y_err
    A = sqrt_w[:, None] * jac(X)
    params, (chi2,), *_ = np.linalg.lstsq(A, sqrt_w * y, rcond=None)
    reduced_chi2 = chi2 / (A.shape[0] - A.shape[1])
    covariance = np.linalg.inv(A.T @ A) * reduced_chi2

    # print(f'{params = }')
    # print(f'{covariance = }')
    # print(f'{reduced_chi2 = }')

    def f(x, cov=True):
        x = np.atleast_2d(x)
        y = model(x, params)

        if not cov:
            return y

        j = jac(x)
        var = j @ covariance @ j.T
        return y, var

    return f, params, covariance


def non_linear_fit(X, y, y_err, model):
    from scipy.optimize import curve_fit

    result = curve_fit(
        f=lambda x, *params: model(x, np.array(list(params))),
        xdata=X,
        ydata=y,
        sigma=y_err,
        p0=np.zeros(X.shape[1] + 1),
    )
    print(result)

In [13]:
n_features = 3
n_samples = 1 << 20

def model(x, params):
    y = x @ params[:-1, None] + params[-1]
    return y.squeeze()


def jac(x):
    return np.hstack([x, np.ones((x.shape[0], 1))])


def true_noise_amplitude(x):
    # return np.full(x.shape[0], 0.5)
    return 0.05 + 0.5 * np.linalg.norm(x, axis=1)


def generate_data(params, n_samples: int = 1 << 20, rng=0):
    rng = np.random.default_rng(rng)

    X = rng.uniform(-1, 1, size=(n_samples, n_features))
    latent = rng.uniform(0, 2, size=(n_samples, 1))
    y_err = true_noise_amplitude(X)
    y_known_noise = rng.normal(loc=0, scale=y_err)
    y_unknown_noise = 0 * rng.normal(loc=0, scale=true_noise_amplitude(latent))
    y = model(X, params) + y_known_noise + y_unknown_noise

    return X, y, y_err


true_params = np.array([+2.0, -3.0, +1.0, +0.5])
assert true_params.size == n_features + 1

X, y, y_err = generate_data(true_params, n_samples=1 << 20)

X_train, X_test, y_train, y_test, y_err_train, y_err_test = train_test_split(X, y, y_err, test_size=0.2, random_state=42, shuffle=True)

In [14]:
linear_f, params, cov = linear_fit(X_train, y_train, y_err_train, model, jac)
# non_linear_fit(X_train, y_train, y_err_train, model)
print(np.mean(np.square(linear_f(X_test, cov=False) - y_test)))
pprint(linear_f(X_test[:3], cov=True))

0.30094158973569707
(array([ 0.78355059, -1.1703901 ,  4.44560298]),
 array([[ 1.66216401e-06,  8.64392630e-08, -5.97198433e-07],
       [ 8.64392630e-08,  1.69493379e-06,  2.17300407e-07],
       [-5.97198433e-07,  2.17300407e-07,  2.21819043e-06]]))


In [17]:
rng = np.random.default_rng(0)
params_bootstrap = []
linear_f_bootstrap = []
from tqdm import tqdm
for _i_bootstrap in tqdm(range(10000)):
    idx = rng.choice(X_train.shape[0], size=X_train.shape[0] // 100, replace=True)
    linear_f_bootstrap_iter, params_bootstrap_iter, _cov = linear_fit(X_train[idx], y_train[idx], y_err_train[idx], model, jac)
    linear_f_bootstrap.append(linear_f_bootstrap_iter)
    params_bootstrap.append(params_bootstrap_iter)
params_bootstrap_mean = np.mean(params_bootstrap, axis=0)
params_bootstrap_cov = np.cov(np.array(params_bootstrap).T)
print(params_bootstrap_mean)
print(params_bootstrap_cov)
print(np.cov([linear_f_bootstrap[i](X_test[:3], cov=False) for i in range(len(linear_f_bootstrap))], rowvar=False))

100%|██████████| 10000/10000 [00:08<00:00, 1161.79it/s]

[ 2.00221801 -2.99992089  1.00080127  0.50053013]
[[ 1.12496224e-04 -1.73328443e-06 -2.02285541e-06  3.79410850e-07]
 [-1.73328443e-06  1.10778150e-04  2.30503376e-06  2.69186347e-07]
 [-2.02285541e-06  2.30503376e-06  1.08547284e-04  6.26235834e-07]
 [ 3.79410850e-07  2.69186347e-07  6.26235834e-07  2.36066733e-05]]
[[ 1.63068611e-04  8.37864034e-06 -5.79936991e-05]
 [ 8.37864034e-06  1.73394440e-04  1.85234150e-05]
 [-5.79936991e-05  1.85234150e-05  2.18637810e-04]]





In [107]:
nn_model, nn_task, nn_trainer = train_nn(X_train, y_train, y_err_train, n_epoch=50, learning_rate=1e-3)

GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(

  | Name  | Type          | Params
----------------------------------------
0 | model | GaussianModel | 20    
----------------------------------------
20        Trainable params
0         Non-trainable params
20        Total params
0.000     Total estimated model params size (MB)
  rank_zero_warn(
  rank_zero_warn(


In [124]:
nn_model(torch.tensor(X_test[:3], dtype=torch.float32), return_var=False).detach().numpy().squeeze()

array([ 0.7715461, -1.1633983,  4.4548364], dtype=float32)

In [None]:
nn_task._loss((torch.tensor(X_test, dtype=torch.float32), torch.tensor(y_test, dtype=torch.float32), torch.tensor(y_err_test, dtype=torch.float32)))

In [None]:
mu, var = nn_model(torch.tensor(X_test, dtype=torch.float32), torch.tensor(y_err_test, dtype=torch.float32)).T
# torch.mean(torch.square(torch.tensor(y_test, dtype=torch.float32) - mu) / (var + torch.tensor(y_err_test, dtype=torch.float32)**2))
torch.mean(torch.square(torch.tensor(y_test, dtype=torch.float32) - mu))

In [None]:
nn_model(torch.zeros((1, n_features)), torch.full((1,), 0.5))