In [None]:
from IPython.display import Image, display

In [None]:
from abc import ABC, abstractmethod
from collections import namedtuple
from itertools import product
import numpy as np

In [None]:
from pgmpy.models import DiscreteBayesianNetwork
from pgmpy.factors.discrete import TabularCPD
from pgmpy.inference import VariableElimination
import pyagrum as gum
from pyagrum.lib import image as gumimage

In [None]:
Variable = namedtuple("Variable", ["name", "states",])
Arc = namedtuple("Arc", ["tail", "head"])
CPD = namedtuple(
    "CPD", 
    [
        "variable",
        "parents",
        "table",
        ]
        )


class ModelABC(ABC):
    def __init__(self, *, name: str, variables: Variable, arcs: tuple[Arc], cdps: tuple[CPD]):
        self.name = name
        self.model = None
        self.variables = variables
        self.arcs = arcs
        self.cpds = cdps

    @abstractmethod
    def get_cpd(self, variable: str):
        raise NotImplementedError
    
    @abstractmethod
    def draw_graph(self):
        raise NotImplementedError
    
    def print_variables(self):
        print(f"Variables:")
        print(self.variables)
        print(f"\n")

    @abstractmethod
    def print_nodes(self):
        raise NotImplementedError
    
    def print_arcs(self):
        print(f"Arcs:")
        print([(tail.name, head.name) for (tail, head) in self.arcs])
        print(f"\n")
    
    @abstractmethod
    def print_potentials(self):
        raise NotImplementedError


class PGM(ModelABC):
    def __init__(self, *, name: str, variables: Variable, arcs: tuple[Arc], cdps: tuple[CPD]):
        super().__init__(name=name, variables=variables, arcs=arcs, cdps=cdps)
        self.model = DiscreteBayesianNetwork()
        self._add_arcs()
        self._add_cpds()
        
    def _add_arcs(self):
        for item in self.arcs:
            self.model.add_edge(item.tail.name, item.head.name)

    @staticmethod
    def _add_cpd(cpd: CPD):
        if not cpd:
            return
        description = {
            "variable": cpd.variable.name,
            "variable_card": len(cpd.variable.states),
            "values": cpd.table
            }
        if cpd.parents:
            description["evidence"] = [item.name for item in cpd.parents]
            description["evidence_card"] = tuple(len(item.states) for item in cpd.parents)
        return TabularCPD(**description)
    
    def _add_cpds(self):
        cpds = [PGM._add_cpd(item) for item in self.cpds]
        self.model.add_cpds(*cpds)

    def get_cpd(self, variable):
        found = False
        for cpd in self.cdps:
            if cpd.variable == variable:
                found = True
                print(cpd)
        if not found:
            print("cpd not found in model")
    
    def print_nodes(self):
        print(f"Nodes:")
        print(list(self.model.nodes()))
        print(f"\n")
    
    def print_potentials(self):
        print(f"Potentials:")
        for cpd in self.model.get_cpds():
            print(F"cpd for {cpd.variable}:\n{cpd}\n")
        print(f"\n")
    
    def draw_graph(self):
        filename = f'{self.name}.png'
        viz = self.model.to_graphviz()
        viz.draw(filename, prog='dot')
        display(Image(filename))


