In [35]:
import dgp
import polars as pl
import polars.selectors as cs
import numpy as np
import statsmodels.api as sm
from statsmodels.genmod.generalized_linear_model import GLMResults
from statsmodels.genmod.families import family

In [36]:
node1 = dgp.GenericNode('Z', node_restrictions=[dgp.Node])
node2 = dgp.GenericNode('X', parents=[node1], node_restrictions=[dgp.Node])
node3 = dgp.GenericNode('Y', parents=[node1], node_restrictions=[dgp.CategoricalNode], min_categories=3)
nc924 = dgp.NodeCollection('L-M Con. Indep.', [node1, node2, node3])

In [46]:
def _compute_categorical(_data, y_label, X_labels, betas):
    
    
    data = _data.to_dummies(cs.string(), separator='__cat__')
        
    results = {}
    
    #print(f'REGRESSING: {y_label} ~ {X_labels}')
    
    X = data.to_pandas()[sorted(X_labels)]
    X['__const'] = 1
    X = X.to_numpy().astype(float)
        
    models = {}
    for cat in betas.keys():
        y = data.to_pandas()[cat]
        y = y.to_numpy().astype(float)
        
        glm_model = sm.GLM(y, X, family=family.Binomial())
        glm_results = GLMResults(glm_model, betas[cat], normalized_cov_params=None, scale=None)
        models[cat] = glm_results
        
    # Get weights for each sample, depending on its class -> combat dataset imbalance
    # class_weights = {}
    # for r in self.data.group_by(y_label).agg(pl.len()).rows():
    #     class_weights[r[0]] = len(self.data)/((len(betas)+1)*r[1])  
    # sample_weights = self.data.with_columns(__sample_weight=pl.col(y_label).replace_strict(class_weights, return_dtype=pl.Float64))['__sample_weight'].to_numpy()
        
    etas = {c:np.clip(m.predict(which='linear'), -709, 305) for c,m in models.items()}
    denom = 1 + sum(np.exp(eta) for eta in etas.values())

    mus = {c:np.clip(np.exp(eta)/denom,1e-15,1-1e-15) for c,eta in etas.items()}
    dmu_deta = {c:mu*(1-mu) for c,mu in mus.items()}
    
    
    for cat in dmu_deta.keys():
        y = data.to_pandas()[cat]
        y = y.to_numpy().astype(float)
        
        z = etas[cat] + (y - mus[cat])/dmu_deta[cat]
        
        W = np.diag((dmu_deta[cat]**2)/max(np.var(mus[cat]), 1e-15))
        
        xw = X.T @ W
        xwx = xw @ X
        xwz = xw @ z
        
        results[cat] = (xwx, xwz)
    
    # LLF
    def get_cat_index(data, y_label, cat):
        cat_val = cat.split('__cat__')[-1]
        return data.with_row_index().filter(pl.col(y_label) == cat_val)['index'].to_list()
    
    def get_ref_cat_index(data, y_label, cats):
        cat_vals = [cat.split('__cat__')[-1] for cat in cats]
        return data.with_row_index().filter(~pl.col(y_label).is_in(cat_vals))['index'].to_list()
    
    cat_indexes = {cat: get_cat_index(_data, y_label, cat) for cat in mus.keys()}
    
    llf = 0 
    for cat in cat_indexes.keys():
        llf += np.sum(np.log(np.take(mus[cat], cat_indexes[cat])))
    llf += np.sum(np.log(np.take(1/denom, get_ref_cat_index(_data, y_label, cat_indexes.keys()))))
    
    # DEVIANCE + LLF SAT
    llf_sat = 0
    for cat in cat_indexes.keys():
        y = data.to_pandas()[cat]
        y = y.to_numpy().astype(float)
        
        # Only add log for y == 1, since log(0) should be excluded
        llf_sat += np.sum(y * np.log(np.clip(y, 1e-10, None)))  # Clip to avoid log(0) issues
        
    def get_ref_cat_mask(data, y_label, cats):
        cat_vals = [cat.split('__cat__')[-1] for cat in cats]
        return data.with_columns(__mask=~pl.col(y_label).is_in(cat_vals))['__mask'].cast(pl.Float64).to_numpy()
        
    # Handle reference category similarly
    #y_ref = data.to_pandas()[reference_category]
    #y_ref = y_ref.to_numpy().astype(float)
    y_ref = get_ref_cat_mask(_data, y_label, cat_indexes.keys())
    llf_sat += np.sum(y_ref * np.log(np.clip(y_ref, 1e-10, None)))
    
    deviance = 2 * (llf_sat - llf)
        
    return results, llf, deviance

