In [1]:
from abc import ABC, abstractmethod

import torch

import pandas as pd
import numpy as np

import bambi as bmb

  import pandas.util.testing as tm


In [20]:
class Family(ABC):
    """Kullback-Leibler divergence functions switch class."""
    
    subclasses = {}

    def __init__(self):
        super().__init__()

    def __init_subclass__(cls, **kwargs):
        super().__init_subclass__(**kwargs)
        cls.subclasses[cls._FAMILY_NAME] = cls

    @abstractmethod
    def kl_div(self):
        pass

    @classmethod
    def create(cls, family_name):
        if family_name not in cls.subclasses:
            raise NotImplementedError("Unsupported family.")

        return cls.subclasses[family_name]()
    
    
class Gaussian(Family):
    _FAMILY_NAME = "gaussian"

    def __init__(self):
        super().__init__()
        self.has_disp_params = True

    def kl_div(self, y_ast, y_perp):
        """Kullback-Leibler divergence between two Gaussians surrogate function.

        Args:
            mu_ast (torch.tensor): Tensor of learned reference model parameters
            mu_perp (torch.tensor): Tensor of submodel parameters to learn

        Returns:
            torch.tensor: Tensor of shape () containing sample KL divergence
        """

        # compute sufficient statistics
        mu_ast, mu_perp = torch.mean(y_ast), torch.mean(y_perp)
        std_ast, std_perp = torch.std(y_ast), torch.std(y_perp)
        # compute KL divergence using full formula
        div = (
            torch.log(std_perp / std_ast)
            + (std_ast**2 + (mu_ast - mu_perp) ** 2) / (2 * std_perp**2)
            - 1 / 2
        )
        assert div.shape == (), f"Expected data dimensions {()}, received {div.shape}."
        return div

    def project_disp_params(self):
        "Project the dispersion parameters of the distribution analytically."

        pass

In [25]:
class NewFamily(Family):
    _FAMILY_NAME = "my_new_family"

    def __init__(self):
        super().__init__()


draws = torch.from_numpy(np.random.normal(0, 1, 100)).float()
family = Family.create("my_new_family")

TypeError: Can't instantiate abstract class NewFamily with abstract methods kl_div

In [23]:
draws = torch.from_numpy(np.random.normal(0, 1, 100)).float()
family = Family.create("gaussian")
family.kl_div(draws, draws) == 0.

tensor(True)

In [24]:
family = Family.create("my_new_family")
family.kl_div(1, 2)

NotImplementedError: Unsupported family.

In [5]:
# define model data
data = pd.DataFrame({
    "y": np.random.normal(size=50),
    "g": np.random.choice(["Yes", "No"], size=50),
    "x1": np.random.normal(size=50),
    "x2": np.random.normal(size=50)
})
# define and fit model with MCMC
model = bmb.Model("y ~ -1 + x1 + x2", data, family="gaussian")

In [8]:
model.intercept_term is not None

False