class GUM(ModelABC):
    def __init__(self, *, name: str, variables: Variable, arcs: tuple[Arc], cdps: tuple[CPD]):
        super().__init__(name=name, variables=variables, arcs=arcs, cdps=cdps)
        self.model = gum.BayesNet()
        self._add_variables()
        self._add_arcs()
        self._add_cpds()

    def _add_variables(self):
        for variable in self.variables:
            self.model.add(gum.LabelizedVariable(variable.name, variable.name, variable.states))


    def _add_arcs(self):
        for arc in self.arcs:
            self.model.addArc(arc.tail.name, arc.head.name)
    
    def _add_cpd(self, cpd):
        if not cpd:
            return

        if not cpd.parents:
            shape = (len(cpd.variable.states),)
            cpt = np.array(cpd.table).T.reshape(shape)
            self.model.cpt(cpd.variable.name).fillWith(cpt.tolist())
        else:
            shape = tuple([len(cpd.variable.states), *[len(item.states) for item in cpd.parents]])
            variable_names = tuple([cpd.variable.name, *[item.name for item in cpd.parents]])
            cpt = np.array(cpd.table).T.reshape(shape)
            all_variables = [{cpd.variable.name: cpd.variable.states}]
            for parent in cpd.parents:
                all_variables.append({parent.name: parent.states})
            key_combinations = []
            ind_combinations = []

            for k,d in enumerate(all_variables):
                key = list(d.keys())[0]
                values = list(d.values())[0]
                key_combinations.append([key])
                key_combinations.append(values)
                ind_combinations.append([k])
                ind_combinations.append(list(range(len(values))))
            key_combinations = list(product(*key_combinations))
            ind_combinations = list(product(*ind_combinations))

            for combination in zip(key_combinations, ind_combinations):
                d = {"name": combination[0][0]}
                d.update({k:v for k,v in zip(combination[0][2::2],  combination[0][3::2])})
                ind = combination[1][3::2]

                variable_name = d.pop("name")
                self.model.cpt(variable_name)[d] = cpt[*ind].tolist()    
   
    def _add_cpds(self):
        for cpd in self.cpds:
            self._add_cpd(cpd)

    def print_nodes(self):
        print(f"Nodes:")
        print(list([name for name in self.model.names()]))
        print(f"\n")
    
    def print_potentials(self):
        print(f"Potentials:")
        for variable in [item.name for item in self.variables]:
            print(F"cpd for {variable}:\n{self.model.cpt(variable)}\n")
        print(f"\n")
                
    def get_cpd(self, variable):
        if not variable in [item.name for item in self.variables]:
            print("cpd not found in model")
            return
        print(self.model.cpt(variable))
    
    def draw_graph(self):
        filename = f'{self.name}.png'
        gumimage.export(self.model,filename)
        display(Image(filename))    


In [None]:
var_pollution = Variable("Pollution", ["yes", "no"])
var_smoker = Variable("Smoker", ["yes", "no"])
var_cancer = Variable("Cancer", ["yes", "no"])
var_xray = Variable("Xray", ["yes", "no"])
var_dyspnoea = Variable("Dyspnoea", ["yes", "no"])

variables = (
    var_pollution,
    var_smoker,
    var_cancer,
    var_xray,
    var_dyspnoea,
)

arcs = (
    Arc(var_pollution, var_cancer),
    Arc(var_smoker, var_cancer),
    Arc(var_cancer, var_xray),
    Arc(var_cancer, var_dyspnoea),
)

cpd_pollution = CPD(
    variable=var_pollution,
    parents=None,
    table=[[0.9], [0.1]],
    )
cpd_smoker = CPD(
    variable=var_smoker,
    parents=None,
    table=[[0.3], [0.7]],
    )
cpd_cancer = CPD(
    variable=var_cancer,
    parents=[var_smoker, var_pollution],
    table=[[0.03, 0.05, 0.001, 0.02], [0.97, 0.95, 0.999, 0.98]],
)
cpd_xray = CPD(
    variable=var_xray,
    parents=[var_cancer],
    table=[[0.9, 0.2], [0.1, 0.8]],
)
cpd_dyspnoea = CPD(
    variable=var_dyspnoea,
    parents=[var_cancer],
    table=[[0.65, 0.3], [0.35, 0.7]],
)

cdps = (
    cpd_pollution,
    cpd_smoker,
    cpd_cancer,
    cpd_xray,
    cpd_dyspnoea,
)

In [None]:
pgm_model = PGM(name="Cancer_model", variables=variables, arcs=arcs, cdps=cdps)

In [None]:
pgm_model.print_nodes()
pgm_model.print_arcs()
pgm_model.print_variables()
pgm_model.print_potentials()

In [None]:
pgm_model.model.check_model()

In [None]:
pgm_model.draw_graph()

In [None]:
gum_model = GUM(name="Cancer_model", variables=variables, arcs=arcs, cdps=cdps)

In [None]:
gum_model.print_nodes()
gum_model.print_arcs()
gum_model.print_variables()
gum_model.print_potentials()

In [None]:
gum_model.draw_graph()

In [None]:
ie = gum.LazyPropagation(gum_model.model)
ie.makeInference()

In [None]:
print(ie.posterior("Xray"))
print(ie.posterior("Dyspnoea"))

In [None]:
ie = VariableElimination(pgm_model.model)
ie.query(variables=["Xray", "Dyspnoea"])