In [173]:
def custom_multionomial_regression(data, y_label, X_labels, debug=True):
    last_deviance = None
    deviance = 0

    categories = [f'{y_label}__cat__{c}' for c in sorted(data[y_label].unique().to_list())]
    betas = {c:np.zeros(len(X_labels)+1) for c in categories[1:]}

    counter = 1

    # CALL FUNC
    while True:
        if debug and counter == 1 or counter % 10 == 0:
            print(f'Running iteration: {counter}')
        last_deviance = deviance
        results, llf, deviance =_compute_categorical(data, y_label, X_labels, betas)

        xwx= {c:results[c][0] for c in betas.keys()}
        xwz = {c:results[c][1] for c in betas.keys()}

        xwx_inv = {}
        for c,xwx in xwx.items():
            try:
                xwx_inv[c] = np.linalg.inv(xwx)
            except np.linalg.LinAlgError:
                xwx_inv[c] = np.linalg.pinv(xwx)

        new_betas = {cat:xwx_inv[cat] @ xwz[cat] for cat in betas.keys()}
        betas = new_betas

        # CHECK DEVIANCE
        if abs(deviance - last_deviance) / (0.1 + abs(deviance)) < 1e-8:
            break
        counter += 1

    if debug:
        print('')
        print(f'Converged after {counter} steps')
        print(f'LLF: {llf}')
        print(f'beta: {betas}')
    
    return llf, betas

In [154]:
def multinomial_logistic_regression(data, y_label, X_labels):
    # Add a constant (intercept) to the independent variables
    _data = data.with_columns(__const=pl.lit(1)).to_pandas()
    X = _data[X_labels + ['__const']]
    # The dependent variable (must be categorical)
    y = _data[y_label]
    
    # Fit the multinomial logistic regression model
    model = sm.MNLogit(y, X)
    result = model.fit()
    
    # Get the log-likelihood function (LLF) and coefficients
    llf = result.llf
    coefficients = result.params
    
    return llf, coefficients

In [188]:
nc924.reset()
data = nc924.get(10000)

y_label = 'Y'
X_labels = ['X']

print(data.head())

shape: (5, 3)
┌─────┬───────────┬───────────┐
│ Y   ┆ X         ┆ Z         │
│ --- ┆ ---       ┆ ---       │
│ str ┆ f64       ┆ f64       │
╞═════╪═══════════╪═══════════╡
│ 2   ┆ 0.56404   ┆ 0.54228   │
│ 4   ┆ -0.63956  ┆ 0.530321  │
│ 3   ┆ 0.445197  ┆ -0.06389  │
│ 1   ┆ -0.290108 ┆ 0.856704  │
│ 4   ┆ 0.087565  ┆ -0.725127 │
└─────┴───────────┴───────────┘


In [183]:

supersample_cat = '1'
data_a = data.filter(pl.col(y_label) == supersample_cat)
data_b = data.filter(pl.col(y_label) != supersample_cat)

data_b = data_b.sample(len(data_a))
data = pl.concat([data_a, data_b])


In [189]:
llf, coeff = custom_multionomial_regression(data, y_label, X_labels)

Running iteration: 1
Running iteration: 10

Converged after 15 steps
LLF: -13781.206100806756
beta: {'Y__cat__2': array([ 0.08472371, -0.11894132]), 'Y__cat__3': array([ 0.13553075, -0.32229342]), 'Y__cat__4': array([ 0.20806508, -0.11688616])}


In [190]:
llf_mn, coeff_mn = multinomial_logistic_regression(data, y_label, X_labels)

Optimization terminated successfully.
         Current function value: 1.378121
         Iterations 4


In [191]:
min_cat = int(list(coeff.keys())[0].split('__cat__')[-1])
_coeff = {str(int(c.split('__cat__')[-1])-min_cat):l.tolist() for c,l in coeff.items()}
coeff_df = pl.from_dict(_coeff).with_row_index()
coeff_df

index,0,1,2
u32,f64,f64,f64
0,0.084724,0.135531,0.208065
1,-0.118941,-0.322293,-0.116886


In [192]:
coeff_mn_df = pl.from_pandas(coeff_mn).with_row_index()
coeff_mn_df

index,0,1,2
u32,f64,f64,f64
0,0.084426,0.135578,0.208236
1,-0.118977,-0.322436,-0.117199
