In [2]:
from typing import Tuple, Dict, List
import json
from itertools import chain
import pandas as pd
from sklearn.linear_model import LogisticRegression


def get_model_params(df: pd.DataFrame, ordering: List[List[str]], solver='liblinear', penalty='l1', C=0.2) -> pd.DataFrame:
    """
    Gets LASSO regression parameters for each variable.

    :param df: Data.
    :param solver: Solver (liblinear or saga). Default: `liblinear`.
    :param penalty: Penalty. Default: `l1`.
    :param C: Regularlization. Default: `0.2`.
    :return: LASSO regression parameters for each variable.
    """

    def get_model(df, X_cols, y_col):
        X = df[X_cols]
        y = df[y_col]
        print(f'{y_col} = {X_cols}')

        model = LogisticRegression(penalty=penalty, solver=solver, C=C)
        model.fit(X, y)

        return model

    def extract_model_params(y, fields, model):
        child = {'child': y}
        intercepts = {'intercept': model.intercept_[0]}
        coefs = {field: coef for field, coef in zip(fields, model.coef_[0])}
        others = {field: 0.0 for field in fields[len(coefs):]}

        p = {**child, **intercepts}
        p = {**p, **coefs}
        p = {**p, **others}

        return p

    models = []
    for i, y_cols in enumerate(ordering):
        if i < 1:
            continue
        X_cols = list(chain(*ordering[:i]))
        for y_col in y_cols:
            m = get_model(df, X_cols, y_col)
            tup = y_col, m
            models.append(tup)
    
    param_df = pd.DataFrame([extract_model_params(y, df.columns, model) for y, model in models])
    return param_df

df = pd.read_csv('../data/data-binary.csv')

with open('../data/data-binary-complete.json', 'r') as f:
    meta = json.load(f)
ordering = meta['ordering']

get_model_params(df, ordering)

b = ['a']
c = ['a', 'b']
d = ['a', 'b', 'c']
e = ['a', 'b', 'c', 'd']


Unnamed: 0,child,intercept,a,b,c,d,e
0,b,-1.401362,0.0,0.0,0.0,0.0,0.0
1,c,-1.402015,0.009982,2.667356,0.0,0.0,0.0
2,d,-1.256663,-0.099697,-0.085201,0.0,0.0,0.0
3,e,-2.245986,0.0,0.02895,0.0,4.337833,0.0
