In [33]:
import csv
from typing import Any
from dataclasses import dataclass
from pgmpy.readwrite import BIFReader
from numpy import array
from collections import defaultdict

# Parameters
GROUP_ID = 29
ALGORITHM = 've' # ’ve’ = Variable Elimination, ’gibbs’ = Gibbs Sampling
NETWORK_NAME = './networks/child.bif'
REPORT = 'Disease' # e.g., Child: ’Disease’
EVIDENCE_LEVEL = 'Little' # {None | Little | Moderate} for Child/Insurance
EVIDENCE = 'LowerBodyO2=<5; RUQO2=>=12; CO2Report=>=7.5; XrayReport=Asy/Patchy'


@dataclass (frozen=True)
class Node:
    name: str
    parents: tuple
    values: tuple
    probability_model: array = None
    state: int = -1

    def __hash__(self):
        return hash(self.name + str(self.parents))

    def __str__(self):
        return f'{self.name}: [{", ".join(self.values)}]'


class Network:
    def __init__(self, nodes: set = None):
        self.nodes = nodes or set()
        self.connections = defaultdict(set)

    def __str__(self):
        return f'{"\n".join([str(node) for node in self.nodes])}'

    def markov_blanket(self, node: Node):
        blanket = self.connections[node]
        for connection in self.connections[node]:
            if connection not in node.parents: # if node is a child node
                blanket.update(connection.parents) # add the child's parents
        return blanket

class InputReader:
    def __init__(self):
        reader = BIFReader(NETWORK_NAME)
        model = reader.get_model()
        states = reader.get_states()
        network = Network()

        for edge in model.edges():
            node_objects = []
            for node in edge:
                cpd = model.get_cpds(node)
                node_obj = Node(name=str(node),parents=tuple(cpd.get_evidence()),probability_model=cpd.values,values=tuple(states[node]))
                network.nodes.add(node_obj)
                node_objects.append(node_obj)

            network.connections[node_objects[0]].add(node_objects[1])
            network.connections[node_objects[1]].add(node_objects[0])

        print(network)
        self.network = network

class VESolver:
    def solve(self,network: Network):
        pass

class GibbsSolver:
    def solve(self,network: Network):
        pass

class OutputWriter:
    def __init__(self):
        with open(f"{GROUP_ID}_{ALGORITHM}_{NETWORK_NAME}_{EVIDENCE_LEVEL}.csv", 'w', newline='') as f:
            writer = csv.writer(f)

class Driver:
    def __init__(self):
        self.reader = InputReader()
        match(ALGORITHM.lower()):
            case "ve":
                self.solver = VESolver()
            case "gibbs":
                self.solver = GibbsSolver()
            case _:
                raise NotImplementedError
        self.solver.solve(self.reader.network)
        # OutputWriter(self.solver)

if __name__ == '__main__':
    Driver()


LVHreport: [yes, no]
LowerBodyO2: [<5, 5-12, 12+]
LVH: [yes, no]
Sick: [yes, no]
CO2: [Normal, Low, High]
XrayReport: [Normal, Oligaemic, Plethoric, Grd_Glass, Asy/Patchy]
RUQO2: [<5, 5-12, 12+]
LungFlow: [Normal, Low, High]
BirthAsphyxia: [yes, no]
Disease: [PFC, TGA, Fallot, PAIVS, TAPVD, Lung]
ChestXray: [Normal, Oligaemic, Plethoric, Grd_Glass, Asy/Patch]
CO2Report: [<7.5, >=7.5]
CardiacMixing: [None, Mild, Complete, Transp.]
Grunting: [yes, no]
HypDistrib: [Equal, Unequal]
DuctFlow: [Lt_to_Rt, None, Rt_to_Lt]
Age: [0-3_days, 4-10_days, 11-30_days]
HypoxiaInO2: [Mild, Moderate, Severe]
LungParench: [Normal, Congested, Abnormal]
GruntingReport: [yes, no]
