In [5]:
import pandas as pd
import json
from itertools import chain
from functools import reduce
import operator

df_path = '../data/gmu-covid.csv'
pa_path = '../data/gmu-covid-parents.json'

def expand_data(df_path: str, pa_path: str):
    def get_interactions(values):
        interactions = sorted(list(set(values)))
        interactions = filter(lambda s: s.find('!') > 0, interactions)
        interactions = map(lambda s: (s, s.split('!')), interactions)
        interactions = {k: v for k, v in interactions}
        
        return interactions
    
    df = pd.read_csv(df_path)

    with open(pa_path, 'r') as f:
        parents = json.load(f)
        
    ch_interactions = get_interactions(chain(*[v for _, v in parents.items()]))
    pa_interactions = get_interactions([k for k, _ in parents.items()])
    interactions = {**ch_interactions, **pa_interactions}
    
    def expand(r, cols):
        vals = [r[c] for c in cols]
        result = reduce(operator.mul, vals, 1)
        return result

    for col_name, cols in interactions.items():
        df[col_name] = df.apply(lambda r: expand(r, cols), axis=1)
        
    return df
    
df = expand_data(df_path, pa_path)
df.shape

(461, 175)

In [7]:
import networkx as nx

g = nx.read_gpickle('../data/gmu-covid-networkx.gpickle')

In [None]:
from typing import Tuple, Dict, List, Any
from itertools import chain, combinations
import numpy as np

def get_parameters(df: pd.DataFrame, g: nx.DiGraph) -> Tuple[Dict[str, List[str]], Dict[str, List[float]]]:
    """
    Gets the parameters.

    :param df: Data.
    :param g: Graph (structure).
    :return: Tuple; first item is dictionary of domains; second item is dictionary of probabilities.
    """

    def vals_to_str():
        ddf = df.copy(deep=True)
        for col in ddf.columns:
            ddf[col] = ddf[col].astype(str)
        return ddf

    def get_filters(ch, parents, domains):
        pas = parents[ch]
        if len(pas) == 0:
            ch_domain = domains[ch]
            return [f'{ch}=="{v}"' for v in ch_domain]
        else:
            def is_valid(tups):
                n_tups = len(tups)
                u_tups = len(set([name for name, _ in tups]))
                if n_tups == u_tups:
                    return True
                return False

            vals = [[(pa, v) for v in domains[pa]] for pa in pas]
            vals = vals + [[(ch, v) for v in domains[ch]]]
            vals = chain(*vals)
            vals = combinations(vals, len(pas) + 1)
            vals = filter(is_valid, vals)
            vals = map(lambda tups: ' and '.join([f'`{t[0]}`=="{t[1]}"' for t in tups]), vals)
            vals = list(vals)
            return vals

    def get_total(filters, n):
        def divide(arr):
            a = np.array(arr)
            n = np.sum(a)

            if n == 0:
                p = 1 / len(arr)
                return [p for _ in range(len(arr))]

            r = a / n
            r = list(r)
            return r
        
        counts = [ddf.query(f).shape[0] for f in filters]
        counts = [counts[i:i + n] for i in range(0, len(counts), n)]
        counts = [divide(arr) for arr in counts]
        counts = list(chain(*counts))
        return counts

    ddf = vals_to_str()
    nodes = list(g.nodes())

    domains = {n: sorted(list(ddf[n].unique())) for n in nodes}
    parents = {ch: list(g.predecessors(ch)) for ch in nodes}

    p = {ch: get_total(get_filters(ch, parents, domains), len(domains[ch])) for ch in nodes}
    return domains, p

domains, p = get_parameters(df, g)

In [None]:
domains

In [None]:
p