In [1]:
from typing import Dict, Iterable, List, Mapping, Sequence
import torch
import torch.nn as nn

class Factor:
    def __init__(self, table: torch.Tensor|list|float, variables: List[str]):
        self.variables = variables
        if not isinstance(table, torch.Tensor):
            table = torch.tensor(table)
        self.table = table
        assert len(table.shape) == len(variables) # the last dimension is the value. The rest are variables

    def __mul__(self, other: 'Factor') -> 'Factor':
        '''
        Multiply two factors.
        '''
        var_union = set(self.variables) | set(other.variables)

        var_name_map = {var: chr(ord('a') + i) for i, var in enumerate(var_union)}

        var_output = self.variables + [var for var in other.variables if var not in self.variables]

        def to_einsum_string(str_list: Iterable[str]) -> str:
            return ''.join(map(lambda x: var_name_map[x], str_list))
        result = torch.einsum(
            to_einsum_string(self.variables)+','+to_einsum_string(other.variables) 
            + '->' 
            + to_einsum_string(var_output), self.table, other.table)
        
        return Factor(result, var_output)
    
    def __truediv__(self, other: 'Factor') -> 'Factor':
        return self * other.inverse()
    
    def __getitem__(self, index: Mapping[str, int|None|slice]) -> torch.Tensor:
        arr_index:Sequence[slice|int] = [slice(None)] * len(self.variables)
        for var, val in index.items():
            if val is None:
                continue
            if isinstance(val, slice):
                arr_index[self.variables.index(var)] = val
            else:
                arr_index[self.variables.index(var)] = slice(val, val+1)
        return self.table[tuple(arr_index)]
    
    def marginalize(self, var_name: str| list[str]) -> 'Factor':
        '''
        Marginalize the factor with respect to the variable.
        '''
        if isinstance(var_name, str):
            var_name = [var_name]
        result = torch.sum(self.table, dim= tuple([self.variables.index(var) for var in var_name]))
        var_output = [var for var in self.variables if var not in var_name]
        return Factor(result, var_output)
    
    def inverse(self) -> 'Factor':
        '''
        Inverse the factor.
        '''
        return Factor(1/self.table, self.variables)

class Variable(nn.Module):
    def __init__(self, name:str, cpt: torch.Tensor|List, parents: List['Variable'] = []):
        super().__init__()
        self.name = name
        self.parents = parents
        if isinstance(cpt, list):
            cpt = torch.tensor(cpt)
        self.cpt = Factor(cpt, [parent.name for parent in parents] + [name])

class BayesianNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.variables: nn.ModuleDict = nn.ModuleDict()

    def copy(self):
        new_instance = BayesianNetwork()
        for var in self.variables:
            assert isinstance(var, Variable)
            new_instance.add_variable(Variable(var.name,var.cpt.table.clone(),var.parents))

    def add_variable(self, variable: Variable):
        self.variables[variable.name] = variable

    def get_variable(self, name: str) -> Variable:
        return self.variables[name] # type: ignore
    
    def infer(self, target: Dict[str,int|None|slice], observation: Dict[str, torch.Tensor] = {}) -> Factor:
        '''
        p(target|observation)
        '''
        joint = self.get_joint_distribution(observation)

        # marginalize dont care variables
        dont_cares = [var for var in joint.variables if var not in target]
        joint = joint.marginalize(dont_cares)

        target_var = [var for var in joint.variables if var in target]

        return Factor(joint[target], target_var)

    def get_joint_distribution(self, observation: Dict[str, torch.Tensor] = {}) -> Factor:
        '''
        Get the joint distribution of all variables.
        '''
        joint = Factor(1, [])
        for var in self.variables.values():
            if var.name in observation:
                joint = joint * Factor(observation[var.name], [var.name])
            else:
                joint = joint * var.cpt
        return joint
    
    def get_entropy(self, target: str|list[str], observation: Dict[str, torch.Tensor] = {}, condition: str|List[str]=[]) -> float:
        '''
        Get the entropy of target variable given the condition.
        '''
        if isinstance(target, str):
            target = [target]

        if isinstance(condition, str):
            condition = [condition]

        dist = self.infer({var: None for var in target + condition}, observation)
        cond_dist = dist.marginalize(target)
        log_term = (dist / cond_dist).table.log2()
        log_term[log_term.isinf() | log_term.isnan()] = 0
        return -(dist.table * log_term).sum().item()
    
    def sample(self, n: int, target: List[str]|None = None, observation: Dict[str, torch.Tensor] = {}) -> Dict[str, torch.Tensor]:
        '''
        Sample from the Bayesian Network.
        '''
        joint = self.get_joint_distribution(observation)
        sample = torch.multinomial(joint.table.view(-1), n, replacement=True)
        result = {}
        for i, var in enumerate(reversed(joint.variables)):
            if target is None or var in target:
                result[var] = sample % joint.table.shape[i]
            sample = sample // joint.table.shape[i]
        return {var: result[var].view(-1) for var in result}

