In [33]:
import csv
from typing import Any, Optional, Set, Dict
from dataclasses import dataclass, field
from pgmpy.readwrite import BIFReader
from numpy import array, ndarray
from collections import defaultdict

# Parameters
GROUP_ID = 29
ALGORITHM = 've' # ’ve’ = Variable Elimination, ’gibbs’ = Gibbs Sampling
NETWORK_NAME = './networks/child.bif'
REPORT = 'Disease' # e.g., Child: ’Disease’
EVIDENCE_LEVEL = 'Little' # {None | Little | Moderate} for Child/Insurance
EVIDENCE = 'LowerBodyO2=<5; RUQO2=>=12; CO2Report=>=7.5; XrayReport=Asy/Patchy'


@dataclass (frozen=True)
class Node:
    # name is primary key
    name: str
    parents: tuple
    values: tuple
    probability_model: Optional[ndarray] = field(default=None, compare=False, repr=False)
    state: int = -1

    def __hash__(self):
        return hash(self.name)

    def __eq__(self, other):
        if isinstance(other, Node):
            return self.name == other.name
        return False

    def __str__(self):
        return f'{self.name}: [{", ".join(self.values)}]'


class Network:
    def __init__(self, nodes: Set[Node] | None = None):
        self.nodes: Set[Node] = nodes or set()
        self.parents: Dict[Node, Set[Node]] = defaultdict(set)   # child -> {parents}
        self.children: Dict[Node, Set[Node]] = defaultdict(set)  # parent -> {children}
        self.by_name: Dict[str, Node] = {}

    def add_node(self, node: Node):
        if node.name not in self.by_name:
            self.by_name[node.name] = node
            self.nodes.add(node)

    def add_edge(self, parent_name: str, child_name: str):
        parent = self.by_name[parent_name]
        child = self.by_name[child_name]
        self.parents[child].add(parent)
        self.children[parent].add(child)

    def markov_blanket(self, node: Node) -> Set[Node]:
        blanket: Set[Node] = self.parents.get(node, set())
        children = self.children.get(node, set())
        blanket.update(children)
        # children's parents excluding self
        for child in children:
            blanket.update(parent for parent in self.parents.get(child, set()) if parent is not node)
        blanket.discard(node)
        return blanket

    def __str__(self):
        return "\n".join(str(n) for n in sorted(self.nodes, key=lambda n: n.name))


class InputReader:
    def __init__(self):
        reader = BIFReader(NETWORK_NAME)
        model = reader.get_model()
        states = reader.get_states()
        net = Network()

        # get all nodes
        for variable in model.nodes():
            cpd = model.get_cpds(variable)
            net.add_node(Node(
                name=str(variable),
                parents=tuple(cpd.get_evidence() or tuple()),
                values=tuple(states[variable]),
                probability_model=cpd.values
            ))

        # 2) get all edges
        for child in model.nodes():
            cpd = model.get_cpds(child)
            for parent in (cpd.get_evidence() or []):
                net.add_edge(str(parent), str(child))

        self.network = net

class VESolver:
    def solve(self,network: Network):
        pass

class GibbsSolver:
    def solve(self,network: Network):
        pass

class OutputWriter:
    def __init__(self,solver):
        with open(f"{GROUP_ID}_{ALGORITHM}_{NETWORK_NAME}_{EVIDENCE_LEVEL}.csv", 'w', newline='') as f:
            writer = csv.writer(f)

class Driver:
    def __init__(self):
        self.reader = InputReader()
        match(ALGORITHM.lower()):
            case "ve":
                self.solver = VESolver()
            case "gibbs":
                self.solver = GibbsSolver()
            case _:
                raise NotImplementedError
        self.solver.solve(self.reader.network)
        # OutputWriter(self.solver)

if __name__ == '__main__':
    Driver()


LVHreport: [yes, no]
LowerBodyO2: [<5, 5-12, 12+]
LVH: [yes, no]
Sick: [yes, no]
CO2: [Normal, Low, High]
XrayReport: [Normal, Oligaemic, Plethoric, Grd_Glass, Asy/Patchy]
RUQO2: [<5, 5-12, 12+]
LungFlow: [Normal, Low, High]
BirthAsphyxia: [yes, no]
Disease: [PFC, TGA, Fallot, PAIVS, TAPVD, Lung]
ChestXray: [Normal, Oligaemic, Plethoric, Grd_Glass, Asy/Patch]
CO2Report: [<7.5, >=7.5]
CardiacMixing: [None, Mild, Complete, Transp.]
Grunting: [yes, no]
HypDistrib: [Equal, Unequal]
DuctFlow: [Lt_to_Rt, None, Rt_to_Lt]
Age: [0-3_days, 4-10_days, 11-30_days]
HypoxiaInO2: [Mild, Moderate, Severe]
LungParench: [Normal, Congested, Abnormal]
GruntingReport: [yes, no]
