In [1]:
import math

import bayesfunc as bf
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import torch as t
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm.auto import tqdm

from data import REGRESSION_CONFIG, RegressionDataset

  from .autonotebook import tqdm as notebook_tqdm


In [7]:
class GI_Dense(nn.Module):
    def __init__(self,
                num_input,
                num_output,
                num_inducing,
                nonlinearity,
                prior_scale_factor,
                dtype,
                name="global_inducing_fc",
                **kwargs):
    
        super().__init__(name=name, dtype=dtype, **kwargs)
        
        self.num_input = num_input + 1
        self.num_output = num_output
        self.num_inducing = num_inducing
        self.prior_scale_factor = prior_scale_factor
        
        # Set nonlinearity for the layer
        self.nonlinearity = (lambda x: x) if nonlinearity is None else \
                            getattr(nn, nonlinearity)

        
    def build(self, input_shape):
        
        # Set up prior mean, scale and distribution
        self.prior_mean = t.zeros(
            shape=(self.num_output, self.num_input),
            dtype=self.dtype
        )
        
        self.prior_scale = t.ones(
            shape=(self.num_output, self.num_input),
            dtype=self.dtype
        )
        self.prior_scale /= self.num_input**0.5
        self.prior_scale *= self.prior_scale_factor
        
        self.prior = t.distributions.MultivariateNormal(
            loc=self.prior_mean,
            # scale_diag=self.prior_scale # old tf code
            scale_tril=self.prior_scale
        )
        
        # Set up pseudo observation means and variances
        self.pseudo_means = t.zeros(
            shape=(self.num_inducing, self.num_output),
            dtype=self.dtype
        )
        self.pseudo_mean = t.Parameter(self.pseudo_means)
        
        self.pseudo_log_prec = t.zeros(
            shape=(self.num_inducing,),
            dtype=self.dtype
        )
        self.pseudo_log_prec = t.Parameter(self.pseudo_log_prec)

    @property
    def pseudo_precision(self):
        return t.math.exp(self.pseudo_log_precision)

    def q_prec_cov_chols(self, Uin):
        """
        
        :param Uin: inducing set U_in
        """
        
        phiU = self.nonlinearity(Uin)
        pseudo_prec = t.math.exp(self.pseudo_log_prec)
        
        # Compute precision matrix of multivariate normal
        phiT_lambda_phi = t.einsum("mi, m, mj -> ij", phiU, pseudo_prec, phiU)
        
        q_prec = t.linalg.diag(self.prior_scale[0, :]**-2.) + phiT_lambda_phi
        
        # Compute cholesky of approximate posterior precision
        q_prec_chol = t.linalg.cholesky(q_prec)
        
        # Compute cholesky of approximate posterior covariance
        iq_prec_chol = t.linalg.triangular_solve(
            q_prec_chol,
            t.eye(q_prec_chol.shape[0]),
            lower=True
        )
        
        q_cov = t.matmul(iq_prec_chol.T, iq_prec_chol)
        q_cov = q_cov + 1e-5 * t.eye(q_cov.shape[0])
        q_cov_chol = t.linalg.cholesky(q_cov)
        
        return q_prec_chol, q_cov_chol
    