In [2]:
# bn = BayesianNetwork()

# # Define the variables
# a = Variable('a', [.96,.04])
# b = Variable('b', [[.5,.5,0],[0,0,1]], [a])
# c = Variable('c', [[.5,.5,0],[0,0,1]], [a])
# x0 = Variable('x0', [[1,0],[0,1],[0,1]], [b])
# x1 = Variable('x1', [[1,0],[0,1],[1,0]], [b])
# x2 = Variable('x2', [[1,0],[0,1],[1,0]], [c])
# x3 = Variable('x3', [[1,0],[0,1],[0,1]], [c])

# # Add the variables to the network
# for var in [a,b,c,x0,x1,x2,x3]:
#     bn.add_variable(var)

# # Infer the distribution of x3 given a=0
# bn.infer({'x0':None,'x1':None}).table

In [3]:
real_model = BayesianNetwork()
a = Variable('a', [.25,.25,.25,.25])
b = Variable('b', [[.7,.1,.1,.1],[.1,.7,.1,.1],[.1,.1,.7,.1],[.1,.1,.1,.7]], [a])
c = Variable('c', [[.7,.1,.1,.1],[.1,.7,.1,.1],[.1,.1,.7,.1],[.1,.1,.1,.7]], [a])
x0 = Variable('x0', [[.94, .02, .02, .02],[.02, .94, .02, .02],[.02, .02, .94, .02],[.02, .02, .02, .94]], [b])
x1 = Variable('x1', [[.94, .02, .02, .02],[.02, .94, .02, .02],[.02, .02, .94, .02],[.02, .02, .02, .94]], [b])
x2 = Variable('x2', [[.94, .02, .02, .02],[.02, .94, .02, .02],[.02, .02, .94, .02],[.02, .02, .02, .94]], [c])
x3 = Variable('x3', [[.94, .02, .02, .02],[.02, .94, .02, .02],[.02, .02, .94, .02],[.02, .02, .02, .94]], [c])

for var in [a,b,c,x0,x1,x2,x3]:
    real_model.add_variable(var)

dataset = real_model.sample(100)
dataset_tensor = torch.stack([dataset[var] for var in ['x0', 'x1', 'x2', 'x3']], dim=1) # [100, 4]

In [4]:

from itertools import product

class Model:
    def __init__(self):
        pass

    def fit(self, dataset: torch.Tensor):
        pass

    def get_likelihood(self, data) -> torch.Tensor:
        '''
        Returns the likelihood of the data.
        '''
        raise NotImplementedError
    
def kld(p: torch.Tensor, q: torch.Tensor) -> float:
    '''
    Compute the KL divergence between two distributions.
    '''
    assert p.shape == q.shape
    items = p * (p / q).log()
    items[p == 0] = 0
    return items.sum().item()

def eval_model(model: Model, real_dist: Factor) -> float:
    '''
    Evaluate the model with the dataset.
    '''
    likelihood = torch.zeros([4,4,4,4])
    for x in product(range(4), repeat=4):
        likelihood[x] = model.get_likelihood(torch.tensor(x)).item()

    real = real_dist.marginalize(['a','b','c']).table

    return kld(real, likelihood)


## Model 1: frequency of each possible value

In [5]:
class FreqModel(Model):
    def __init__(self):
        self.freq = {}

    def fit(self, dataset: torch.Tensor, smoothing_prior: float = 0):
        data, count = torch.unique(dataset, return_counts=True, dim=0)

        all_possibilities = torch.tensor([[i,j,k,l] for i in range(4) for j in range(4) for k in range(4) for l in range(4)])

        for i, possibility in enumerate(all_possibilities):
            self.freq[tuple(possibility.tolist())] = smoothing_prior * len(data) / len(all_possibilities)

        for i, possibility in enumerate(data):
            self.freq[tuple(possibility.tolist())] += count[i].item()

        total = sum(self.freq.values())
        for key in self.freq.keys():
            self.freq[key] /= total

    def get_likelihood(self, data) -> torch.Tensor:
        if len(data.shape) == 1:
            data = data.unsqueeze(0)
        return torch.tensor([self.freq[tuple(data[i].tolist())] for i in range(data.shape[0])]).prod()
    
    def get_log_likelihood(self, data) -> torch.Tensor:
        if len(data.shape) == 1:
            data = data.unsqueeze(0)
        return torch.tensor([self.freq[tuple(data[i].tolist())] for i in range(data.shape[0])]).log().sum()


In [1]:
freq_model = FreqModel()
freq_model.fit(dataset_tensor, smoothing_prior=1)

eval_model(freq_model, real_model.get_joint_distribution())

NameError: name 'FreqModel' is not defined

In [91]:
sum(freq_model.freq.values())

0.9999999999999949

In [79]:
0.6*0.3*0.02*0.01*0.8

2.8800000000000002e-05